import argparse import filecmp import multiprocessing import os import subprocess import librosa from functools import partial from multiprocessing import Pool, Process import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import torch.distributed as dist from torch.optim import AdamW from modules.vocoder.commons.stft_loss import MultiResolutionSTFTLoss from modules.vocoder.hifigan.hifigan import MultiPeriodDiscriminator, MultiScaleDiscriminator, \ generator_loss, feature_loss, discriminator_loss from modules.vocoder.hifigan.mel_utils import mel_spectrogram from modules.vocoder.univnet.mrd import MultiResolutionDiscriminator from modules.tts.wavvae.decoder.wavvae_v3 import WavVAE_V3 from tasks.tts.utils.audio import torch_wav2spec from tasks.tts.utils.audio.align import mel2token_to_dur from utils.commons.ckpt_utils import load_ckpt from utils.commons.hparams import hparams from attrdict import AttrDict from tasks.tts.dataset_mixin import TTSDatasetMixin from utils.commons.base_task import BaseTask from utils.commons.import_utils import import_module_bystr from utils.nn.schedulers import WarmupSchedule, CosineSchedule class WavVAETask(TTSDatasetMixin, BaseTask): def __init__(self): super().__init__() self.dataset_cls = import_module_bystr(hparams['dataset_cls']) self.val_dataset_cls = import_module_bystr(hparams['val_dataset_cls']) self.processer_fn = import_module_bystr(hparams['processer_fn']) self.build_fast_dataloader = import_module_bystr(hparams['build_fast_dataloader']) self.hparams = hparams self.config = AttrDict(hparams) # Online load mel with GPU sample_rate = hparams["audio_sample_rate"] fft_size = hparams["win_size"] win_size = hparams["win_size"] hop_size = hparams["hop_size"] num_mels = hparams["audio_num_mel_bins"] fmin = hparams["fmin"] fmax = hparams["fmax"] mel_basis = librosa.filters.mel( sr=sample_rate, n_fft=fft_size, n_mels=num_mels, fmin=fmin, fmax=fmax ) self.torch_wav2spec_ = partial( torch_wav2spec, mel_basis=mel_basis, fft_size=fft_size, hop_size=hop_size, win_length=win_size, ) def build_model(self): self.model_gen = WavVAE_V3(hparams=hparams) self.model_disc = torch.nn.ModuleDict() self.model_disc['mpd'] = MultiPeriodDiscriminator(hparams['mpd'], use_cond=hparams['use_cond_disc']) self.model_disc['msd'] = MultiScaleDiscriminator(use_cond=hparams['use_cond_disc']) if hparams['use_mrd']: self.model_disc['mrd'] = MultiResolutionDiscriminator(hparams) self.stft_loss = MultiResolutionSTFTLoss() load_ckpt(self.model_gen.encoder, './checkpoints/1231_megatts3_wavvae_v2_25hz', 'model.module.encoder', strict=False) load_ckpt(self.model_gen.decoder, './checkpoints/1117_melgan-nsf_full_1', 'model_gen', force=True, strict=True) load_ckpt(self.model_disc, './checkpoints/1117_melgan-nsf_full_1', 'model_disc', force=True, strict=True) return {'trainable': [self.model_gen, self.model_disc['mpd'], self.model_disc['msd'], self.model_disc['mrd']], 'others': []} def load_model(self): if hparams.get('load_ckpt', '') != '': load_ckpt(self.model, hparams['load_ckpt'], 'model', strict=False) def build_optimizer(self): optimizer_gen = torch.optim.AdamW(self.model_gen.parameters(), lr=hparams['lr'], betas=[hparams['adam_b1'], hparams['adam_b2']]) optimizer_disc = torch.optim.AdamW(self.model_disc.parameters(), lr=hparams.get('disc_lr', hparams['lr']), betas=[hparams['adam_b1'], hparams['adam_b2']]) return [optimizer_gen, optimizer_disc] def build_scheduler(self, optimizer): return None def _training_step(self, sample, batch_idx, optimizer_idx): log_outputs = {} loss_weights = {} sample['wavs'] = sample['wavs'].float() # return None, {} if self.global_step % 100 == 0: devices = os.environ.get('CUDA_VISIBLE_DEVICES', '').split(",") for d in devices: os.system(f'pkill -f "voidgpu{d}"') y = sample['wavs'] loss_output = {} if optimizer_idx == 0: ####################### # Generator # ####################### y_, posterior = self.model_gen(y) y = y.unsqueeze(1) y_mel = mel_spectrogram(y.squeeze(1), hparams).transpose(1, 2) y_hat_mel = mel_spectrogram(y_.squeeze(1), hparams).transpose(1, 2) loss_output['mel'] = F.l1_loss(y_hat_mel, y_mel) * hparams['lambda_mel'] if self.training: _, y_p_hat_g, fmap_f_r, fmap_f_g = self.model_disc['mpd'](y, y_, None) _, y_s_hat_g, fmap_s_r, fmap_s_g = self.model_disc['msd'](y, y_, None) loss_output['a_p'] = generator_loss(y_p_hat_g) * hparams['lambda_adv'] * hparams.get('lambda_mpd', 1.0) loss_output['a_s'] = generator_loss(y_s_hat_g) * hparams['lambda_adv'] * hparams.get('lambda_msd', 1.0) if hparams['use_mrd']: y_r_hat_g = [x[1] for x in self.model_disc['mrd'](y_)] loss_output['a_r'] = generator_loss(y_r_hat_g) \ * hparams['lambda_adv'] * hparams.get('lambda_mrd', 1.0) if hparams['use_ms_stft']: loss_output['sc'], loss_output['mag'] = self.stft_loss(y.squeeze(1), y_.squeeze(1)) loss_output['kl_loss'] = posterior.kl().mean() * hparams.get('lambda_kl', 1.0) self.y_ = y_.detach() else: ####################### # Discriminator # ####################### if not self.training: return None y = y.unsqueeze(1) y_ = self.y_ # MPD y_p_hat_r, y_p_hat_g, _, _ = self.model_disc['mpd'](y, y_.detach(), None) loss_output['r_p'], loss_output['f_p'] = discriminator_loss(y_p_hat_r, y_p_hat_g) # MSD y_s_hat_r, y_s_hat_g, _, _ = self.model_disc['msd'](y, y_.detach(), None) loss_output['r_s'], loss_output['f_s'] = discriminator_loss(y_s_hat_r, y_s_hat_g) # MRD if hparams['use_mrd']: y_r_hat_r = [x[1] for x in self.model_disc['mrd'](y)] y_r_hat_g = [x[1] for x in self.model_disc['mrd'](y_.detach())] loss_output['r_r'], loss_output['f_r'] = discriminator_loss(y_r_hat_r, y_r_hat_g) total_loss = sum(loss_output.values()) loss_output['bs'] = sample['wavs'].shape[0] return total_loss, loss_output def save_valid_result(self, sample, batch_idx, model_out): sr = hparams['audio_sample_rate'] mel_out = model_out.get('mel_out') f0 = sample.get('f0') f0_gt = sample.get('f0') if f0 is not None: f0_gt = f0_gt.cpu()[-1] if mel_out is not None: f0_pred = self.predict_f0(sample['mels']) self.plot_mel(batch_idx, sample['mels'], mel_out, f0s={'f0': f0_pred, 'f0g': f0_gt}) # gt wav if self.global_step <= hparams['valid_infer_interval']: mel_gt = sample['mels'][-1].cpu() f0 = self.predict_f0(sample['mels'][-1:]) wav_gt = self.vocoder.spec2wav(mel_gt, f0=f0) self.logger.add_audio(f'wav_gt_{batch_idx}', wav_gt, self.global_step, sr) if self.global_step >= 0: # with gt duration model_out = self.run_model(sample, infer=True, infer_use_gt_dur=True) # dur_info = self.get_plot_dur_info(sample, model_out) # del dur_info['dur_pred'] dur_info = None f0 = self.predict_f0(model_out['mel_out']) wav_pred = self.vocoder.spec2wav(model_out['mel_out'][-1].cpu(), f0=f0) self.logger.add_audio(f'wav_gdur_{batch_idx}', wav_pred, self.global_step, sr) self.plot_mel(batch_idx, sample['mels'][-1:], model_out['mel_out'][-1], f'mel_gdur_{batch_idx}', dur_info=dur_info, f0s={'f0': f0, 'f0g': f0_gt}) # with pred duration if not hparams['use_gt_dur'] and not hparams['use_gt_latent']: model_out = self.run_model(sample, infer=True, infer_use_gt_dur=False) # dur_info = self.get_plot_dur_info(sample, model_out) dur_info = None f0 = self.predict_f0(model_out['mel_out']) self.plot_mel( batch_idx, sample['mels'], model_out['mel_out'][-1], f'mel_pdur_{batch_idx}', dur_info=dur_info, f0s={'f0': f0, 'f0g': f0_gt}) wav_pred = self.vocoder.spec2wav(model_out['mel_out'][-1].cpu(), f0=f0) self.logger.add_audio(f'wav_pdur_{batch_idx}', wav_pred, self.global_step, sr) def get_plot_dur_info(self, sample, model_out): T_txt = sample['txt_tokens'].shape[1] dur_gt = mel2token_to_dur(sample['mel2ph'], T_txt)[-1] dur_pred = model_out['dur'] if 'dur' in model_out else dur_gt txt = self.token_encoder.decode(sample['txt_tokens'][-1].cpu().numpy()) txt = txt.split(" ") return {'dur_gt': dur_gt, 'dur_pred': dur_pred, 'txt': txt} def on_before_optimization(self, opt_idx): if opt_idx == 0: nn.utils.clip_grad_norm_(self.model_gen.parameters(), hparams['generator_grad_norm']) else: nn.utils.clip_grad_norm_(self.model_disc.parameters(), hparams["discriminator_grad_norm"]) def to(self, device=None, dtype=None): super().to(device=device, dtype=dtype) # trainer doesn't move ema to device automatically, we do it mannually if hparams.get('use_ema', False): self.ema.to(device=device, dtype=dtype) def cuda(self,device): super().cuda(device) if hparams.get('use_ema', False): self.ema.to(device=device) @torch.no_grad() def validation_step(self, sample, batch_idx): infer_steps = self.hparams.get('infer_steps', 12) outputs = self._validation_step(sample, batch_idx, infer_steps) return outputs def _validation_step(self, sample, batch_idx, infer_steps): outputs = {} if self.trainer.proc_rank == 0: # self.vae.eval() # with torch.inference_mode(): # with torch.cuda.amp.autocast(dtype=torch.bfloat16, enabled=True): # lat = self.vae.get_latent(sample["mels"]) # lat_lens = latent_lengths.clamp(max=lat.size(1)) # mel = self.vae.decode(lat) pass # outputs['losses'], _ = self.run_model(sample) # _, model_out = self.run_model(sample, infer=True, infer_steps=infer_steps) # outputs = tensors_to_scalars(outputs) # output_ldm = model_out['ldm_out'] # T = output_ldm.shape[1] # ldm = sample['kps'][:, :T] # [B, T, nkp, kp_dim] [0, 1] # B, T, nkp, kp_dim = ldm.shape # output_ldm = self.denormalize_ldm(output_ldm) # recon_ldm = model_out['recon_ldm'] # recon_ldm = self.denormalize_ldm(recon_ldm) # results_dir = f"{hparams['work_dir']}/results/{self.global_step}_infersteps{infer_steps}_cfg{hparams['cfg_w']}" # os.makedirs(results_dir, exist_ok=True) # n_ctx = model_out['ctx_mask'][0, :, 0].sum().long().item() # writer_kp = imageio.get_writer(f"{results_dir}/{batch_idx:06d}_kp.sil.mp4", fps=25) # writer_gt = imageio.get_writer(f"{results_dir}/{batch_idx:06d}_gt.sil.mp4", fps=25) # writer_pred = imageio.get_writer(f"{results_dir}/{batch_idx:06d}_pred.sil.mp4", fps=25) # for i in range(T): # img = self.draw_ldm(recon_ldm[0, i]) # writer_gt.append_data(img) # img = self.draw_ldm(ldm[0, i]) # writer_kp.append_data(img) # if i < n_ctx: # writer_pred.append_data(img) # else: # img = self.draw_ldm( # output_ldm[0, i], color=(255, 255, 0), # ) # writer_pred.append_data(img) # writer_gt.close() # writer_kp.close() # writer_pred.close() return outputs @torch.no_grad() def test_step(self, sample, batch_idx): infer_steps = hparams['infer_steps'] return self._validation_step(sample, batch_idx, infer_steps)