Respair commited on
Commit
30cb84c
·
verified ·
1 Parent(s): 0097ce0

Upload 3 files

Browse files
pkanade_24_multi_gpu_train_finetune_accelerate.py ADDED
@@ -0,0 +1,787 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # load packages
2
+ import random
3
+ import yaml
4
+ import time
5
+ from munch import Munch
6
+ import numpy as np
7
+ import torch
8
+ from torch import nn
9
+ import torch.nn.functional as F
10
+ import torchaudio
11
+ import librosa
12
+ import click
13
+ import shutil
14
+ import warnings
15
+ warnings.simplefilter('ignore')
16
+ from torch.utils.tensorboard import SummaryWriter
17
+
18
+ from meldataset import build_dataloader
19
+
20
+ from Utils.ASR.models import ASRCNN
21
+ from Utils.JDC.model import JDCNet
22
+ from Utils.PLBERT.util import load_plbert
23
+
24
+ from models import *
25
+ from losses import *
26
+ from utils import *
27
+
28
+ from Modules.slmadv import SLMAdversarialLoss
29
+ from Modules.diffusion.sampler import DiffusionSampler, ADPM2Sampler, KarrasSchedule
30
+
31
+ from optimizers import build_optimizer
32
+
33
+
34
+ from accelerate import Accelerator, DistributedDataParallelKwargs
35
+ from accelerate.utils import tqdm, ProjectConfiguration
36
+
37
+
38
+
39
+
40
+ # # simple fix for dataparallel that allows access to class attributes
41
+ # class MyDataParallel(torch.nn.DataParallel):
42
+ # def __getattr__(self, name):
43
+ # try:
44
+ # return super().__getattr__(name)
45
+ # except AttributeError:
46
+ # return getattr(self.module, name)
47
+
48
+ # import logging
49
+ # from logging import StreamHandler
50
+ # logger = logging.getLogger(__name__)
51
+ # logger.setLevel(logging.DEBUG)
52
+ # handler = StreamHandler()
53
+ # handler.setLevel(logging.DEBUG)
54
+ # logger.addHandler(handler)
55
+
56
+ import logging
57
+ from accelerate.logging import get_logger
58
+ from logging import StreamHandler
59
+
60
+ logger = get_logger(__name__)
61
+ logger.setLevel(logging.DEBUG)
62
+
63
+ @click.command()
64
+ @click.option('-p', '--config_path', default='Configs/config_ft.yml', type=str)
65
+ def main(config_path):
66
+ config = yaml.safe_load(open(config_path))
67
+
68
+ log_dir = config['log_dir']
69
+ if not osp.exists(log_dir): os.makedirs(log_dir, exist_ok=True)
70
+ shutil.copy(config_path, osp.join(log_dir, osp.basename(config_path)))
71
+ writer = SummaryWriter(log_dir + "/tensorboard")
72
+
73
+ # write logs
74
+ file_handler = logging.FileHandler(osp.join(log_dir, 'train.log'))
75
+ file_handler.setLevel(logging.DEBUG)
76
+ file_handler.setFormatter(logging.Formatter('%(levelname)s:%(asctime)s: %(message)s'))
77
+ logger.logger.addHandler(file_handler)
78
+
79
+ batch_size = config.get('batch_size', 10)
80
+
81
+ epochs = config.get('epochs', 200)
82
+ save_freq = config.get('save_freq', 2)
83
+ log_interval = config.get('log_interval', 10)
84
+ saving_epoch = config.get('save_freq', 2)
85
+
86
+ data_params = config.get('data_params', None)
87
+ sr = config['preprocess_params'].get('sr', 24000)
88
+ train_path = data_params['train_data']
89
+ val_path = data_params['val_data']
90
+ root_path = data_params['root_path']
91
+ min_length = data_params['min_length']
92
+ OOD_data = data_params['OOD_data']
93
+
94
+ max_len = config.get('max_len', 200)
95
+
96
+ loss_params = Munch(config['loss_params'])
97
+ diff_epoch = loss_params.diff_epoch
98
+ joint_epoch = loss_params.joint_epoch
99
+
100
+ optimizer_params = Munch(config['optimizer_params'])
101
+
102
+ train_list, val_list = get_data_path_list(train_path, val_path)
103
+
104
+ try:
105
+ tracker = data_params['logger']
106
+ except KeyError:
107
+ tracker = "mlflow"
108
+
109
+ ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True, broadcast_buffers=False)
110
+ configAcc = ProjectConfiguration(project_dir=log_dir, logging_dir=log_dir)
111
+ accelerator = Accelerator(log_with=tracker,
112
+ project_config=configAcc,
113
+ split_batches=True,
114
+ kwargs_handlers=[ddp_kwargs],
115
+ mixed_precision='bf16')
116
+
117
+
118
+
119
+ device = accelerator.device
120
+
121
+
122
+ with accelerator.main_process_first():
123
+
124
+ train_dataloader = build_dataloader(train_list,
125
+ root_path,
126
+ OOD_data=OOD_data,
127
+ min_length=min_length,
128
+ batch_size=batch_size,
129
+ num_workers=2,
130
+ dataset_config={},
131
+ device=device)
132
+
133
+ val_dataloader = build_dataloader(val_list,
134
+ root_path,
135
+ OOD_data=OOD_data,
136
+ min_length=min_length,
137
+ batch_size=batch_size,
138
+ validation=True,
139
+ num_workers=0,
140
+ device=device,
141
+ dataset_config={})
142
+
143
+ # load pretrained ASR model
144
+ ASR_config = config.get('ASR_config', False)
145
+ ASR_path = config.get('ASR_path', False)
146
+ text_aligner = load_ASR_models(ASR_path, ASR_config)
147
+
148
+ # load pretrained F0 model
149
+ F0_path = config.get('F0_path', False)
150
+ pitch_extractor = load_F0_models(F0_path)
151
+
152
+ # load PL-BERT model
153
+ BERT_path = config.get('PLBERT_dir', False)
154
+ plbert = load_plbert(BERT_path)
155
+
156
+ # build model
157
+ model_params = recursive_munch(config['model_params'])
158
+ multispeaker = model_params.multispeaker
159
+ model = build_model(model_params, text_aligner, pitch_extractor, plbert)
160
+ _ = [model[key].to(device) for key in model]
161
+
162
+ # DP
163
+ for key in model:
164
+ if key != "mpd" and key != "msd" and key != "wd":
165
+ model[key] = accelerator.prepare(model[key])
166
+
167
+ start_epoch = 0
168
+ iters = 0
169
+
170
+ load_pretrained = config.get('pretrained_model', '') != '' and config.get('second_stage_load_pretrained', False)
171
+
172
+ if not load_pretrained:
173
+ if config.get('first_stage_path', '') != '':
174
+ first_stage_path = osp.join(log_dir, config.get('first_stage_path', 'first_stage.pth'))
175
+ print('Loading the first stage model at %s ...' % first_stage_path)
176
+ model, _, start_epoch, iters = load_checkpoint(model,
177
+ None,
178
+ first_stage_path,
179
+ load_only_params=True,
180
+ ignore_modules=['bert', 'bert_encoder', 'predictor', 'predictor_encoder', 'msd', 'mpd', 'wd', 'diffusion']) # keep starting epoch for tensorboard log
181
+
182
+ # these epochs should be counted from the start epoch
183
+ diff_epoch += start_epoch
184
+ joint_epoch += start_epoch
185
+ epochs += start_epoch
186
+
187
+ model.predictor_encoder = copy.deepcopy(model.style_encoder)
188
+ else:
189
+ raise ValueError('You need to specify the path to the first stage model.')
190
+
191
+ gl = GeneratorLoss(model.mpd, model.msd).to(device)
192
+ dl = DiscriminatorLoss(model.mpd, model.msd).to(device)
193
+ wl = WavLMLoss(model_params.slm.model,
194
+ model.wd,
195
+ sr,
196
+ model_params.slm.sr).to(device)
197
+
198
+ gl = accelerator.prepare(gl)
199
+ dl = accelerator.prepare(dl)
200
+ wl = accelerator.prepare(wl)
201
+
202
+ sampler = DiffusionSampler(
203
+ model.diffusion.module.diffusion,
204
+ sampler=ADPM2Sampler(),
205
+ sigma_schedule=KarrasSchedule(sigma_min=0.0001, sigma_max=3.0, rho=9.0), # empirical parameters
206
+ clamp=False
207
+ )
208
+
209
+ scheduler_params = {
210
+ "max_lr": optimizer_params.lr,
211
+ "pct_start": float(0),
212
+ "epochs": epochs,
213
+ "steps_per_epoch": len(train_dataloader),
214
+ }
215
+ scheduler_params_dict= {key: scheduler_params.copy() for key in model}
216
+ scheduler_params_dict['bert']['max_lr'] = optimizer_params.bert_lr * 2
217
+ scheduler_params_dict['decoder']['max_lr'] = optimizer_params.ft_lr * 2
218
+ scheduler_params_dict['style_encoder']['max_lr'] = optimizer_params.ft_lr * 2
219
+
220
+ optimizer = build_optimizer({key: model[key].parameters() for key in model},
221
+ scheduler_params_dict=scheduler_params_dict, lr=optimizer_params.lr)
222
+
223
+ # adjust BERT learning rate
224
+ for g in optimizer.optimizers['bert'].param_groups:
225
+ g['betas'] = (0.9, 0.99)
226
+ g['lr'] = optimizer_params.bert_lr
227
+ g['initial_lr'] = optimizer_params.bert_lr
228
+ g['min_lr'] = 0
229
+ g['weight_decay'] = 0.01
230
+
231
+ # adjust acoustic module learning rate
232
+ for module in ["decoder", "style_encoder"]:
233
+ for g in optimizer.optimizers[module].param_groups:
234
+ g['betas'] = (0.0, 0.99)
235
+ g['lr'] = optimizer_params.ft_lr
236
+ g['initial_lr'] = optimizer_params.ft_lr
237
+ g['min_lr'] = 0
238
+ g['weight_decay'] = 1e-4
239
+
240
+ # load models if there is a model
241
+ if load_pretrained:
242
+ model, optimizer, start_epoch, iters = load_checkpoint(model, optimizer, config['pretrained_model'],
243
+ load_only_params=config.get('load_only_params', True))
244
+
245
+ n_down = model.text_aligner.module.n_down
246
+
247
+ best_loss = float('inf') # best test loss
248
+ loss_train_record = list([])
249
+ loss_test_record = list([])
250
+ iters = 0
251
+
252
+ criterion = nn.L1Loss() # F0 loss (regression)
253
+ torch.cuda.empty_cache()
254
+
255
+ stft_loss = MultiResolutionSTFTLoss().to(device)
256
+
257
+ print('BERT', optimizer.optimizers['bert'])
258
+ print('decoder', optimizer.optimizers['decoder'])
259
+
260
+ start_ds = False
261
+
262
+ running_std = []
263
+
264
+ slmadv_params = Munch(config['slmadv_params'])
265
+ slmadv = SLMAdversarialLoss(model, wl, sampler,
266
+ slmadv_params.min_len,
267
+ slmadv_params.max_len,
268
+ batch_percentage=slmadv_params.batch_percentage,
269
+ skip_update=slmadv_params.iter,
270
+ sig=slmadv_params.sig
271
+ )
272
+
273
+ for k, v in optimizer.optimizers.items():
274
+ optimizer.optimizers[k] = accelerator.prepare(optimizer.optimizers[k])
275
+ optimizer.schedulers[k] = accelerator.prepare(optimizer.schedulers[k])
276
+
277
+ train_dataloader = accelerator.prepare(train_dataloader)
278
+ val_dataloader = accelerator.prepare(val_dataloader)
279
+
280
+ for epoch in range(start_epoch, epochs):
281
+ running_loss = 0
282
+ start_time = time.time()
283
+
284
+ _ = [model[key].eval() for key in model]
285
+
286
+ model.text_aligner.train()
287
+ model.text_encoder.train()
288
+
289
+ model.predictor.train()
290
+ model.bert_encoder.train()
291
+ model.bert.train()
292
+ model.msd.train()
293
+ model.mpd.train()
294
+
295
+ for i, batch in enumerate(train_dataloader):
296
+ waves = batch[0]
297
+ batch = [b.to(device) for b in batch[1:]]
298
+ texts, input_lengths, ref_texts, ref_lengths, mels, mel_input_length, ref_mels = batch
299
+ with torch.no_grad():
300
+ mask = length_to_mask(mel_input_length // (2 ** n_down)).to(device)
301
+ mel_mask = length_to_mask(mel_input_length).to(device)
302
+ text_mask = length_to_mask(input_lengths).to(texts.device)
303
+
304
+ # compute reference styles
305
+ if multispeaker and epoch >= diff_epoch:
306
+ ref_ss = model.style_encoder(ref_mels.unsqueeze(1))
307
+ ref_sp = model.predictor_encoder(ref_mels.unsqueeze(1))
308
+ ref = torch.cat([ref_ss, ref_sp], dim=1)
309
+
310
+ try:
311
+ ppgs, s2s_pred, s2s_attn = model.text_aligner(mels, mask, texts)
312
+ s2s_attn = s2s_attn.transpose(-1, -2)
313
+ s2s_attn = s2s_attn[..., 1:]
314
+ s2s_attn = s2s_attn.transpose(-1, -2)
315
+ except:
316
+ continue
317
+
318
+ mask_ST = mask_from_lens(s2s_attn, input_lengths, mel_input_length // (2 ** n_down))
319
+ s2s_attn_mono = maximum_path(s2s_attn, mask_ST)
320
+
321
+ # encode
322
+ t_en = model.text_encoder(texts, input_lengths, text_mask)
323
+
324
+ # 50% of chance of using monotonic version
325
+ if bool(random.getrandbits(1)):
326
+ asr = (t_en @ s2s_attn)
327
+ else:
328
+ asr = (t_en @ s2s_attn_mono)
329
+
330
+ d_gt = s2s_attn_mono.sum(axis=-1).detach()
331
+
332
+ # compute the style of the entire utterance
333
+ # this operation cannot be done in batch because of the avgpool layer (may need to work on masked avgpool)
334
+ ss = []
335
+ gs = []
336
+ for bib in range(len(mel_input_length)):
337
+ mel_length = int(mel_input_length[bib].item())
338
+ mel = mels[bib, :, :mel_input_length[bib]]
339
+ s = model.predictor_encoder(mel.unsqueeze(0).unsqueeze(1))
340
+ ss.append(s)
341
+ s = model.style_encoder(mel.unsqueeze(0).unsqueeze(1))
342
+ gs.append(s)
343
+
344
+ s_dur = torch.stack(ss).squeeze() # global prosodic styles
345
+ gs = torch.stack(gs).squeeze() # global acoustic styles
346
+ s_trg = torch.cat([gs, s_dur], dim=-1).detach() # ground truth for denoiser
347
+
348
+ bert_dur = model.bert(texts, attention_mask=(~text_mask).int())
349
+ d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
350
+
351
+ # denoiser training
352
+ if epoch >= diff_epoch:
353
+ num_steps = np.random.randint(3, 5)
354
+
355
+ if model_params.diffusion.dist.estimate_sigma_data:
356
+ model.diffusion.module.diffusion.sigma_data = s_trg.std(axis=-1).mean().item() # batch-wise std estimation
357
+ running_std.append(model.diffusion.module.diffusion.sigma_data)
358
+
359
+ if multispeaker:
360
+ s_preds = sampler(noise = torch.randn_like(s_trg).unsqueeze(1).to(device),
361
+ embedding=bert_dur,
362
+ embedding_scale=1,
363
+ features=ref, # reference from the same speaker as the embedding
364
+ embedding_mask_proba=0.1,
365
+ num_steps=num_steps).squeeze(1)
366
+ loss_diff = model.diffusion(s_trg.unsqueeze(1), embedding=bert_dur, features=ref).mean() # EDM loss
367
+ loss_sty = F.l1_loss(s_preds, s_trg.detach()) # style reconstruction loss
368
+ else:
369
+ s_preds = sampler(noise = torch.randn_like(s_trg).unsqueeze(1).to(device),
370
+ embedding=bert_dur,
371
+ embedding_scale=1,
372
+ embedding_mask_proba=0.1,
373
+ num_steps=num_steps).squeeze(1)
374
+ loss_diff = model.diffusion.module.diffusion(s_trg.unsqueeze(1), embedding=bert_dur).mean() # EDM loss
375
+ loss_sty = F.l1_loss(s_preds, s_trg.detach()) # style reconstruction loss
376
+ else:
377
+ loss_sty = 0
378
+ loss_diff = 0
379
+
380
+
381
+ s_loss = 0
382
+
383
+
384
+ d, p = model.predictor(d_en, s_dur,
385
+ input_lengths,
386
+ s2s_attn_mono,
387
+ text_mask)
388
+
389
+ mel_len_st = int(mel_input_length.min().item() / 2 - 1)
390
+
391
+
392
+ mel_input_length_all = accelerator.gather(mel_input_length) # for balanced load
393
+ mel_len = min([int(mel_input_length_all.min().item() / 2 - 1), max_len // 2])
394
+
395
+
396
+ en = []
397
+ gt = []
398
+ p_en = []
399
+ wav = []
400
+ st = []
401
+
402
+ for bib in range(len(mel_input_length)):
403
+ mel_length = int(mel_input_length[bib].item() / 2)
404
+
405
+ random_start = np.random.randint(0, mel_length - mel_len)
406
+ en.append(asr[bib, :, random_start:random_start+mel_len])
407
+ p_en.append(p[bib, :, random_start:random_start+mel_len])
408
+ gt.append(mels[bib, :, (random_start * 2):((random_start+mel_len) * 2)])
409
+
410
+ y = waves[bib][(random_start * 2) * 300:((random_start+mel_len) * 2) * 300]
411
+ wav.append(torch.from_numpy(y).to(device))
412
+
413
+ # style reference (better to be different from the GT)
414
+ random_start = np.random.randint(0, mel_length - mel_len_st)
415
+ st.append(mels[bib, :, (random_start * 2):((random_start+mel_len_st) * 2)])
416
+
417
+ wav = torch.stack(wav).float().detach()
418
+
419
+ en = torch.stack(en)
420
+ p_en = torch.stack(p_en)
421
+ gt = torch.stack(gt).detach()
422
+ st = torch.stack(st).detach()
423
+
424
+
425
+ if gt.size(-1) < 80:
426
+ continue
427
+
428
+ s = model.style_encoder(gt.unsqueeze(1))
429
+ s_dur = model.predictor_encoder(gt.unsqueeze(1))
430
+
431
+ with torch.no_grad():
432
+ F0_real, _, F0 = model.pitch_extractor(gt.unsqueeze(1))
433
+ F0 = F0.reshape(F0.shape[0], F0.shape[1] * 2, F0.shape[2], 1).squeeze()
434
+
435
+ N_real = log_norm(gt.unsqueeze(1)).squeeze(1)
436
+
437
+ y_rec_gt = wav.unsqueeze(1)
438
+ y_rec_gt_pred = model.decoder(en, F0_real, N_real, s)
439
+
440
+ wav = y_rec_gt
441
+
442
+ F0_fake, N_fake = model.predictor(texts=p_en, style=s, f0=True)
443
+
444
+ y_rec = model.decoder(en, F0_fake, N_fake, s)
445
+
446
+ loss_F0_rec = (F.smooth_l1_loss(F0_real, F0_fake)) / 10
447
+ loss_norm_rec = F.smooth_l1_loss(N_real, N_fake)
448
+
449
+ optimizer.zero_grad()
450
+ d_loss = dl(wav.detach(), y_rec.detach()).mean()
451
+ accelerator.backward(d_loss)
452
+ optimizer.step('msd')
453
+ optimizer.step('mpd')
454
+
455
+ # generator loss
456
+ optimizer.zero_grad()
457
+
458
+ loss_mel = stft_loss(y_rec, wav)
459
+ loss_gen_all = gl(wav, y_rec).mean()
460
+ loss_lm = wl(wav.detach().squeeze(), y_rec.squeeze()).mean()
461
+
462
+ loss_ce = 0
463
+ loss_dur = 0
464
+ for _s2s_pred, _text_input, _text_length in zip(d, (d_gt), input_lengths):
465
+ _s2s_pred = _s2s_pred[:_text_length, :]
466
+ _text_input = _text_input[:_text_length].long()
467
+ _s2s_trg = torch.zeros_like(_s2s_pred)
468
+ for p in range(_s2s_trg.shape[0]):
469
+ _s2s_trg[p, :_text_input[p]] = 1
470
+ _dur_pred = torch.sigmoid(_s2s_pred).sum(axis=1)
471
+
472
+ loss_dur += F.l1_loss(_dur_pred[1:_text_length-1],
473
+ _text_input[1:_text_length-1])
474
+ loss_ce += F.binary_cross_entropy_with_logits(_s2s_pred.flatten(), _s2s_trg.flatten())
475
+
476
+ loss_ce /= texts.size(0)
477
+ loss_dur /= texts.size(0)
478
+
479
+ loss_s2s = 0
480
+ for _s2s_pred, _text_input, _text_length in zip(s2s_pred, texts, input_lengths):
481
+ loss_s2s += F.cross_entropy(_s2s_pred[:_text_length], _text_input[:_text_length])
482
+ loss_s2s /= texts.size(0)
483
+
484
+ loss_mono = F.l1_loss(s2s_attn, s2s_attn_mono) * 10
485
+
486
+ g_loss = loss_params.lambda_mel * loss_mel + \
487
+ loss_params.lambda_F0 * loss_F0_rec + \
488
+ loss_params.lambda_ce * loss_ce + \
489
+ loss_params.lambda_norm * loss_norm_rec + \
490
+ loss_params.lambda_dur * loss_dur + \
491
+ loss_params.lambda_gen * loss_gen_all + \
492
+ loss_params.lambda_slm * loss_lm + \
493
+ loss_params.lambda_sty * loss_sty + \
494
+ loss_params.lambda_diff * loss_diff + \
495
+ loss_params.lambda_mono * loss_mono + \
496
+ loss_params.lambda_s2s * loss_s2s
497
+
498
+ running_loss += accelerator.gather(loss_mel).mean().item()
499
+ accelerator.backward(g_loss)
500
+
501
+ # if torch.isnan(g_loss):
502
+ # from IPython.core.debugger import set_trace
503
+ # set_trace()
504
+
505
+ optimizer.step('bert_encoder')
506
+ optimizer.step('bert')
507
+ optimizer.step('predictor')
508
+ optimizer.step('predictor_encoder')
509
+ optimizer.step('style_encoder')
510
+ optimizer.step('decoder')
511
+
512
+ optimizer.step('text_encoder')
513
+ optimizer.step('text_aligner')
514
+
515
+ if epoch >= diff_epoch:
516
+ optimizer.step('diffusion')
517
+
518
+ d_loss_slm, loss_gen_lm = 0, 0
519
+ if epoch >= joint_epoch:
520
+ # randomly pick whether to use in-distribution text
521
+ if np.random.rand() < 0.5:
522
+ use_ind = True
523
+ else:
524
+ use_ind = False
525
+
526
+ if use_ind:
527
+ ref_lengths = input_lengths
528
+ ref_texts = texts
529
+
530
+ slm_out = slmadv(i,
531
+ y_rec_gt,
532
+ y_rec_gt_pred,
533
+ waves,
534
+ mel_input_length,
535
+ ref_texts,
536
+ ref_lengths, use_ind, s_trg.detach(), ref if multispeaker else None)
537
+
538
+ if slm_out is not None:
539
+ d_loss_slm, loss_gen_lm, y_pred = slm_out
540
+
541
+ # SLM generator loss
542
+ optimizer.zero_grad()
543
+ accelerator.backward(loss_gen_lm)
544
+
545
+ # compute the gradient norm
546
+ total_norm = {}
547
+ for key in model.keys():
548
+ total_norm[key] = 0
549
+ parameters = [p for p in model[key].parameters() if p.grad is not None and p.requires_grad]
550
+ for p in parameters:
551
+ param_norm = p.grad.detach().data.norm(2)
552
+ total_norm[key] += param_norm.item() ** 2
553
+ total_norm[key] = total_norm[key] ** 0.5
554
+
555
+ # gradient scaling
556
+ if total_norm['predictor'] > slmadv_params.thresh:
557
+ for key in model.keys():
558
+ for p in model[key].parameters():
559
+ if p.grad is not None:
560
+ p.grad *= (1 / total_norm['predictor'])
561
+
562
+ for p in model.predictor.duration_proj.parameters():
563
+ if p.grad is not None:
564
+ p.grad *= slmadv_params.scale
565
+
566
+ for p in model.predictor.lstm.parameters():
567
+ if p.grad is not None:
568
+ p.grad *= slmadv_params.scale
569
+
570
+ for p in model.diffusion.parameters():
571
+ if p.grad is not None:
572
+ p.grad *= slmadv_params.scale
573
+
574
+ optimizer.step('bert_encoder')
575
+ optimizer.step('bert')
576
+ optimizer.step('predictor')
577
+ optimizer.step('diffusion')
578
+
579
+ # SLM discriminator loss
580
+ if d_loss_slm != 0:
581
+ optimizer.zero_grad()
582
+ accelerator.backward(d_loss_slm)
583
+ optimizer.step('wd')
584
+
585
+ iters = iters + 1
586
+
587
+ if (i + 1) % log_interval == 0:
588
+ logger.info ('Epoch [%d/%d], Step [%d/%d], Loss: %.5f, Disc Loss: %.5f, Dur Loss: %.5f, CE Loss: %.5f, Norm Loss: %.5f, F0 Loss: %.5f, LM Loss: %.5f, Gen Loss: %.5f, Sty Loss: %.5f, Diff Loss: %.5f, DiscLM Loss: %.5f, GenLM Loss: %.5f, SLoss: %.5f, S2S Loss: %.5f, Mono Loss: %.5f'
589
+ %(epoch+1, epochs, i+1, len(train_list)//batch_size, running_loss / log_interval, d_loss, loss_dur, loss_ce, loss_norm_rec, loss_F0_rec, loss_lm, loss_gen_all, loss_sty, loss_diff, d_loss_slm, loss_gen_lm, s_loss, loss_s2s, loss_mono), main_process_only=True)
590
+ if accelerator.is_main_process:
591
+ print ('Epoch [%d/%d], Step [%d/%d], Loss: %.5f, Disc Loss: %.5f, Dur Loss: %.5f, CE Loss: %.5f, Norm Loss: %.5f, F0 Loss: %.5f, LM Loss: %.5f, Gen Loss: %.5f, Sty Loss: %.5f, Diff Loss: %.5f, DiscLM Loss: %.5f, GenLM Loss: %.5f, SLoss: %.5f, S2S Loss: %.5f, Mono Loss: %.5f'
592
+ %(epoch+1, epochs, i+1, len(train_list)//batch_size, running_loss / log_interval, d_loss, loss_dur, loss_ce, loss_norm_rec, loss_F0_rec, loss_lm, loss_gen_all, loss_sty, loss_diff, d_loss_slm, loss_gen_lm, s_loss, loss_s2s, loss_mono))
593
+ accelerator.log({'train/mel_loss': float(running_loss / log_interval),
594
+ 'train/gen_loss': float(loss_gen_all),
595
+ 'train/d_loss': float(d_loss),
596
+ 'train/ce_loss': float(loss_ce),
597
+ 'train/dur_loss': float(loss_dur),
598
+ 'train/slm_loss': float(loss_lm),
599
+ 'train/norm_loss': float(loss_norm_rec),
600
+ 'train/F0_loss': float(loss_F0_rec),
601
+ 'train/sty_loss': float(loss_sty),
602
+ 'train/diff_loss': float(loss_diff),
603
+ 'train/d_loss_slm': float(d_loss_slm),
604
+ 'train/gen_loss_slm': float(loss_gen_lm),
605
+ 'epoch': int(epoch) + 1}, step=iters)
606
+
607
+ running_loss = 0
608
+
609
+ accelerator.print('Time elasped:', time.time() - start_time)
610
+
611
+ loss_test = 0
612
+ loss_align = 0
613
+ loss_f = 0
614
+ _ = [model[key].eval() for key in model]
615
+
616
+ with torch.no_grad():
617
+ iters_test = 0
618
+ for batch_idx, batch in enumerate(val_dataloader):
619
+ optimizer.zero_grad()
620
+
621
+ try:
622
+ waves = batch[0]
623
+ batch = [b.to(device) for b in batch[1:]]
624
+ texts, input_lengths, ref_texts, ref_lengths, mels, mel_input_length, ref_mels = batch
625
+ with torch.no_grad():
626
+ mask = length_to_mask(mel_input_length // (2 ** n_down)).to('cuda')
627
+ text_mask = length_to_mask(input_lengths).to(texts.device)
628
+
629
+ _, _, s2s_attn = model.text_aligner(mels, mask, texts)
630
+ s2s_attn = s2s_attn.transpose(-1, -2)
631
+ s2s_attn = s2s_attn[..., 1:]
632
+ s2s_attn = s2s_attn.transpose(-1, -2)
633
+
634
+ mask_ST = mask_from_lens(s2s_attn, input_lengths, mel_input_length // (2 ** n_down))
635
+ s2s_attn_mono = maximum_path(s2s_attn, mask_ST)
636
+
637
+ # encode
638
+ t_en = model.text_encoder(texts, input_lengths, text_mask)
639
+ asr = (t_en @ s2s_attn_mono)
640
+
641
+ d_gt = s2s_attn_mono.sum(axis=-1).detach()
642
+
643
+ ss = []
644
+ gs = []
645
+
646
+ for bib in range(len(mel_input_length)):
647
+ mel_length = int(mel_input_length[bib].item())
648
+ mel = mels[bib, :, :mel_input_length[bib]]
649
+ s = model.predictor_encoder(mel.unsqueeze(0).unsqueeze(1))
650
+ ss.append(s)
651
+ s = model.style_encoder(mel.unsqueeze(0).unsqueeze(1))
652
+ gs.append(s)
653
+
654
+ s = torch.stack(ss).squeeze()
655
+ gs = torch.stack(gs).squeeze()
656
+ s_trg = torch.cat([s, gs], dim=-1).detach()
657
+
658
+ bert_dur = model.bert(texts, attention_mask=(~text_mask).int())
659
+ d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
660
+ d, p = model.predictor(d_en, s,
661
+ input_lengths,
662
+ s2s_attn_mono,
663
+ text_mask)
664
+ # get clips
665
+ mel_len = int(mel_input_length.min().item() / 2 - 1)
666
+ en = []
667
+ gt = []
668
+
669
+ p_en = []
670
+ wav = []
671
+
672
+ for bib in range(len(mel_input_length)):
673
+ mel_length = int(mel_input_length[bib].item() / 2)
674
+
675
+ random_start = np.random.randint(0, mel_length - mel_len)
676
+ en.append(asr[bib, :, random_start:random_start+mel_len])
677
+ p_en.append(p[bib, :, random_start:random_start+mel_len])
678
+
679
+ gt.append(mels[bib, :, (random_start * 2):((random_start+mel_len) * 2)])
680
+ y = waves[bib][(random_start * 2) * 300:((random_start+mel_len) * 2) * 300]
681
+ wav.append(torch.from_numpy(y).to(device))
682
+
683
+ wav = torch.stack(wav).float().detach()
684
+
685
+ en = torch.stack(en)
686
+ p_en = torch.stack(p_en)
687
+ gt = torch.stack(gt).detach()
688
+ s = model.predictor_encoder(gt.unsqueeze(1))
689
+
690
+ F0_fake, N_fake = model.predictor(texts=p_en, style=s, f0=True)
691
+
692
+ loss_dur = 0
693
+ for _s2s_pred, _text_input, _text_length in zip(d, (d_gt), input_lengths):
694
+ _s2s_pred = _s2s_pred[:_text_length, :]
695
+ _text_input = _text_input[:_text_length].long()
696
+ _s2s_trg = torch.zeros_like(_s2s_pred)
697
+ for bib in range(_s2s_trg.shape[0]):
698
+ _s2s_trg[bib, :_text_input[bib]] = 1
699
+ _dur_pred = torch.sigmoid(_s2s_pred).sum(axis=1)
700
+ loss_dur += F.l1_loss(_dur_pred[1:_text_length-1],
701
+ _text_input[1:_text_length-1])
702
+
703
+ loss_dur /= texts.size(0)
704
+
705
+ s = model.style_encoder(gt.unsqueeze(1))
706
+
707
+ y_rec = model.decoder(en, F0_fake, N_fake, s)
708
+ loss_mel = stft_loss(y_rec.squeeze(), wav.detach())
709
+
710
+ F0_real, _, F0 = model.pitch_extractor(gt.unsqueeze(1))
711
+
712
+ loss_F0 = F.l1_loss(F0_real, F0_fake) / 10
713
+
714
+
715
+
716
+ loss_test += (loss_mel).mean()
717
+ loss_align += (loss_dur).mean()
718
+ loss_f += (loss_F0).mean()
719
+
720
+
721
+ iters_test += 1
722
+ except:
723
+ continue
724
+
725
+ accelerator.print('Epochs:', epoch + 1)
726
+ accelerator.print('iters test:', iters_test)
727
+ try:
728
+ logger.info('Validation loss: %.3f, Dur loss: %.3f, F0 loss: %.3f' % (
729
+ loss_test / iters_test, loss_align / iters_test, loss_f / iters_test) + '\n', main_process_only=True)
730
+
731
+
732
+ accelerator.log({'eval/mel_loss': float(loss_test / iters_test),
733
+ 'eval/dur_loss': float(loss_test / iters_test),
734
+ 'eval/F0_loss': float(loss_f / iters_test)},
735
+ step=(i + 1) * (epoch + 1))
736
+ except ZeroDivisionError:
737
+ accelerator.print("Eval loss was divided by zero... skipping eval cycle")
738
+
739
+ if epoch % saving_epoch == 0:
740
+ if (loss_test / iters_test) < best_loss:
741
+ best_loss = loss_test / iters_test
742
+ try:
743
+ accelerator.print('Saving..')
744
+ state = {
745
+ 'net': {key: model[key].state_dict() for key in model},
746
+ 'optimizer': optimizer.state_dict(),
747
+ 'iters': iters,
748
+ 'val_loss': loss_test / iters_test,
749
+ 'epoch': epoch,
750
+ }
751
+ except ZeroDivisionError:
752
+ accelerator.print('No iter test, Re-Saving..')
753
+ state = {
754
+ 'net': {key: model[key].state_dict() for key in model},
755
+ 'optimizer': optimizer.state_dict(),
756
+ 'iters': iters,
757
+ 'val_loss': 0.1, # not zero just in case
758
+ 'epoch': epoch,
759
+ }
760
+
761
+ if accelerator.is_main_process:
762
+ save_path = osp.join(log_dir, 'epoch_2nd_%05d.pth' % epoch)
763
+ torch.save(state, save_path)
764
+
765
+ # if estimate sigma, save the estimated simga
766
+ if model_params.diffusion.dist.estimate_sigma_data:
767
+ config['model_params']['diffusion']['dist']['sigma_data'] = float(np.mean(running_std))
768
+
769
+ with open(osp.join(log_dir, osp.basename(config_path)), 'w') as outfile:
770
+ yaml.dump(config, outfile, default_flow_style=True)
771
+ if accelerator.is_main_process:
772
+ print('Saving last pth..')
773
+ state = {
774
+ 'net': {key: model[key].state_dict() for key in model},
775
+ 'optimizer': optimizer.state_dict(),
776
+ 'iters': iters,
777
+ 'val_loss': loss_test / iters_test,
778
+ 'epoch': epoch,
779
+ }
780
+ save_path = osp.join(log_dir, '2nd_phase_last.pth')
781
+ torch.save(state, save_path)
782
+
783
+ accelerator.end_training()
784
+
785
+
786
+ if __name__ == "__main__":
787
+ main()
pkanade_24_train_finetune.py ADDED
@@ -0,0 +1,707 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # load packages
2
+ import random
3
+ import yaml
4
+ import time
5
+ from munch import Munch
6
+ import numpy as np
7
+ import torch
8
+ from torch import nn
9
+ import torch.nn.functional as F
10
+ import torchaudio
11
+ import librosa
12
+ import click
13
+ import shutil
14
+ import warnings
15
+ warnings.simplefilter('ignore')
16
+ from torch.utils.tensorboard import SummaryWriter
17
+
18
+ from meldataset import build_dataloader
19
+
20
+ from Utils.ASR.models import ASRCNN
21
+ from Utils.JDC.model import JDCNet
22
+ from Utils.PLBERT.util import load_plbert
23
+
24
+ from models import *
25
+ from losses import *
26
+ from utils import *
27
+
28
+ from Modules.slmadv import SLMAdversarialLoss
29
+ from Modules.diffusion.sampler import DiffusionSampler, ADPM2Sampler, KarrasSchedule
30
+
31
+ from optimizers import build_optimizer
32
+
33
+ # simple fix for dataparallel that allows access to class attributes
34
+ class MyDataParallel(torch.nn.DataParallel):
35
+ def __getattr__(self, name):
36
+ try:
37
+ return super().__getattr__(name)
38
+ except AttributeError:
39
+ return getattr(self.module, name)
40
+
41
+ import logging
42
+ from logging import StreamHandler
43
+ logger = logging.getLogger(__name__)
44
+ logger.setLevel(logging.DEBUG)
45
+ handler = StreamHandler()
46
+ handler.setLevel(logging.DEBUG)
47
+ logger.addHandler(handler)
48
+
49
+
50
+ @click.command()
51
+ @click.option('-p', '--config_path', default='Configs/config_ft.yml', type=str)
52
+ def main(config_path):
53
+ config = yaml.safe_load(open(config_path))
54
+
55
+ log_dir = config['log_dir']
56
+ if not osp.exists(log_dir): os.makedirs(log_dir, exist_ok=True)
57
+ shutil.copy(config_path, osp.join(log_dir, osp.basename(config_path)))
58
+ writer = SummaryWriter(log_dir + "/tensorboard")
59
+
60
+ # write logs
61
+ file_handler = logging.FileHandler(osp.join(log_dir, 'train.log'))
62
+ file_handler.setLevel(logging.DEBUG)
63
+ file_handler.setFormatter(logging.Formatter('%(levelname)s:%(asctime)s: %(message)s'))
64
+ logger.addHandler(file_handler)
65
+
66
+
67
+ batch_size = config.get('batch_size', 10)
68
+
69
+ epochs = config.get('epochs', 200)
70
+ save_freq = config.get('save_freq', 2)
71
+ log_interval = config.get('log_interval', 10)
72
+ saving_epoch = config.get('save_freq', 2)
73
+
74
+ data_params = config.get('data_params', None)
75
+ sr = config['preprocess_params'].get('sr', 24000)
76
+ train_path = data_params['train_data']
77
+ val_path = data_params['val_data']
78
+ root_path = data_params['root_path']
79
+ min_length = data_params['min_length']
80
+ OOD_data = data_params['OOD_data']
81
+
82
+ max_len = config.get('max_len', 200)
83
+
84
+ loss_params = Munch(config['loss_params'])
85
+ diff_epoch = loss_params.diff_epoch
86
+ joint_epoch = loss_params.joint_epoch
87
+
88
+ optimizer_params = Munch(config['optimizer_params'])
89
+
90
+ train_list, val_list = get_data_path_list(train_path, val_path)
91
+ device = 'cuda'
92
+
93
+ train_dataloader = build_dataloader(train_list,
94
+ root_path,
95
+ OOD_data=OOD_data,
96
+ min_length=min_length,
97
+ batch_size=batch_size,
98
+ num_workers=2,
99
+ dataset_config={},
100
+ device=device)
101
+
102
+ val_dataloader = build_dataloader(val_list,
103
+ root_path,
104
+ OOD_data=OOD_data,
105
+ min_length=min_length,
106
+ batch_size=batch_size,
107
+ validation=True,
108
+ num_workers=0,
109
+ device=device,
110
+ dataset_config={})
111
+
112
+ # load pretrained ASR model
113
+ ASR_config = config.get('ASR_config', False)
114
+ ASR_path = config.get('ASR_path', False)
115
+ text_aligner = load_ASR_models(ASR_path, ASR_config)
116
+
117
+ # load pretrained F0 model
118
+ F0_path = config.get('F0_path', False)
119
+ pitch_extractor = load_F0_models(F0_path)
120
+
121
+ # load PL-BERT model
122
+ BERT_path = config.get('PLBERT_dir', False)
123
+ plbert = load_plbert(BERT_path)
124
+
125
+ # build model
126
+ model_params = recursive_munch(config['model_params'])
127
+ multispeaker = model_params.multispeaker
128
+ model = build_model(model_params, text_aligner, pitch_extractor, plbert)
129
+ _ = [model[key].to(device) for key in model]
130
+
131
+ # DP
132
+ for key in model:
133
+ if key != "mpd" and key != "msd" and key != "wd":
134
+ model[key] = MyDataParallel(model[key])
135
+
136
+ start_epoch = 0
137
+ iters = 0
138
+
139
+ load_pretrained = config.get('pretrained_model', '') != '' and config.get('second_stage_load_pretrained', False)
140
+
141
+ if not load_pretrained:
142
+ if config.get('first_stage_path', '') != '':
143
+ first_stage_path = osp.join(log_dir, config.get('first_stage_path', 'first_stage.pth'))
144
+ print('Loading the first stage model at %s ...' % first_stage_path)
145
+ model, _, start_epoch, iters = load_checkpoint(model,
146
+ None,
147
+ first_stage_path,
148
+ load_only_params=True,
149
+ ignore_modules=['bert', 'bert_encoder', 'predictor', 'predictor_encoder', 'msd', 'mpd', 'wd', 'diffusion']) # keep starting epoch for tensorboard log
150
+
151
+ # these epochs should be counted from the start epoch
152
+ diff_epoch += start_epoch
153
+ joint_epoch += start_epoch
154
+ epochs += start_epoch
155
+
156
+ model.predictor_encoder = copy.deepcopy(model.style_encoder)
157
+ else:
158
+ raise ValueError('You need to specify the path to the first stage model.')
159
+
160
+ gl = GeneratorLoss(model.mpd, model.msd).to(device)
161
+ dl = DiscriminatorLoss(model.mpd, model.msd).to(device)
162
+ wl = WavLMLoss(model_params.slm.model,
163
+ model.wd,
164
+ sr,
165
+ model_params.slm.sr).to(device)
166
+
167
+ gl = MyDataParallel(gl)
168
+ dl = MyDataParallel(dl)
169
+ wl = MyDataParallel(wl)
170
+
171
+ sampler = DiffusionSampler(
172
+ model.diffusion.diffusion,
173
+ sampler=ADPM2Sampler(),
174
+ sigma_schedule=KarrasSchedule(sigma_min=0.0001, sigma_max=3.0, rho=9.0), # empirical parameters
175
+ clamp=False
176
+ )
177
+
178
+ scheduler_params = {
179
+ "max_lr": optimizer_params.lr,
180
+ "pct_start": float(0),
181
+ "epochs": epochs,
182
+ "steps_per_epoch": len(train_dataloader),
183
+ }
184
+ scheduler_params_dict= {key: scheduler_params.copy() for key in model}
185
+ scheduler_params_dict['bert']['max_lr'] = optimizer_params.bert_lr * 2
186
+ scheduler_params_dict['decoder']['max_lr'] = optimizer_params.ft_lr * 2
187
+ scheduler_params_dict['style_encoder']['max_lr'] = optimizer_params.ft_lr * 2
188
+
189
+ optimizer = build_optimizer({key: model[key].parameters() for key in model},
190
+ scheduler_params_dict=scheduler_params_dict, lr=optimizer_params.lr)
191
+
192
+ # adjust BERT learning rate
193
+ for g in optimizer.optimizers['bert'].param_groups:
194
+ g['betas'] = (0.9, 0.99)
195
+ g['lr'] = optimizer_params.bert_lr
196
+ g['initial_lr'] = optimizer_params.bert_lr
197
+ g['min_lr'] = 0
198
+ g['weight_decay'] = 0.01
199
+
200
+ # adjust acoustic module learning rate
201
+ for module in ["decoder", "style_encoder"]:
202
+ for g in optimizer.optimizers[module].param_groups:
203
+ g['betas'] = (0.0, 0.99)
204
+ g['lr'] = optimizer_params.ft_lr
205
+ g['initial_lr'] = optimizer_params.ft_lr
206
+ g['min_lr'] = 0
207
+ g['weight_decay'] = 1e-4
208
+
209
+ # load models if there is a model
210
+ if load_pretrained:
211
+ model, optimizer, start_epoch, iters = load_checkpoint(model, optimizer, config['pretrained_model'],
212
+ load_only_params=config.get('load_only_params', True))
213
+
214
+ n_down = model.text_aligner.n_down
215
+
216
+ best_loss = float('inf') # best test loss
217
+ loss_train_record = list([])
218
+ loss_test_record = list([])
219
+ iters = 0
220
+
221
+ criterion = nn.L1Loss() # F0 loss (regression)
222
+ torch.cuda.empty_cache()
223
+
224
+ stft_loss = MultiResolutionSTFTLoss().to(device)
225
+
226
+ print('BERT', optimizer.optimizers['bert'])
227
+ print('decoder', optimizer.optimizers['decoder'])
228
+
229
+ start_ds = False
230
+
231
+ running_std = []
232
+
233
+ slmadv_params = Munch(config['slmadv_params'])
234
+ slmadv = SLMAdversarialLoss(model, wl, sampler,
235
+ slmadv_params.min_len,
236
+ slmadv_params.max_len,
237
+ batch_percentage=slmadv_params.batch_percentage,
238
+ skip_update=slmadv_params.iter,
239
+ sig=slmadv_params.sig
240
+ )
241
+
242
+
243
+ for epoch in range(start_epoch, epochs):
244
+ running_loss = 0
245
+ start_time = time.time()
246
+
247
+ _ = [model[key].eval() for key in model]
248
+
249
+ model.text_aligner.train()
250
+ model.text_encoder.train()
251
+
252
+ model.predictor.train()
253
+ model.bert_encoder.train()
254
+ model.bert.train()
255
+ model.msd.train()
256
+ model.mpd.train()
257
+
258
+ for i, batch in enumerate(train_dataloader):
259
+ waves = batch[0]
260
+ batch = [b.to(device) for b in batch[1:]]
261
+ texts, input_lengths, ref_texts, ref_lengths, mels, mel_input_length, ref_mels = batch
262
+ with torch.no_grad():
263
+ mask = length_to_mask(mel_input_length // (2 ** n_down)).to(device)
264
+ mel_mask = length_to_mask(mel_input_length).to(device)
265
+ text_mask = length_to_mask(input_lengths).to(texts.device)
266
+
267
+ # compute reference styles
268
+ if multispeaker and epoch >= diff_epoch:
269
+ ref_ss = model.style_encoder(ref_mels.unsqueeze(1))
270
+ ref_sp = model.predictor_encoder(ref_mels.unsqueeze(1))
271
+ ref = torch.cat([ref_ss, ref_sp], dim=1)
272
+
273
+ try:
274
+ ppgs, s2s_pred, s2s_attn = model.text_aligner(mels, mask, texts)
275
+ s2s_attn = s2s_attn.transpose(-1, -2)
276
+ s2s_attn = s2s_attn[..., 1:]
277
+ s2s_attn = s2s_attn.transpose(-1, -2)
278
+ except:
279
+ continue
280
+
281
+ mask_ST = mask_from_lens(s2s_attn, input_lengths, mel_input_length // (2 ** n_down))
282
+ s2s_attn_mono = maximum_path(s2s_attn, mask_ST)
283
+
284
+ # encode
285
+ t_en = model.text_encoder(texts, input_lengths, text_mask)
286
+
287
+ # 50% of chance of using monotonic version
288
+ if bool(random.getrandbits(1)):
289
+ asr = (t_en @ s2s_attn)
290
+ else:
291
+ asr = (t_en @ s2s_attn_mono)
292
+
293
+ d_gt = s2s_attn_mono.sum(axis=-1).detach()
294
+
295
+ # compute the style of the entire utterance
296
+ # this operation cannot be done in batch because of the avgpool layer (may need to work on masked avgpool)
297
+ ss = []
298
+ gs = []
299
+ for bib in range(len(mel_input_length)):
300
+ mel_length = int(mel_input_length[bib].item())
301
+ mel = mels[bib, :, :mel_input_length[bib]]
302
+ s = model.predictor_encoder(mel.unsqueeze(0).unsqueeze(1))
303
+ ss.append(s)
304
+ s = model.style_encoder(mel.unsqueeze(0).unsqueeze(1))
305
+ gs.append(s)
306
+
307
+ s_dur = torch.stack(ss).squeeze() # global prosodic styles
308
+ gs = torch.stack(gs).squeeze() # global acoustic styles
309
+ s_trg = torch.cat([gs, s_dur], dim=-1).detach() # ground truth for denoiser
310
+
311
+ bert_dur = model.bert(texts, attention_mask=(~text_mask).int())
312
+ d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
313
+
314
+ # denoiser training
315
+ if epoch >= diff_epoch:
316
+ num_steps = np.random.randint(3, 5)
317
+
318
+ if model_params.diffusion.dist.estimate_sigma_data:
319
+ model.diffusion.module.diffusion.sigma_data = s_trg.std(axis=-1).mean().item() # batch-wise std estimation
320
+ running_std.append(model.diffusion.module.diffusion.sigma_data)
321
+
322
+ if multispeaker:
323
+ s_preds = sampler(noise = torch.randn_like(s_trg).unsqueeze(1).to(device),
324
+ embedding=bert_dur,
325
+ embedding_scale=1,
326
+ features=ref, # reference from the same speaker as the embedding
327
+ embedding_mask_proba=0.1,
328
+ num_steps=num_steps).squeeze(1)
329
+ loss_diff = model.diffusion(s_trg.unsqueeze(1), embedding=bert_dur, features=ref).mean() # EDM loss
330
+ loss_sty = F.l1_loss(s_preds, s_trg.detach()) # style reconstruction loss
331
+ else:
332
+ s_preds = sampler(noise = torch.randn_like(s_trg).unsqueeze(1).to(device),
333
+ embedding=bert_dur,
334
+ embedding_scale=1,
335
+ embedding_mask_proba=0.1,
336
+ num_steps=num_steps).squeeze(1)
337
+ loss_diff = model.diffusion.module.diffusion(s_trg.unsqueeze(1), embedding=bert_dur).mean() # EDM loss
338
+ loss_sty = F.l1_loss(s_preds, s_trg.detach()) # style reconstruction loss
339
+ else:
340
+ loss_sty = 0
341
+ loss_diff = 0
342
+
343
+
344
+ s_loss = 0
345
+
346
+
347
+ d, p = model.predictor(d_en, s_dur,
348
+ input_lengths,
349
+ s2s_attn_mono,
350
+ text_mask)
351
+
352
+ mel_len_st = int(mel_input_length.min().item() / 2 - 1)
353
+ mel_len = min(int(mel_input_length.min().item() / 2 - 1), max_len // 2)
354
+ en = []
355
+ gt = []
356
+ p_en = []
357
+ wav = []
358
+ st = []
359
+
360
+ for bib in range(len(mel_input_length)):
361
+ mel_length = int(mel_input_length[bib].item() / 2)
362
+
363
+ random_start = np.random.randint(0, mel_length - mel_len)
364
+ en.append(asr[bib, :, random_start:random_start+mel_len])
365
+ p_en.append(p[bib, :, random_start:random_start+mel_len])
366
+ gt.append(mels[bib, :, (random_start * 2):((random_start+mel_len) * 2)])
367
+
368
+ y = waves[bib][(random_start * 2) * 300:((random_start+mel_len) * 2) * 300]
369
+ wav.append(torch.from_numpy(y).to(device))
370
+
371
+ # style reference (better to be different from the GT)
372
+ random_start = np.random.randint(0, mel_length - mel_len_st)
373
+ st.append(mels[bib, :, (random_start * 2):((random_start+mel_len_st) * 2)])
374
+
375
+ wav = torch.stack(wav).float().detach()
376
+
377
+ en = torch.stack(en)
378
+ p_en = torch.stack(p_en)
379
+ gt = torch.stack(gt).detach()
380
+ st = torch.stack(st).detach()
381
+
382
+
383
+ if gt.size(-1) < 80:
384
+ continue
385
+
386
+ s = model.style_encoder(gt.unsqueeze(1))
387
+ s_dur = model.predictor_encoder(gt.unsqueeze(1))
388
+
389
+ with torch.no_grad():
390
+ F0_real, _, F0 = model.pitch_extractor(gt.unsqueeze(1))
391
+ F0 = F0.reshape(F0.shape[0], F0.shape[1] * 2, F0.shape[2], 1).squeeze()
392
+
393
+ N_real = log_norm(gt.unsqueeze(1)).squeeze(1)
394
+
395
+ y_rec_gt = wav.unsqueeze(1)
396
+ y_rec_gt_pred = model.decoder(en, F0_real, N_real, s)
397
+
398
+ wav = y_rec_gt
399
+
400
+ F0_fake, N_fake = model.predictor.F0Ntrain(p_en, s_dur)
401
+
402
+ y_rec = model.decoder(en, F0_fake, N_fake, s)
403
+
404
+ loss_F0_rec = (F.smooth_l1_loss(F0_real, F0_fake)) / 10
405
+ loss_norm_rec = F.smooth_l1_loss(N_real, N_fake)
406
+
407
+ optimizer.zero_grad()
408
+ d_loss = dl(wav.detach(), y_rec.detach()).mean()
409
+ d_loss.backward()
410
+ optimizer.step('msd')
411
+ optimizer.step('mpd')
412
+
413
+ # generator loss
414
+ optimizer.zero_grad()
415
+
416
+ loss_mel = stft_loss(y_rec, wav)
417
+ loss_gen_all = gl(wav, y_rec).mean()
418
+ loss_lm = wl(wav.detach().squeeze(), y_rec.squeeze()).mean()
419
+
420
+ loss_ce = 0
421
+ loss_dur = 0
422
+ for _s2s_pred, _text_input, _text_length in zip(d, (d_gt), input_lengths):
423
+ _s2s_pred = _s2s_pred[:_text_length, :]
424
+ _text_input = _text_input[:_text_length].long()
425
+ _s2s_trg = torch.zeros_like(_s2s_pred)
426
+ for p in range(_s2s_trg.shape[0]):
427
+ _s2s_trg[p, :_text_input[p]] = 1
428
+ _dur_pred = torch.sigmoid(_s2s_pred).sum(axis=1)
429
+
430
+ loss_dur += F.l1_loss(_dur_pred[1:_text_length-1],
431
+ _text_input[1:_text_length-1])
432
+ loss_ce += F.binary_cross_entropy_with_logits(_s2s_pred.flatten(), _s2s_trg.flatten())
433
+
434
+ loss_ce /= texts.size(0)
435
+ loss_dur /= texts.size(0)
436
+
437
+ loss_s2s = 0
438
+ for _s2s_pred, _text_input, _text_length in zip(s2s_pred, texts, input_lengths):
439
+ loss_s2s += F.cross_entropy(_s2s_pred[:_text_length], _text_input[:_text_length])
440
+ loss_s2s /= texts.size(0)
441
+
442
+ loss_mono = F.l1_loss(s2s_attn, s2s_attn_mono) * 10
443
+
444
+ g_loss = loss_params.lambda_mel * loss_mel + \
445
+ loss_params.lambda_F0 * loss_F0_rec + \
446
+ loss_params.lambda_ce * loss_ce + \
447
+ loss_params.lambda_norm * loss_norm_rec + \
448
+ loss_params.lambda_dur * loss_dur + \
449
+ loss_params.lambda_gen * loss_gen_all + \
450
+ loss_params.lambda_slm * loss_lm + \
451
+ loss_params.lambda_sty * loss_sty + \
452
+ loss_params.lambda_diff * loss_diff + \
453
+ loss_params.lambda_mono * loss_mono + \
454
+ loss_params.lambda_s2s * loss_s2s
455
+
456
+ running_loss += loss_mel.item()
457
+ g_loss.backward()
458
+ if torch.isnan(g_loss):
459
+ from IPython.core.debugger import set_trace
460
+ set_trace()
461
+
462
+ optimizer.step('bert_encoder')
463
+ optimizer.step('bert')
464
+ optimizer.step('predictor')
465
+ optimizer.step('predictor_encoder')
466
+ optimizer.step('style_encoder')
467
+ optimizer.step('decoder')
468
+
469
+ optimizer.step('text_encoder')
470
+ optimizer.step('text_aligner')
471
+
472
+ if epoch >= diff_epoch:
473
+ optimizer.step('diffusion')
474
+
475
+ d_loss_slm, loss_gen_lm = 0, 0
476
+ if epoch >= joint_epoch:
477
+ # randomly pick whether to use in-distribution text
478
+ if np.random.rand() < 0.5:
479
+ use_ind = True
480
+ else:
481
+ use_ind = False
482
+
483
+ if use_ind:
484
+ ref_lengths = input_lengths
485
+ ref_texts = texts
486
+
487
+ slm_out = slmadv(i,
488
+ y_rec_gt,
489
+ y_rec_gt_pred,
490
+ waves,
491
+ mel_input_length,
492
+ ref_texts,
493
+ ref_lengths, use_ind, s_trg.detach(), ref if multispeaker else None)
494
+
495
+ if slm_out is not None:
496
+ d_loss_slm, loss_gen_lm, y_pred = slm_out
497
+
498
+ # SLM generator loss
499
+ optimizer.zero_grad()
500
+ loss_gen_lm.backward()
501
+
502
+ # compute the gradient norm
503
+ total_norm = {}
504
+ for key in model.keys():
505
+ total_norm[key] = 0
506
+ parameters = [p for p in model[key].parameters() if p.grad is not None and p.requires_grad]
507
+ for p in parameters:
508
+ param_norm = p.grad.detach().data.norm(2)
509
+ total_norm[key] += param_norm.item() ** 2
510
+ total_norm[key] = total_norm[key] ** 0.5
511
+
512
+ # gradient scaling
513
+ if total_norm['predictor'] > slmadv_params.thresh:
514
+ for key in model.keys():
515
+ for p in model[key].parameters():
516
+ if p.grad is not None:
517
+ p.grad *= (1 / total_norm['predictor'])
518
+
519
+ for p in model.predictor.duration_proj.parameters():
520
+ if p.grad is not None:
521
+ p.grad *= slmadv_params.scale
522
+
523
+ for p in model.predictor.lstm.parameters():
524
+ if p.grad is not None:
525
+ p.grad *= slmadv_params.scale
526
+
527
+ for p in model.diffusion.parameters():
528
+ if p.grad is not None:
529
+ p.grad *= slmadv_params.scale
530
+
531
+ optimizer.step('bert_encoder')
532
+ optimizer.step('bert')
533
+ optimizer.step('predictor')
534
+ optimizer.step('diffusion')
535
+
536
+ # SLM discriminator loss
537
+ if d_loss_slm != 0:
538
+ optimizer.zero_grad()
539
+ d_loss_slm.backward(retain_graph=True)
540
+ optimizer.step('wd')
541
+
542
+ iters = iters + 1
543
+
544
+ if (i+1)%log_interval == 0:
545
+ logger.info ('Epoch [%d/%d], Step [%d/%d], Loss: %.5f, Disc Loss: %.5f, Dur Loss: %.5f, CE Loss: %.5f, Norm Loss: %.5f, F0 Loss: %.5f, LM Loss: %.5f, Gen Loss: %.5f, Sty Loss: %.5f, Diff Loss: %.5f, DiscLM Loss: %.5f, GenLM Loss: %.5f, SLoss: %.5f, S2S Loss: %.5f, Mono Loss: %.5f'
546
+ %(epoch+1, epochs, i+1, len(train_list)//batch_size, running_loss / log_interval, d_loss, loss_dur, loss_ce, loss_norm_rec, loss_F0_rec, loss_lm, loss_gen_all, loss_sty, loss_diff, d_loss_slm, loss_gen_lm, s_loss, loss_s2s, loss_mono))
547
+
548
+ writer.add_scalar('train/mel_loss', running_loss / log_interval, iters)
549
+ writer.add_scalar('train/gen_loss', loss_gen_all, iters)
550
+ writer.add_scalar('train/d_loss', d_loss, iters)
551
+ writer.add_scalar('train/ce_loss', loss_ce, iters)
552
+ writer.add_scalar('train/dur_loss', loss_dur, iters)
553
+ writer.add_scalar('train/slm_loss', loss_lm, iters)
554
+ writer.add_scalar('train/norm_loss', loss_norm_rec, iters)
555
+ writer.add_scalar('train/F0_loss', loss_F0_rec, iters)
556
+ writer.add_scalar('train/sty_loss', loss_sty, iters)
557
+ writer.add_scalar('train/diff_loss', loss_diff, iters)
558
+ writer.add_scalar('train/d_loss_slm', d_loss_slm, iters)
559
+ writer.add_scalar('train/gen_loss_slm', loss_gen_lm, iters)
560
+
561
+ running_loss = 0
562
+
563
+ print('Time elasped:', time.time()-start_time)
564
+
565
+ loss_test = 0
566
+ loss_align = 0
567
+ loss_f = 0
568
+ _ = [model[key].eval() for key in model]
569
+
570
+ with torch.no_grad():
571
+ iters_test = 0
572
+ for batch_idx, batch in enumerate(val_dataloader):
573
+ optimizer.zero_grad()
574
+
575
+ try:
576
+ waves = batch[0]
577
+ batch = [b.to(device) for b in batch[1:]]
578
+ texts, input_lengths, ref_texts, ref_lengths, mels, mel_input_length, ref_mels = batch
579
+ with torch.no_grad():
580
+ mask = length_to_mask(mel_input_length // (2 ** n_down)).to('cuda')
581
+ text_mask = length_to_mask(input_lengths).to(texts.device)
582
+
583
+ _, _, s2s_attn = model.text_aligner(mels, mask, texts)
584
+ s2s_attn = s2s_attn.transpose(-1, -2)
585
+ s2s_attn = s2s_attn[..., 1:]
586
+ s2s_attn = s2s_attn.transpose(-1, -2)
587
+
588
+ mask_ST = mask_from_lens(s2s_attn, input_lengths, mel_input_length // (2 ** n_down))
589
+ s2s_attn_mono = maximum_path(s2s_attn, mask_ST)
590
+
591
+ # encode
592
+ t_en = model.text_encoder(texts, input_lengths, text_mask)
593
+ asr = (t_en @ s2s_attn_mono)
594
+
595
+ d_gt = s2s_attn_mono.sum(axis=-1).detach()
596
+
597
+ ss = []
598
+ gs = []
599
+
600
+ for bib in range(len(mel_input_length)):
601
+ mel_length = int(mel_input_length[bib].item())
602
+ mel = mels[bib, :, :mel_input_length[bib]]
603
+ s = model.predictor_encoder(mel.unsqueeze(0).unsqueeze(1))
604
+ ss.append(s)
605
+ s = model.style_encoder(mel.unsqueeze(0).unsqueeze(1))
606
+ gs.append(s)
607
+
608
+ s = torch.stack(ss).squeeze()
609
+ gs = torch.stack(gs).squeeze()
610
+ s_trg = torch.cat([s, gs], dim=-1).detach()
611
+
612
+ bert_dur = model.bert(texts, attention_mask=(~text_mask).int())
613
+ d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
614
+ d, p = model.predictor(d_en, s,
615
+ input_lengths,
616
+ s2s_attn_mono,
617
+ text_mask)
618
+ # get clips
619
+ mel_len = int(mel_input_length.min().item() / 2 - 1)
620
+ en = []
621
+ gt = []
622
+
623
+ p_en = []
624
+ wav = []
625
+
626
+ for bib in range(len(mel_input_length)):
627
+ mel_length = int(mel_input_length[bib].item() / 2)
628
+
629
+ random_start = np.random.randint(0, mel_length - mel_len)
630
+ en.append(asr[bib, :, random_start:random_start+mel_len])
631
+ p_en.append(p[bib, :, random_start:random_start+mel_len])
632
+
633
+ gt.append(mels[bib, :, (random_start * 2):((random_start+mel_len) * 2)])
634
+ y = waves[bib][(random_start * 2) * 300:((random_start+mel_len) * 2) * 300]
635
+ wav.append(torch.from_numpy(y).to(device))
636
+
637
+ wav = torch.stack(wav).float().detach()
638
+
639
+ en = torch.stack(en)
640
+ p_en = torch.stack(p_en)
641
+ gt = torch.stack(gt).detach()
642
+ s = model.predictor_encoder(gt.unsqueeze(1))
643
+
644
+ F0_fake, N_fake = model.predictor.F0Ntrain(p_en, s)
645
+
646
+ loss_dur = 0
647
+ for _s2s_pred, _text_input, _text_length in zip(d, (d_gt), input_lengths):
648
+ _s2s_pred = _s2s_pred[:_text_length, :]
649
+ _text_input = _text_input[:_text_length].long()
650
+ _s2s_trg = torch.zeros_like(_s2s_pred)
651
+ for bib in range(_s2s_trg.shape[0]):
652
+ _s2s_trg[bib, :_text_input[bib]] = 1
653
+ _dur_pred = torch.sigmoid(_s2s_pred).sum(axis=1)
654
+ loss_dur += F.l1_loss(_dur_pred[1:_text_length-1],
655
+ _text_input[1:_text_length-1])
656
+
657
+ loss_dur /= texts.size(0)
658
+
659
+ s = model.style_encoder(gt.unsqueeze(1))
660
+
661
+ y_rec = model.decoder(en, F0_fake, N_fake, s)
662
+ loss_mel = stft_loss(y_rec.squeeze(), wav.detach())
663
+
664
+ F0_real, _, F0 = model.pitch_extractor(gt.unsqueeze(1))
665
+
666
+ loss_F0 = F.l1_loss(F0_real, F0_fake) / 10
667
+
668
+ loss_test += (loss_mel).mean()
669
+ loss_align += (loss_dur).mean()
670
+ loss_f += (loss_F0).mean()
671
+
672
+ iters_test += 1
673
+ except:
674
+ continue
675
+
676
+ print('Epochs:', epoch + 1)
677
+ logger.info('Validation loss: %.3f, Dur loss: %.3f, F0 loss: %.3f' % (loss_test / iters_test, loss_align / iters_test, loss_f / iters_test) + '\n\n\n')
678
+ print('\n\n\n')
679
+ writer.add_scalar('eval/mel_loss', loss_test / iters_test, epoch + 1)
680
+ writer.add_scalar('eval/dur_loss', loss_test / iters_test, epoch + 1)
681
+ writer.add_scalar('eval/F0_loss', loss_f / iters_test, epoch + 1)
682
+
683
+
684
+ if (epoch + 1) % save_freq == 0 :
685
+ if (loss_test / iters_test) < best_loss:
686
+ best_loss = loss_test / iters_test
687
+ print('Saving..')
688
+ state = {
689
+ 'net': {key: model[key].state_dict() for key in model},
690
+ 'optimizer': optimizer.state_dict(),
691
+ 'iters': iters,
692
+ 'val_loss': loss_test / iters_test,
693
+ 'epoch': epoch,
694
+ }
695
+ save_path = osp.join(log_dir, 'epoch_2nd_%05d.pth' % epoch)
696
+ torch.save(state, save_path)
697
+
698
+ # if estimate sigma, save the estimated simga
699
+ if model_params.diffusion.dist.estimate_sigma_data:
700
+ config['model_params']['diffusion']['dist']['sigma_data'] = float(np.mean(running_std))
701
+
702
+ with open(osp.join(log_dir, osp.basename(config_path)), 'w') as outfile:
703
+ yaml.dump(config, outfile, default_flow_style=True)
704
+
705
+
706
+ if __name__=="__main__":
707
+ main()
pkanade_24_train_finetune_accelerate.py ADDED
@@ -0,0 +1,788 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # load packages
2
+ import random
3
+ import yaml
4
+ import time
5
+ from munch import Munch
6
+ import numpy as np
7
+ import torch
8
+ from torch import nn
9
+ import torch.nn.functional as F
10
+ import torchaudio
11
+ import librosa
12
+ import click
13
+ import shutil
14
+ import warnings
15
+ warnings.simplefilter('ignore')
16
+ from torch.utils.tensorboard import SummaryWriter
17
+
18
+ from meldataset import build_dataloader
19
+
20
+ from Utils.ASR.models import ASRCNN
21
+ from Utils.JDC.model import JDCNet
22
+ from Utils.PLBERT.util import load_plbert
23
+
24
+ from models import *
25
+ from losses import *
26
+ from utils import *
27
+
28
+ from Modules.slmadv import SLMAdversarialLoss
29
+ from Modules.diffusion.sampler import DiffusionSampler, ADPM2Sampler, KarrasSchedule
30
+
31
+ from optimizers import build_optimizer
32
+
33
+
34
+ from accelerate import Accelerator, DistributedDataParallelKwargs
35
+ from accelerate.utils import tqdm, ProjectConfiguration
36
+
37
+
38
+
39
+
40
+ # simple fix for dataparallel that allows access to class attributes
41
+ class MyDataParallel(torch.nn.DataParallel):
42
+ def __getattr__(self, name):
43
+ try:
44
+ return super().__getattr__(name)
45
+ except AttributeError:
46
+ return getattr(self.module, name)
47
+
48
+
49
+ # from logging import StreamHandler
50
+ # logger = logging.getLogger(__name__)
51
+ # logger.setLevel(logging.DEBUG)
52
+ # handler = StreamHandler()
53
+ # handler.setLevel(logging.DEBUG)
54
+ # logger.addHandler(handler)
55
+
56
+
57
+ import logging
58
+ from accelerate.logging import get_logger
59
+ from logging import StreamHandler
60
+
61
+ logger = get_logger(__name__)
62
+ logger.setLevel(logging.DEBUG)
63
+
64
+ @click.command()
65
+ @click.option('-p', '--config_path', default='Configs/config_ft.yml', type=str)
66
+ def main(config_path):
67
+ config = yaml.safe_load(open(config_path))
68
+
69
+ log_dir = config['log_dir']
70
+ if not osp.exists(log_dir): os.makedirs(log_dir, exist_ok=True)
71
+ shutil.copy(config_path, osp.join(log_dir, osp.basename(config_path)))
72
+ writer = SummaryWriter(log_dir + "/tensorboard")
73
+ # write logs
74
+ file_handler = logging.FileHandler(osp.join(log_dir, 'train.log'))
75
+ file_handler.setLevel(logging.DEBUG)
76
+ file_handler.setFormatter(logging.Formatter('%(levelname)s:%(asctime)s: %(message)s'))
77
+ logger.logger.addHandler(file_handler)
78
+
79
+
80
+
81
+ batch_size = config.get('batch_size', 10)
82
+
83
+ epochs = config.get('epochs', 200)
84
+ save_freq = config.get('save_freq', 2)
85
+ log_interval = config.get('log_interval', 10)
86
+ saving_epoch = config.get('save_freq', 2)
87
+
88
+ data_params = config.get('data_params', None)
89
+ sr = config['preprocess_params'].get('sr', 24000)
90
+ train_path = data_params['train_data']
91
+ val_path = data_params['val_data']
92
+ root_path = data_params['root_path']
93
+ min_length = data_params['min_length']
94
+ OOD_data = data_params['OOD_data']
95
+
96
+ max_len = config.get('max_len', 200)
97
+
98
+ loss_params = Munch(config['loss_params'])
99
+ diff_epoch = loss_params.diff_epoch
100
+ joint_epoch = loss_params.joint_epoch
101
+
102
+ optimizer_params = Munch(config['optimizer_params'])
103
+
104
+ train_list, val_list = get_data_path_list(train_path, val_path)
105
+
106
+ try:
107
+ tracker = data_params['logger']
108
+ except KeyError:
109
+ tracker = "mlflow"
110
+
111
+ ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True, broadcast_buffers=False)
112
+ configAcc = ProjectConfiguration(project_dir=log_dir, logging_dir=log_dir)
113
+ accelerator = Accelerator(log_with=tracker,
114
+ project_config=configAcc,
115
+ split_batches=True,
116
+ kwargs_handlers=[ddp_kwargs],
117
+ mixed_precision='bf16')
118
+
119
+
120
+
121
+ device = accelerator.device
122
+
123
+
124
+ with accelerator.main_process_first():
125
+
126
+ train_dataloader = build_dataloader(train_list,
127
+ root_path,
128
+ OOD_data=OOD_data,
129
+ min_length=min_length,
130
+ batch_size=batch_size,
131
+ num_workers=2,
132
+ dataset_config={},
133
+ device=device)
134
+
135
+ val_dataloader = build_dataloader(val_list,
136
+ root_path,
137
+ OOD_data=OOD_data,
138
+ min_length=min_length,
139
+ batch_size=batch_size,
140
+ validation=True,
141
+ num_workers=0,
142
+ device=device,
143
+ dataset_config={})
144
+
145
+ # load pretrained ASR model
146
+ ASR_config = config.get('ASR_config', False)
147
+ ASR_path = config.get('ASR_path', False)
148
+ text_aligner = load_ASR_models(ASR_path, ASR_config)
149
+
150
+ # load pretrained F0 model
151
+ F0_path = config.get('F0_path', False)
152
+ pitch_extractor = load_F0_models(F0_path)
153
+
154
+ # load PL-BERT model
155
+ BERT_path = config.get('PLBERT_dir', False)
156
+ plbert = load_plbert(BERT_path)
157
+
158
+ # build model
159
+ model_params = recursive_munch(config['model_params'])
160
+ multispeaker = model_params.multispeaker
161
+ model = build_model(model_params, text_aligner, pitch_extractor, plbert)
162
+ _ = [model[key].to(device) for key in model]
163
+
164
+ # DP
165
+ for key in model:
166
+ if key != "mpd" and key != "msd" and key != "wd":
167
+ model[key] = accelerator.prepare(model[key])
168
+
169
+ start_epoch = 0
170
+ iters = 0
171
+
172
+ load_pretrained = config.get('pretrained_model', '') != '' and config.get('second_stage_load_pretrained', False)
173
+
174
+ if not load_pretrained:
175
+ if config.get('first_stage_path', '') != '':
176
+ first_stage_path = osp.join(log_dir, config.get('first_stage_path', 'first_stage.pth'))
177
+ print('Loading the first stage model at %s ...' % first_stage_path)
178
+ model, _, start_epoch, iters = load_checkpoint(model,
179
+ None,
180
+ first_stage_path,
181
+ load_only_params=True,
182
+ ignore_modules=['bert', 'bert_encoder', 'predictor', 'predictor_encoder', 'msd', 'mpd', 'wd', 'diffusion']) # keep starting epoch for tensorboard log
183
+
184
+ # these epochs should be counted from the start epoch
185
+ diff_epoch += start_epoch
186
+ joint_epoch += start_epoch
187
+ epochs += start_epoch
188
+
189
+ model.predictor_encoder = copy.deepcopy(model.style_encoder)
190
+ else:
191
+ raise ValueError('You need to specify the path to the first stage model.')
192
+
193
+ gl = GeneratorLoss(model.mpd, model.msd).to(device)
194
+ dl = DiscriminatorLoss(model.mpd, model.msd).to(device)
195
+ wl = WavLMLoss(model_params.slm.model,
196
+ model.wd,
197
+ sr,
198
+ model_params.slm.sr).to(device)
199
+
200
+ gl = accelerator.prepare(gl)
201
+ dl = accelerator.prepare(dl)
202
+ wl = accelerator.prepare(wl)
203
+
204
+ sampler = DiffusionSampler(
205
+ model.diffusion.diffusion,
206
+ sampler=ADPM2Sampler(),
207
+ sigma_schedule=KarrasSchedule(sigma_min=0.0001, sigma_max=3.0, rho=9.0), # empirical parameters
208
+ clamp=False
209
+ )
210
+
211
+ scheduler_params = {
212
+ "max_lr": optimizer_params.lr,
213
+ "pct_start": float(0),
214
+ "epochs": epochs,
215
+ "steps_per_epoch": len(train_dataloader),
216
+ }
217
+ scheduler_params_dict= {key: scheduler_params.copy() for key in model}
218
+ scheduler_params_dict['bert']['max_lr'] = optimizer_params.bert_lr * 2
219
+ scheduler_params_dict['decoder']['max_lr'] = optimizer_params.ft_lr * 2
220
+ scheduler_params_dict['style_encoder']['max_lr'] = optimizer_params.ft_lr * 2
221
+
222
+ optimizer = build_optimizer({key: model[key].parameters() for key in model},
223
+ scheduler_params_dict=scheduler_params_dict, lr=optimizer_params.lr)
224
+
225
+ # adjust BERT learning rate
226
+ for g in optimizer.optimizers['bert'].param_groups:
227
+ g['betas'] = (0.9, 0.99)
228
+ g['lr'] = optimizer_params.bert_lr
229
+ g['initial_lr'] = optimizer_params.bert_lr
230
+ g['min_lr'] = 0
231
+ g['weight_decay'] = 0.01
232
+
233
+ # adjust acoustic module learning rate
234
+ for module in ["decoder", "style_encoder"]:
235
+ for g in optimizer.optimizers[module].param_groups:
236
+ g['betas'] = (0.0, 0.99)
237
+ g['lr'] = optimizer_params.ft_lr
238
+ g['initial_lr'] = optimizer_params.ft_lr
239
+ g['min_lr'] = 0
240
+ g['weight_decay'] = 1e-4
241
+
242
+ # load models if there is a model
243
+ if load_pretrained:
244
+ model, optimizer, start_epoch, iters = load_checkpoint(model, optimizer, config['pretrained_model'],
245
+ load_only_params=config.get('load_only_params', True))
246
+
247
+ n_down = model.text_aligner.n_down
248
+
249
+ best_loss = float('inf') # best test loss
250
+ loss_train_record = list([])
251
+ loss_test_record = list([])
252
+ iters = 0
253
+
254
+ criterion = nn.L1Loss() # F0 loss (regression)
255
+ torch.cuda.empty_cache()
256
+
257
+ stft_loss = MultiResolutionSTFTLoss().to(device)
258
+
259
+ print('BERT', optimizer.optimizers['bert'])
260
+ print('decoder', optimizer.optimizers['decoder'])
261
+
262
+ start_ds = False
263
+
264
+ running_std = []
265
+
266
+ slmadv_params = Munch(config['slmadv_params'])
267
+ slmadv = SLMAdversarialLoss(model, wl, sampler,
268
+ slmadv_params.min_len,
269
+ slmadv_params.max_len,
270
+ batch_percentage=slmadv_params.batch_percentage,
271
+ skip_update=slmadv_params.iter,
272
+ sig=slmadv_params.sig
273
+ )
274
+
275
+ for k, v in optimizer.optimizers.items():
276
+ optimizer.optimizers[k] = accelerator.prepare(optimizer.optimizers[k])
277
+ optimizer.schedulers[k] = accelerator.prepare(optimizer.schedulers[k])
278
+
279
+ train_dataloader = accelerator.prepare(train_dataloader)
280
+
281
+
282
+ for epoch in range(start_epoch, epochs):
283
+ running_loss = 0
284
+ start_time = time.time()
285
+
286
+ _ = [model[key].eval() for key in model]
287
+
288
+ model.text_aligner.train()
289
+ model.text_encoder.train()
290
+
291
+ model.predictor.train()
292
+ model.bert_encoder.train()
293
+ model.bert.train()
294
+ model.msd.train()
295
+ model.mpd.train()
296
+
297
+ for i, batch in enumerate(train_dataloader):
298
+ waves = batch[0]
299
+ batch = [b.to(device) for b in batch[1:]]
300
+ texts, input_lengths, ref_texts, ref_lengths, mels, mel_input_length, ref_mels = batch
301
+ with torch.no_grad():
302
+ mask = length_to_mask(mel_input_length // (2 ** n_down)).to(device)
303
+ mel_mask = length_to_mask(mel_input_length).to(device)
304
+ text_mask = length_to_mask(input_lengths).to(texts.device)
305
+
306
+ # compute reference styles
307
+ if multispeaker and epoch >= diff_epoch:
308
+ ref_ss = model.style_encoder(ref_mels.unsqueeze(1))
309
+ ref_sp = model.predictor_encoder(ref_mels.unsqueeze(1))
310
+ ref = torch.cat([ref_ss, ref_sp], dim=1)
311
+
312
+ try:
313
+ ppgs, s2s_pred, s2s_attn = model.text_aligner(mels, mask, texts)
314
+ s2s_attn = s2s_attn.transpose(-1, -2)
315
+ s2s_attn = s2s_attn[..., 1:]
316
+ s2s_attn = s2s_attn.transpose(-1, -2)
317
+ except:
318
+ continue
319
+
320
+ mask_ST = mask_from_lens(s2s_attn, input_lengths, mel_input_length // (2 ** n_down))
321
+ s2s_attn_mono = maximum_path(s2s_attn, mask_ST)
322
+
323
+ # encode
324
+ t_en = model.text_encoder(texts, input_lengths, text_mask)
325
+
326
+ # 50% of chance of using monotonic version
327
+ if bool(random.getrandbits(1)):
328
+ asr = (t_en @ s2s_attn)
329
+ else:
330
+ asr = (t_en @ s2s_attn_mono)
331
+
332
+ d_gt = s2s_attn_mono.sum(axis=-1).detach()
333
+
334
+ # compute the style of the entire utterance
335
+ # this operation cannot be done in batch because of the avgpool layer (may need to work on masked avgpool)
336
+ ss = []
337
+ gs = []
338
+ for bib in range(len(mel_input_length)):
339
+ mel_length = int(mel_input_length[bib].item())
340
+ mel = mels[bib, :, :mel_input_length[bib]]
341
+ s = model.predictor_encoder(mel.unsqueeze(0).unsqueeze(1))
342
+ ss.append(s)
343
+ s = model.style_encoder(mel.unsqueeze(0).unsqueeze(1))
344
+ gs.append(s)
345
+
346
+ s_dur = torch.stack(ss).squeeze() # global prosodic styles
347
+ gs = torch.stack(gs).squeeze() # global acoustic styles
348
+ s_trg = torch.cat([gs, s_dur], dim=-1).detach() # ground truth for denoiser
349
+
350
+ bert_dur = model.bert(texts, attention_mask=(~text_mask).int())
351
+ d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
352
+
353
+ # denoiser training
354
+ if epoch >= diff_epoch:
355
+ num_steps = np.random.randint(3, 5)
356
+
357
+ if model_params.diffusion.dist.estimate_sigma_data:
358
+ model.diffusion.diffusion.sigma_data = s_trg.std(axis=-1).mean().item() # batch-wise std estimation
359
+ running_std.append(model.diffusion.diffusion.sigma_data)
360
+
361
+ if multispeaker:
362
+ s_preds = sampler(noise = torch.randn_like(s_trg).unsqueeze(1).to(device),
363
+ embedding=bert_dur,
364
+ embedding_scale=1,
365
+ features=ref, # reference from the same speaker as the embedding
366
+ embedding_mask_proba=0.1,
367
+ num_steps=num_steps).squeeze(1)
368
+ loss_diff = model.diffusion(s_trg.unsqueeze(1), embedding=bert_dur, features=ref).mean() # EDM loss
369
+ loss_sty = F.l1_loss(s_preds, s_trg.detach()) # style reconstruction loss
370
+ else:
371
+ s_preds = sampler(noise = torch.randn_like(s_trg).unsqueeze(1).to(device),
372
+ embedding=bert_dur,
373
+ embedding_scale=1,
374
+ embedding_mask_proba=0.1,
375
+ num_steps=num_steps).squeeze(1)
376
+ loss_diff = model.diffusion.diffusion(s_trg.unsqueeze(1), embedding=bert_dur).mean() # EDM loss
377
+ loss_sty = F.l1_loss(s_preds, s_trg.detach()) # style reconstruction loss
378
+ else:
379
+ loss_sty = 0
380
+ loss_diff = 0
381
+
382
+
383
+ s_loss = 0
384
+
385
+
386
+ d, p = model.predictor(d_en, s_dur,
387
+ input_lengths,
388
+ s2s_attn_mono,
389
+ text_mask)
390
+
391
+ mel_len_st = int(mel_input_length.min().item() / 2 - 1)
392
+
393
+
394
+ mel_input_length_all = accelerator.gather(mel_input_length) # for balanced load
395
+ mel_len = min([int(mel_input_length_all.min().item() / 2 - 1), max_len // 2])
396
+
397
+
398
+ en = []
399
+ gt = []
400
+ p_en = []
401
+ wav = []
402
+ st = []
403
+
404
+ for bib in range(len(mel_input_length)):
405
+ mel_length = int(mel_input_length[bib].item() / 2)
406
+
407
+ random_start = np.random.randint(0, mel_length - mel_len)
408
+ en.append(asr[bib, :, random_start:random_start+mel_len])
409
+ p_en.append(p[bib, :, random_start:random_start+mel_len])
410
+ gt.append(mels[bib, :, (random_start * 2):((random_start+mel_len) * 2)])
411
+
412
+ y = waves[bib][(random_start * 2) * 300:((random_start+mel_len) * 2) * 300]
413
+ wav.append(torch.from_numpy(y).to(device))
414
+
415
+ # style reference (better to be different from the GT)
416
+ random_start = np.random.randint(0, mel_length - mel_len_st)
417
+ st.append(mels[bib, :, (random_start * 2):((random_start+mel_len_st) * 2)])
418
+
419
+ wav = torch.stack(wav).float().detach()
420
+
421
+ en = torch.stack(en)
422
+ p_en = torch.stack(p_en)
423
+ gt = torch.stack(gt).detach()
424
+ st = torch.stack(st).detach()
425
+
426
+
427
+ if gt.size(-1) < 80:
428
+ continue
429
+
430
+ s = model.style_encoder(gt.unsqueeze(1))
431
+ s_dur = model.predictor_encoder(gt.unsqueeze(1))
432
+
433
+ with torch.no_grad():
434
+ F0_real, _, F0 = model.pitch_extractor(gt.unsqueeze(1))
435
+ F0 = F0.reshape(F0.shape[0], F0.shape[1] * 2, F0.shape[2], 1).squeeze()
436
+
437
+ N_real = log_norm(gt.unsqueeze(1)).squeeze(1)
438
+
439
+ y_rec_gt = wav.unsqueeze(1)
440
+ y_rec_gt_pred = model.decoder(en, F0_real, N_real, s)
441
+
442
+ wav = y_rec_gt
443
+
444
+ F0_fake, N_fake = model.predictor.F0Ntrain(p_en, s_dur)
445
+
446
+ y_rec = model.decoder(en, F0_fake, N_fake, s)
447
+
448
+ loss_F0_rec = (F.smooth_l1_loss(F0_real, F0_fake)) / 10
449
+ loss_norm_rec = F.smooth_l1_loss(N_real, N_fake)
450
+
451
+ optimizer.zero_grad()
452
+ d_loss = dl(wav.detach(), y_rec.detach()).mean()
453
+ accelerator.backward(d_loss)
454
+ optimizer.step('msd')
455
+ optimizer.step('mpd')
456
+
457
+ # generator loss
458
+ optimizer.zero_grad()
459
+
460
+ loss_mel = stft_loss(y_rec, wav)
461
+ loss_gen_all = gl(wav, y_rec).mean()
462
+ loss_lm = wl(wav.detach().squeeze(), y_rec.squeeze()).mean()
463
+
464
+ loss_ce = 0
465
+ loss_dur = 0
466
+ for _s2s_pred, _text_input, _text_length in zip(d, (d_gt), input_lengths):
467
+ _s2s_pred = _s2s_pred[:_text_length, :]
468
+ _text_input = _text_input[:_text_length].long()
469
+ _s2s_trg = torch.zeros_like(_s2s_pred)
470
+ for p in range(_s2s_trg.shape[0]):
471
+ _s2s_trg[p, :_text_input[p]] = 1
472
+ _dur_pred = torch.sigmoid(_s2s_pred).sum(axis=1)
473
+
474
+ loss_dur += F.l1_loss(_dur_pred[1:_text_length-1],
475
+ _text_input[1:_text_length-1])
476
+ loss_ce += F.binary_cross_entropy_with_logits(_s2s_pred.flatten(), _s2s_trg.flatten())
477
+
478
+ loss_ce /= texts.size(0)
479
+ loss_dur /= texts.size(0)
480
+
481
+ loss_s2s = 0
482
+ for _s2s_pred, _text_input, _text_length in zip(s2s_pred, texts, input_lengths):
483
+ loss_s2s += F.cross_entropy(_s2s_pred[:_text_length], _text_input[:_text_length])
484
+ loss_s2s /= texts.size(0)
485
+
486
+ loss_mono = F.l1_loss(s2s_attn, s2s_attn_mono) * 10
487
+
488
+ g_loss = loss_params.lambda_mel * loss_mel + \
489
+ loss_params.lambda_F0 * loss_F0_rec + \
490
+ loss_params.lambda_ce * loss_ce + \
491
+ loss_params.lambda_norm * loss_norm_rec + \
492
+ loss_params.lambda_dur * loss_dur + \
493
+ loss_params.lambda_gen * loss_gen_all + \
494
+ loss_params.lambda_slm * loss_lm + \
495
+ loss_params.lambda_sty * loss_sty + \
496
+ loss_params.lambda_diff * loss_diff + \
497
+ loss_params.lambda_mono * loss_mono + \
498
+ loss_params.lambda_s2s * loss_s2s
499
+
500
+ running_loss += accelerator.gather(loss_mel).mean().item()
501
+ accelerator.backward(g_loss)
502
+
503
+ # if torch.isnan(g_loss):
504
+ # from IPython.core.debugger import set_trace
505
+ # set_trace()
506
+
507
+ optimizer.step('bert_encoder')
508
+ optimizer.step('bert')
509
+ optimizer.step('predictor')
510
+ optimizer.step('predictor_encoder')
511
+ optimizer.step('style_encoder')
512
+ optimizer.step('decoder')
513
+
514
+ optimizer.step('text_encoder')
515
+ optimizer.step('text_aligner')
516
+
517
+ if epoch >= diff_epoch:
518
+ optimizer.step('diffusion')
519
+
520
+ d_loss_slm, loss_gen_lm = 0, 0
521
+ if epoch >= joint_epoch:
522
+ # randomly pick whether to use in-distribution text
523
+ if np.random.rand() < 0.5:
524
+ use_ind = True
525
+ else:
526
+ use_ind = False
527
+
528
+ if use_ind:
529
+ ref_lengths = input_lengths
530
+ ref_texts = texts
531
+
532
+ slm_out = slmadv(i,
533
+ y_rec_gt,
534
+ y_rec_gt_pred,
535
+ waves,
536
+ mel_input_length,
537
+ ref_texts,
538
+ ref_lengths, use_ind, s_trg.detach(), ref if multispeaker else None)
539
+
540
+ if slm_out is not None:
541
+ d_loss_slm, loss_gen_lm, y_pred = slm_out
542
+
543
+ # SLM generator loss
544
+ optimizer.zero_grad()
545
+ accelerator.backward(loss_gen_lm)
546
+
547
+ # compute the gradient norm
548
+ total_norm = {}
549
+ for key in model.keys():
550
+ total_norm[key] = 0
551
+ parameters = [p for p in model[key].parameters() if p.grad is not None and p.requires_grad]
552
+ for p in parameters:
553
+ param_norm = p.grad.detach().data.norm(2)
554
+ total_norm[key] += param_norm.item() ** 2
555
+ total_norm[key] = total_norm[key] ** 0.5
556
+
557
+ # gradient scaling
558
+ if total_norm['predictor'] > slmadv_params.thresh:
559
+ for key in model.keys():
560
+ for p in model[key].parameters():
561
+ if p.grad is not None:
562
+ p.grad *= (1 / total_norm['predictor'])
563
+
564
+ for p in model.predictor.duration_proj.parameters():
565
+ if p.grad is not None:
566
+ p.grad *= slmadv_params.scale
567
+
568
+ for p in model.predictor.lstm.parameters():
569
+ if p.grad is not None:
570
+ p.grad *= slmadv_params.scale
571
+
572
+ for p in model.diffusion.parameters():
573
+ if p.grad is not None:
574
+ p.grad *= slmadv_params.scale
575
+
576
+ optimizer.step('bert_encoder')
577
+ optimizer.step('bert')
578
+ optimizer.step('predictor')
579
+ optimizer.step('diffusion')
580
+
581
+ # SLM discriminator loss
582
+ if d_loss_slm != 0:
583
+ optimizer.zero_grad()
584
+ accelerator.backward(d_loss_slm)
585
+ optimizer.step('wd')
586
+
587
+ iters = iters + 1
588
+
589
+ if (i + 1) % log_interval == 0:
590
+ logger.info ('Epoch [%d/%d], Step [%d/%d], Loss: %.5f, Disc Loss: %.5f, Dur Loss: %.5f, CE Loss: %.5f, Norm Loss: %.5f, F0 Loss: %.5f, LM Loss: %.5f, Gen Loss: %.5f, Sty Loss: %.5f, Diff Loss: %.5f, DiscLM Loss: %.5f, GenLM Loss: %.5f, SLoss: %.5f, S2S Loss: %.5f, Mono Loss: %.5f'
591
+ %(epoch+1, epochs, i+1, len(train_list)//batch_size, running_loss / log_interval, d_loss, loss_dur, loss_ce, loss_norm_rec, loss_F0_rec, loss_lm, loss_gen_all, loss_sty, loss_diff, d_loss_slm, loss_gen_lm, s_loss, loss_s2s, loss_mono), main_process_only=True)
592
+ if accelerator.is_main_process:
593
+ print ('Epoch [%d/%d], Step [%d/%d], Loss: %.5f, Disc Loss: %.5f, Dur Loss: %.5f, CE Loss: %.5f, Norm Loss: %.5f, F0 Loss: %.5f, LM Loss: %.5f, Gen Loss: %.5f, Sty Loss: %.5f, Diff Loss: %.5f, DiscLM Loss: %.5f, GenLM Loss: %.5f, SLoss: %.5f, S2S Loss: %.5f, Mono Loss: %.5f'
594
+ %(epoch+1, epochs, i+1, len(train_list)//batch_size, running_loss / log_interval, d_loss, loss_dur, loss_ce, loss_norm_rec, loss_F0_rec, loss_lm, loss_gen_all, loss_sty, loss_diff, d_loss_slm, loss_gen_lm, s_loss, loss_s2s, loss_mono))
595
+ accelerator.log({'train/mel_loss': float(running_loss / log_interval),
596
+ 'train/gen_loss': float(loss_gen_all),
597
+ 'train/d_loss': float(d_loss),
598
+ 'train/ce_loss': float(loss_ce),
599
+ 'train/dur_loss': float(loss_dur),
600
+ 'train/slm_loss': float(loss_lm),
601
+ 'train/norm_loss': float(loss_norm_rec),
602
+ 'train/F0_loss': float(loss_F0_rec),
603
+ 'train/sty_loss': float(loss_sty),
604
+ 'train/diff_loss': float(loss_diff),
605
+ 'train/d_loss_slm': float(d_loss_slm),
606
+ 'train/gen_loss_slm': float(loss_gen_lm),
607
+ 'epoch': int(epoch) + 1}, step=iters)
608
+
609
+ running_loss = 0
610
+
611
+ accelerator.print('Time elasped:', time.time() - start_time)
612
+
613
+
614
+ loss_test = 0
615
+ loss_align = 0
616
+ loss_f = 0
617
+ _ = [model[key].eval() for key in model]
618
+
619
+ with torch.no_grad():
620
+ iters_test = 0
621
+ for batch_idx, batch in enumerate(val_dataloader):
622
+ optimizer.zero_grad()
623
+
624
+ try:
625
+ waves = batch[0]
626
+ batch = [b.to(device) for b in batch[1:]]
627
+ texts, input_lengths, ref_texts, ref_lengths, mels, mel_input_length, ref_mels = batch
628
+ with torch.no_grad():
629
+ mask = length_to_mask(mel_input_length // (2 ** n_down)).to('cuda')
630
+ text_mask = length_to_mask(input_lengths).to(texts.device)
631
+
632
+ _, _, s2s_attn = model.text_aligner(mels, mask, texts)
633
+ s2s_attn = s2s_attn.transpose(-1, -2)
634
+ s2s_attn = s2s_attn[..., 1:]
635
+ s2s_attn = s2s_attn.transpose(-1, -2)
636
+
637
+ mask_ST = mask_from_lens(s2s_attn, input_lengths, mel_input_length // (2 ** n_down))
638
+ s2s_attn_mono = maximum_path(s2s_attn, mask_ST)
639
+
640
+ # encode
641
+ t_en = model.text_encoder(texts, input_lengths, text_mask)
642
+ asr = (t_en @ s2s_attn_mono)
643
+
644
+ d_gt = s2s_attn_mono.sum(axis=-1).detach()
645
+
646
+ ss = []
647
+ gs = []
648
+
649
+ for bib in range(len(mel_input_length)):
650
+ mel_length = int(mel_input_length[bib].item())
651
+ mel = mels[bib, :, :mel_input_length[bib]]
652
+ s = model.predictor_encoder(mel.unsqueeze(0).unsqueeze(1))
653
+ ss.append(s)
654
+ s = model.style_encoder(mel.unsqueeze(0).unsqueeze(1))
655
+ gs.append(s)
656
+
657
+ s = torch.stack(ss).squeeze()
658
+ gs = torch.stack(gs).squeeze()
659
+ s_trg = torch.cat([s, gs], dim=-1).detach()
660
+
661
+ bert_dur = model.bert(texts, attention_mask=(~text_mask).int())
662
+ d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
663
+ d, p = model.predictor(d_en, s,
664
+ input_lengths,
665
+ s2s_attn_mono,
666
+ text_mask)
667
+ # get clips
668
+ mel_len = int(mel_input_length.min().item() / 2 - 1)
669
+ en = []
670
+ gt = []
671
+
672
+ p_en = []
673
+ wav = []
674
+
675
+ for bib in range(len(mel_input_length)):
676
+ mel_length = int(mel_input_length[bib].item() / 2)
677
+
678
+ random_start = np.random.randint(0, mel_length - mel_len)
679
+ en.append(asr[bib, :, random_start:random_start+mel_len])
680
+ p_en.append(p[bib, :, random_start:random_start+mel_len])
681
+
682
+ gt.append(mels[bib, :, (random_start * 2):((random_start+mel_len) * 2)])
683
+ y = waves[bib][(random_start * 2) * 300:((random_start+mel_len) * 2) * 300]
684
+ wav.append(torch.from_numpy(y).to(device))
685
+
686
+ wav = torch.stack(wav).float().detach()
687
+
688
+ en = torch.stack(en)
689
+ p_en = torch.stack(p_en)
690
+ gt = torch.stack(gt).detach()
691
+ s = model.predictor_encoder(gt.unsqueeze(1))
692
+
693
+ F0_fake, N_fake = model.predictor.F0Ntrain(p_en, s)
694
+
695
+ loss_dur = 0
696
+ for _s2s_pred, _text_input, _text_length in zip(d, (d_gt), input_lengths):
697
+ _s2s_pred = _s2s_pred[:_text_length, :]
698
+ _text_input = _text_input[:_text_length].long()
699
+ _s2s_trg = torch.zeros_like(_s2s_pred)
700
+ for bib in range(_s2s_trg.shape[0]):
701
+ _s2s_trg[bib, :_text_input[bib]] = 1
702
+ _dur_pred = torch.sigmoid(_s2s_pred).sum(axis=1)
703
+ loss_dur += F.l1_loss(_dur_pred[1:_text_length-1],
704
+ _text_input[1:_text_length-1])
705
+
706
+ loss_dur /= texts.size(0)
707
+
708
+ s = model.style_encoder(gt.unsqueeze(1))
709
+
710
+ y_rec = model.decoder(en, F0_fake, N_fake, s)
711
+ loss_mel = stft_loss(y_rec.squeeze(), wav.detach())
712
+
713
+ F0_real, _, F0 = model.pitch_extractor(gt.unsqueeze(1))
714
+
715
+ loss_F0 = F.l1_loss(F0_real, F0_fake) / 10
716
+
717
+
718
+ loss_test += accelerator.gather(loss_mel).mean()
719
+ loss_align += accelerator.gather(loss_dur).mean()
720
+ loss_f += accelerator.gather(loss_F0).mean()
721
+
722
+ iters_test += 1
723
+ except:
724
+ continue
725
+
726
+
727
+ accelerator.print('Epochs:', epoch + 1)
728
+ try:
729
+ logger.info('Validation loss: %.3f, Dur loss: %.3f, F0 loss: %.3f' % (
730
+ loss_test / iters_test, loss_align / iters_test, loss_f / iters_test) + '\n', main_process_only=True)
731
+
732
+
733
+ accelerator.log({'eval/mel_loss': float(loss_test / iters_test),
734
+ 'eval/dur_loss': float(loss_test / iters_test),
735
+ 'eval/F0_loss': float(loss_f / iters_test)},
736
+ step=(i + 1) * (epoch + 1))
737
+ except ZeroDivisionError:
738
+ accelerator.print("Eval loss was divided by zero... skipping eval cycle")
739
+
740
+ if epoch % saving_epoch == 0:
741
+ if (loss_test / iters_test) < best_loss:
742
+ best_loss = loss_test / iters_test
743
+ try:
744
+ accelerator.print('Saving..')
745
+ state = {
746
+ 'net': {key: model[key].state_dict() for key in model},
747
+ 'optimizer': optimizer.state_dict(),
748
+ 'iters': iters,
749
+ 'val_loss': loss_test / iters_test,
750
+ 'epoch': epoch,
751
+ }
752
+ except ZeroDivisionError:
753
+ accelerator.print('No iter test, Re-Saving..')
754
+ state = {
755
+ 'net': {key: model[key].state_dict() for key in model},
756
+ 'optimizer': optimizer.state_dict(),
757
+ 'iters': iters,
758
+ 'val_loss': 0.1, # not zero just in case
759
+ 'epoch': epoch,
760
+ }
761
+
762
+ if accelerator.is_main_process:
763
+ save_path = osp.join(log_dir, 'epoch_2nd_%05d.pth' % epoch)
764
+ torch.save(state, save_path)
765
+
766
+ # if estimate sigma, save the estimated simga
767
+ if model_params.diffusion.dist.estimate_sigma_data:
768
+ config['model_params']['diffusion']['dist']['sigma_data'] = float(np.mean(running_std))
769
+
770
+ with open(osp.join(log_dir, osp.basename(config_path)), 'w') as outfile:
771
+ yaml.dump(config, outfile, default_flow_style=True)
772
+ if accelerator.is_main_process:
773
+ print('Saving last pth..')
774
+ state = {
775
+ 'net': {key: model[key].state_dict() for key in model},
776
+ 'optimizer': optimizer.state_dict(),
777
+ 'iters': iters,
778
+ 'val_loss': loss_test / iters_test,
779
+ 'epoch': epoch,
780
+ }
781
+ save_path = osp.join(log_dir, '2nd_phase_last.pth')
782
+ torch.save(state, save_path)
783
+
784
+ accelerator.end_training()
785
+
786
+
787
+ if __name__ == "__main__":
788
+ main()