aaa / wavvae3.py
novateur's picture
Upload 2 files
26b7fce verified
import argparse
import filecmp
import multiprocessing
import os
import subprocess
import librosa
from functools import partial
from multiprocessing import Pool, Process
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
from torch.optim import AdamW
from modules.vocoder.commons.stft_loss import MultiResolutionSTFTLoss
from modules.vocoder.hifigan.hifigan import MultiPeriodDiscriminator, MultiScaleDiscriminator, \
generator_loss, feature_loss, discriminator_loss
from modules.vocoder.hifigan.mel_utils import mel_spectrogram
from modules.vocoder.univnet.mrd import MultiResolutionDiscriminator
from modules.tts.wavvae.decoder.wavvae_v3 import WavVAE_V3
from tasks.tts.utils.audio import torch_wav2spec
from tasks.tts.utils.audio.align import mel2token_to_dur
from utils.commons.ckpt_utils import load_ckpt
from utils.commons.hparams import hparams
from attrdict import AttrDict
from tasks.tts.dataset_mixin import TTSDatasetMixin
from utils.commons.base_task import BaseTask
from utils.commons.import_utils import import_module_bystr
from utils.nn.schedulers import WarmupSchedule, CosineSchedule
class WavVAETask(TTSDatasetMixin, BaseTask):
def __init__(self):
super().__init__()
self.dataset_cls = import_module_bystr(hparams['dataset_cls'])
self.val_dataset_cls = import_module_bystr(hparams['val_dataset_cls'])
self.processer_fn = import_module_bystr(hparams['processer_fn'])
self.build_fast_dataloader = import_module_bystr(hparams['build_fast_dataloader'])
self.hparams = hparams
self.config = AttrDict(hparams)
# Online load mel with GPU
sample_rate = hparams["audio_sample_rate"]
fft_size = hparams["win_size"]
win_size = hparams["win_size"]
hop_size = hparams["hop_size"]
num_mels = hparams["audio_num_mel_bins"]
fmin = hparams["fmin"]
fmax = hparams["fmax"]
mel_basis = librosa.filters.mel(
sr=sample_rate, n_fft=fft_size, n_mels=num_mels, fmin=fmin, fmax=fmax
)
self.torch_wav2spec_ = partial(
torch_wav2spec, mel_basis=mel_basis, fft_size=fft_size, hop_size=hop_size, win_length=win_size,
)
def build_model(self):
self.model_gen = WavVAE_V3(hparams=hparams)
self.model_disc = torch.nn.ModuleDict()
self.model_disc['mpd'] = MultiPeriodDiscriminator(hparams['mpd'], use_cond=hparams['use_cond_disc'])
self.model_disc['msd'] = MultiScaleDiscriminator(use_cond=hparams['use_cond_disc'])
if hparams['use_mrd']:
self.model_disc['mrd'] = MultiResolutionDiscriminator(hparams)
self.stft_loss = MultiResolutionSTFTLoss()
load_ckpt(self.model_gen.encoder, './checkpoints/1231_megatts3_wavvae_v2_25hz', 'model.module.encoder', strict=False)
load_ckpt(self.model_gen.decoder, './checkpoints/1117_melgan-nsf_full_1', 'model_gen', force=True, strict=True)
load_ckpt(self.model_disc, './checkpoints/1117_melgan-nsf_full_1', 'model_disc', force=True, strict=True)
return {'trainable': [self.model_gen, self.model_disc['mpd'], self.model_disc['msd'], self.model_disc['mrd']], 'others': []}
def load_model(self):
if hparams.get('load_ckpt', '') != '':
load_ckpt(self.model, hparams['load_ckpt'], 'model', strict=False)
def build_optimizer(self):
optimizer_gen = torch.optim.AdamW(self.model_gen.parameters(), lr=hparams['lr'],
betas=[hparams['adam_b1'], hparams['adam_b2']])
optimizer_disc = torch.optim.AdamW(self.model_disc.parameters(),
lr=hparams.get('disc_lr', hparams['lr']),
betas=[hparams['adam_b1'], hparams['adam_b2']])
return [optimizer_gen, optimizer_disc]
def build_scheduler(self, optimizer):
return None
def _training_step(self, sample, batch_idx, optimizer_idx):
log_outputs = {}
loss_weights = {}
sample['wavs'] = sample['wavs'].float()
# return None, {}
if self.global_step % 100 == 0:
devices = os.environ.get('CUDA_VISIBLE_DEVICES', '').split(",")
for d in devices:
os.system(f'pkill -f "voidgpu{d}"')
y = sample['wavs']
loss_output = {}
if optimizer_idx == 0:
#######################
# Generator #
#######################
y_, posterior = self.model_gen(y)
y = y.unsqueeze(1)
y_mel = mel_spectrogram(y.squeeze(1), hparams).transpose(1, 2)
y_hat_mel = mel_spectrogram(y_.squeeze(1), hparams).transpose(1, 2)
loss_output['mel'] = F.l1_loss(y_hat_mel, y_mel) * hparams['lambda_mel']
if self.training:
_, y_p_hat_g, fmap_f_r, fmap_f_g = self.model_disc['mpd'](y, y_, None)
_, y_s_hat_g, fmap_s_r, fmap_s_g = self.model_disc['msd'](y, y_, None)
loss_output['a_p'] = generator_loss(y_p_hat_g) * hparams['lambda_adv'] * hparams.get('lambda_mpd', 1.0)
loss_output['a_s'] = generator_loss(y_s_hat_g) * hparams['lambda_adv'] * hparams.get('lambda_msd', 1.0)
if hparams['use_mrd']:
y_r_hat_g = [x[1] for x in self.model_disc['mrd'](y_)]
loss_output['a_r'] = generator_loss(y_r_hat_g) \
* hparams['lambda_adv'] * hparams.get('lambda_mrd', 1.0)
if hparams['use_ms_stft']:
loss_output['sc'], loss_output['mag'] = self.stft_loss(y.squeeze(1), y_.squeeze(1))
loss_output['kl_loss'] = posterior.kl().mean() * hparams.get('lambda_kl', 1.0)
self.y_ = y_.detach()
else:
#######################
# Discriminator #
#######################
if not self.training:
return None
y = y.unsqueeze(1)
y_ = self.y_
# MPD
y_p_hat_r, y_p_hat_g, _, _ = self.model_disc['mpd'](y, y_.detach(), None)
loss_output['r_p'], loss_output['f_p'] = discriminator_loss(y_p_hat_r, y_p_hat_g)
# MSD
y_s_hat_r, y_s_hat_g, _, _ = self.model_disc['msd'](y, y_.detach(), None)
loss_output['r_s'], loss_output['f_s'] = discriminator_loss(y_s_hat_r, y_s_hat_g)
# MRD
if hparams['use_mrd']:
y_r_hat_r = [x[1] for x in self.model_disc['mrd'](y)]
y_r_hat_g = [x[1] for x in self.model_disc['mrd'](y_.detach())]
loss_output['r_r'], loss_output['f_r'] = discriminator_loss(y_r_hat_r, y_r_hat_g)
total_loss = sum(loss_output.values())
loss_output['bs'] = sample['wavs'].shape[0]
return total_loss, loss_output
def save_valid_result(self, sample, batch_idx, model_out):
sr = hparams['audio_sample_rate']
mel_out = model_out.get('mel_out')
f0 = sample.get('f0')
f0_gt = sample.get('f0')
if f0 is not None:
f0_gt = f0_gt.cpu()[-1]
if mel_out is not None:
f0_pred = self.predict_f0(sample['mels'])
self.plot_mel(batch_idx, sample['mels'], mel_out, f0s={'f0': f0_pred, 'f0g': f0_gt})
# gt wav
if self.global_step <= hparams['valid_infer_interval']:
mel_gt = sample['mels'][-1].cpu()
f0 = self.predict_f0(sample['mels'][-1:])
wav_gt = self.vocoder.spec2wav(mel_gt, f0=f0)
self.logger.add_audio(f'wav_gt_{batch_idx}', wav_gt, self.global_step, sr)
if self.global_step >= 0:
# with gt duration
model_out = self.run_model(sample, infer=True, infer_use_gt_dur=True)
# dur_info = self.get_plot_dur_info(sample, model_out)
# del dur_info['dur_pred']
dur_info = None
f0 = self.predict_f0(model_out['mel_out'])
wav_pred = self.vocoder.spec2wav(model_out['mel_out'][-1].cpu(), f0=f0)
self.logger.add_audio(f'wav_gdur_{batch_idx}', wav_pred, self.global_step, sr)
self.plot_mel(batch_idx, sample['mels'][-1:], model_out['mel_out'][-1], f'mel_gdur_{batch_idx}',
dur_info=dur_info, f0s={'f0': f0, 'f0g': f0_gt})
# with pred duration
if not hparams['use_gt_dur'] and not hparams['use_gt_latent']:
model_out = self.run_model(sample, infer=True, infer_use_gt_dur=False)
# dur_info = self.get_plot_dur_info(sample, model_out)
dur_info = None
f0 = self.predict_f0(model_out['mel_out'])
self.plot_mel(
batch_idx, sample['mels'], model_out['mel_out'][-1], f'mel_pdur_{batch_idx}',
dur_info=dur_info, f0s={'f0': f0, 'f0g': f0_gt})
wav_pred = self.vocoder.spec2wav(model_out['mel_out'][-1].cpu(), f0=f0)
self.logger.add_audio(f'wav_pdur_{batch_idx}', wav_pred, self.global_step, sr)
def get_plot_dur_info(self, sample, model_out):
T_txt = sample['txt_tokens'].shape[1]
dur_gt = mel2token_to_dur(sample['mel2ph'], T_txt)[-1]
dur_pred = model_out['dur'] if 'dur' in model_out else dur_gt
txt = self.token_encoder.decode(sample['txt_tokens'][-1].cpu().numpy())
txt = txt.split(" ")
return {'dur_gt': dur_gt, 'dur_pred': dur_pred, 'txt': txt}
def on_before_optimization(self, opt_idx):
if opt_idx == 0:
nn.utils.clip_grad_norm_(self.model_gen.parameters(), hparams['generator_grad_norm'])
else:
nn.utils.clip_grad_norm_(self.model_disc.parameters(), hparams["discriminator_grad_norm"])
def to(self, device=None, dtype=None):
super().to(device=device, dtype=dtype)
# trainer doesn't move ema to device automatically, we do it mannually
if hparams.get('use_ema', False):
self.ema.to(device=device, dtype=dtype)
def cuda(self,device):
super().cuda(device)
if hparams.get('use_ema', False):
self.ema.to(device=device)
@torch.no_grad()
def validation_step(self, sample, batch_idx):
infer_steps = self.hparams.get('infer_steps', 12)
outputs = self._validation_step(sample, batch_idx, infer_steps)
return outputs
def _validation_step(self, sample, batch_idx, infer_steps):
outputs = {}
if self.trainer.proc_rank == 0:
# self.vae.eval()
# with torch.inference_mode():
# with torch.cuda.amp.autocast(dtype=torch.bfloat16, enabled=True):
# lat = self.vae.get_latent(sample["mels"])
# lat_lens = latent_lengths.clamp(max=lat.size(1))
# mel = self.vae.decode(lat)
pass
# outputs['losses'], _ = self.run_model(sample)
# _, model_out = self.run_model(sample, infer=True, infer_steps=infer_steps)
# outputs = tensors_to_scalars(outputs)
# output_ldm = model_out['ldm_out']
# T = output_ldm.shape[1]
# ldm = sample['kps'][:, :T] # [B, T, nkp, kp_dim] [0, 1]
# B, T, nkp, kp_dim = ldm.shape
# output_ldm = self.denormalize_ldm(output_ldm)
# recon_ldm = model_out['recon_ldm']
# recon_ldm = self.denormalize_ldm(recon_ldm)
# results_dir = f"{hparams['work_dir']}/results/{self.global_step}_infersteps{infer_steps}_cfg{hparams['cfg_w']}"
# os.makedirs(results_dir, exist_ok=True)
# n_ctx = model_out['ctx_mask'][0, :, 0].sum().long().item()
# writer_kp = imageio.get_writer(f"{results_dir}/{batch_idx:06d}_kp.sil.mp4", fps=25)
# writer_gt = imageio.get_writer(f"{results_dir}/{batch_idx:06d}_gt.sil.mp4", fps=25)
# writer_pred = imageio.get_writer(f"{results_dir}/{batch_idx:06d}_pred.sil.mp4", fps=25)
# for i in range(T):
# img = self.draw_ldm(recon_ldm[0, i])
# writer_gt.append_data(img)
# img = self.draw_ldm(ldm[0, i])
# writer_kp.append_data(img)
# if i < n_ctx:
# writer_pred.append_data(img)
# else:
# img = self.draw_ldm(
# output_ldm[0, i], color=(255, 255, 0),
# )
# writer_pred.append_data(img)
# writer_gt.close()
# writer_kp.close()
# writer_pred.close()
return outputs
@torch.no_grad()
def test_step(self, sample, batch_idx):
infer_steps = hparams['infer_steps']
return self._validation_step(sample, batch_idx, infer_steps)