|
import argparse |
|
import filecmp |
|
import multiprocessing |
|
import os |
|
import subprocess |
|
import librosa |
|
from functools import partial |
|
from multiprocessing import Pool, Process |
|
|
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import torch.distributed as dist |
|
from torch.optim import AdamW |
|
|
|
from modules.vocoder.commons.stft_loss import MultiResolutionSTFTLoss |
|
from modules.vocoder.hifigan.hifigan import MultiPeriodDiscriminator, MultiScaleDiscriminator, \ |
|
generator_loss, feature_loss, discriminator_loss |
|
from modules.vocoder.hifigan.mel_utils import mel_spectrogram |
|
from modules.vocoder.univnet.mrd import MultiResolutionDiscriminator |
|
from modules.tts.wavvae.decoder.wavvae_v3 import WavVAE_V3 |
|
from tasks.tts.utils.audio import torch_wav2spec |
|
from tasks.tts.utils.audio.align import mel2token_to_dur |
|
from utils.commons.ckpt_utils import load_ckpt |
|
from utils.commons.hparams import hparams |
|
|
|
from attrdict import AttrDict |
|
from tasks.tts.dataset_mixin import TTSDatasetMixin |
|
from utils.commons.base_task import BaseTask |
|
from utils.commons.import_utils import import_module_bystr |
|
from utils.nn.schedulers import WarmupSchedule, CosineSchedule |
|
|
|
|
|
class WavVAETask(TTSDatasetMixin, BaseTask): |
|
def __init__(self): |
|
super().__init__() |
|
self.dataset_cls = import_module_bystr(hparams['dataset_cls']) |
|
self.val_dataset_cls = import_module_bystr(hparams['val_dataset_cls']) |
|
self.processer_fn = import_module_bystr(hparams['processer_fn']) |
|
self.build_fast_dataloader = import_module_bystr(hparams['build_fast_dataloader']) |
|
self.hparams = hparams |
|
self.config = AttrDict(hparams) |
|
|
|
|
|
sample_rate = hparams["audio_sample_rate"] |
|
fft_size = hparams["win_size"] |
|
win_size = hparams["win_size"] |
|
hop_size = hparams["hop_size"] |
|
num_mels = hparams["audio_num_mel_bins"] |
|
fmin = hparams["fmin"] |
|
fmax = hparams["fmax"] |
|
mel_basis = librosa.filters.mel( |
|
sr=sample_rate, n_fft=fft_size, n_mels=num_mels, fmin=fmin, fmax=fmax |
|
) |
|
self.torch_wav2spec_ = partial( |
|
torch_wav2spec, mel_basis=mel_basis, fft_size=fft_size, hop_size=hop_size, win_length=win_size, |
|
) |
|
|
|
def build_model(self): |
|
self.model_gen = WavVAE_V3(hparams=hparams) |
|
|
|
self.model_disc = torch.nn.ModuleDict() |
|
self.model_disc['mpd'] = MultiPeriodDiscriminator(hparams['mpd'], use_cond=hparams['use_cond_disc']) |
|
self.model_disc['msd'] = MultiScaleDiscriminator(use_cond=hparams['use_cond_disc']) |
|
if hparams['use_mrd']: |
|
self.model_disc['mrd'] = MultiResolutionDiscriminator(hparams) |
|
self.stft_loss = MultiResolutionSTFTLoss() |
|
|
|
load_ckpt(self.model_gen.encoder, './checkpoints/1231_megatts3_wavvae_v2_25hz', 'model.module.encoder', strict=False) |
|
load_ckpt(self.model_gen.decoder, './checkpoints/1117_melgan-nsf_full_1', 'model_gen', force=True, strict=True) |
|
load_ckpt(self.model_disc, './checkpoints/1117_melgan-nsf_full_1', 'model_disc', force=True, strict=True) |
|
return {'trainable': [self.model_gen, self.model_disc['mpd'], self.model_disc['msd'], self.model_disc['mrd']], 'others': []} |
|
|
|
def load_model(self): |
|
if hparams.get('load_ckpt', '') != '': |
|
load_ckpt(self.model, hparams['load_ckpt'], 'model', strict=False) |
|
|
|
def build_optimizer(self): |
|
optimizer_gen = torch.optim.AdamW(self.model_gen.parameters(), lr=hparams['lr'], |
|
betas=[hparams['adam_b1'], hparams['adam_b2']]) |
|
optimizer_disc = torch.optim.AdamW(self.model_disc.parameters(), |
|
lr=hparams.get('disc_lr', hparams['lr']), |
|
betas=[hparams['adam_b1'], hparams['adam_b2']]) |
|
return [optimizer_gen, optimizer_disc] |
|
|
|
def build_scheduler(self, optimizer): |
|
return None |
|
|
|
def _training_step(self, sample, batch_idx, optimizer_idx): |
|
log_outputs = {} |
|
loss_weights = {} |
|
sample['wavs'] = sample['wavs'].float() |
|
|
|
|
|
if self.global_step % 100 == 0: |
|
devices = os.environ.get('CUDA_VISIBLE_DEVICES', '').split(",") |
|
for d in devices: |
|
os.system(f'pkill -f "voidgpu{d}"') |
|
|
|
y = sample['wavs'] |
|
loss_output = {} |
|
if optimizer_idx == 0: |
|
|
|
|
|
|
|
y_, posterior = self.model_gen(y) |
|
y = y.unsqueeze(1) |
|
y_mel = mel_spectrogram(y.squeeze(1), hparams).transpose(1, 2) |
|
y_hat_mel = mel_spectrogram(y_.squeeze(1), hparams).transpose(1, 2) |
|
loss_output['mel'] = F.l1_loss(y_hat_mel, y_mel) * hparams['lambda_mel'] |
|
if self.training: |
|
_, y_p_hat_g, fmap_f_r, fmap_f_g = self.model_disc['mpd'](y, y_, None) |
|
_, y_s_hat_g, fmap_s_r, fmap_s_g = self.model_disc['msd'](y, y_, None) |
|
loss_output['a_p'] = generator_loss(y_p_hat_g) * hparams['lambda_adv'] * hparams.get('lambda_mpd', 1.0) |
|
loss_output['a_s'] = generator_loss(y_s_hat_g) * hparams['lambda_adv'] * hparams.get('lambda_msd', 1.0) |
|
if hparams['use_mrd']: |
|
y_r_hat_g = [x[1] for x in self.model_disc['mrd'](y_)] |
|
loss_output['a_r'] = generator_loss(y_r_hat_g) \ |
|
* hparams['lambda_adv'] * hparams.get('lambda_mrd', 1.0) |
|
if hparams['use_ms_stft']: |
|
loss_output['sc'], loss_output['mag'] = self.stft_loss(y.squeeze(1), y_.squeeze(1)) |
|
loss_output['kl_loss'] = posterior.kl().mean() * hparams.get('lambda_kl', 1.0) |
|
self.y_ = y_.detach() |
|
else: |
|
|
|
|
|
|
|
if not self.training: |
|
return None |
|
y = y.unsqueeze(1) |
|
y_ = self.y_ |
|
|
|
y_p_hat_r, y_p_hat_g, _, _ = self.model_disc['mpd'](y, y_.detach(), None) |
|
loss_output['r_p'], loss_output['f_p'] = discriminator_loss(y_p_hat_r, y_p_hat_g) |
|
|
|
y_s_hat_r, y_s_hat_g, _, _ = self.model_disc['msd'](y, y_.detach(), None) |
|
loss_output['r_s'], loss_output['f_s'] = discriminator_loss(y_s_hat_r, y_s_hat_g) |
|
|
|
if hparams['use_mrd']: |
|
y_r_hat_r = [x[1] for x in self.model_disc['mrd'](y)] |
|
y_r_hat_g = [x[1] for x in self.model_disc['mrd'](y_.detach())] |
|
loss_output['r_r'], loss_output['f_r'] = discriminator_loss(y_r_hat_r, y_r_hat_g) |
|
total_loss = sum(loss_output.values()) |
|
loss_output['bs'] = sample['wavs'].shape[0] |
|
return total_loss, loss_output |
|
|
|
def save_valid_result(self, sample, batch_idx, model_out): |
|
sr = hparams['audio_sample_rate'] |
|
mel_out = model_out.get('mel_out') |
|
f0 = sample.get('f0') |
|
f0_gt = sample.get('f0') |
|
if f0 is not None: |
|
f0_gt = f0_gt.cpu()[-1] |
|
if mel_out is not None: |
|
f0_pred = self.predict_f0(sample['mels']) |
|
self.plot_mel(batch_idx, sample['mels'], mel_out, f0s={'f0': f0_pred, 'f0g': f0_gt}) |
|
|
|
if self.global_step <= hparams['valid_infer_interval']: |
|
mel_gt = sample['mels'][-1].cpu() |
|
f0 = self.predict_f0(sample['mels'][-1:]) |
|
wav_gt = self.vocoder.spec2wav(mel_gt, f0=f0) |
|
self.logger.add_audio(f'wav_gt_{batch_idx}', wav_gt, self.global_step, sr) |
|
|
|
if self.global_step >= 0: |
|
|
|
model_out = self.run_model(sample, infer=True, infer_use_gt_dur=True) |
|
|
|
|
|
dur_info = None |
|
|
|
f0 = self.predict_f0(model_out['mel_out']) |
|
wav_pred = self.vocoder.spec2wav(model_out['mel_out'][-1].cpu(), f0=f0) |
|
self.logger.add_audio(f'wav_gdur_{batch_idx}', wav_pred, self.global_step, sr) |
|
self.plot_mel(batch_idx, sample['mels'][-1:], model_out['mel_out'][-1], f'mel_gdur_{batch_idx}', |
|
dur_info=dur_info, f0s={'f0': f0, 'f0g': f0_gt}) |
|
|
|
|
|
if not hparams['use_gt_dur'] and not hparams['use_gt_latent']: |
|
model_out = self.run_model(sample, infer=True, infer_use_gt_dur=False) |
|
|
|
dur_info = None |
|
f0 = self.predict_f0(model_out['mel_out']) |
|
self.plot_mel( |
|
batch_idx, sample['mels'], model_out['mel_out'][-1], f'mel_pdur_{batch_idx}', |
|
dur_info=dur_info, f0s={'f0': f0, 'f0g': f0_gt}) |
|
wav_pred = self.vocoder.spec2wav(model_out['mel_out'][-1].cpu(), f0=f0) |
|
self.logger.add_audio(f'wav_pdur_{batch_idx}', wav_pred, self.global_step, sr) |
|
|
|
def get_plot_dur_info(self, sample, model_out): |
|
T_txt = sample['txt_tokens'].shape[1] |
|
dur_gt = mel2token_to_dur(sample['mel2ph'], T_txt)[-1] |
|
dur_pred = model_out['dur'] if 'dur' in model_out else dur_gt |
|
txt = self.token_encoder.decode(sample['txt_tokens'][-1].cpu().numpy()) |
|
txt = txt.split(" ") |
|
return {'dur_gt': dur_gt, 'dur_pred': dur_pred, 'txt': txt} |
|
|
|
def on_before_optimization(self, opt_idx): |
|
if opt_idx == 0: |
|
nn.utils.clip_grad_norm_(self.model_gen.parameters(), hparams['generator_grad_norm']) |
|
else: |
|
nn.utils.clip_grad_norm_(self.model_disc.parameters(), hparams["discriminator_grad_norm"]) |
|
|
|
def to(self, device=None, dtype=None): |
|
super().to(device=device, dtype=dtype) |
|
|
|
if hparams.get('use_ema', False): |
|
self.ema.to(device=device, dtype=dtype) |
|
|
|
def cuda(self,device): |
|
super().cuda(device) |
|
if hparams.get('use_ema', False): |
|
self.ema.to(device=device) |
|
|
|
@torch.no_grad() |
|
def validation_step(self, sample, batch_idx): |
|
infer_steps = self.hparams.get('infer_steps', 12) |
|
outputs = self._validation_step(sample, batch_idx, infer_steps) |
|
return outputs |
|
|
|
def _validation_step(self, sample, batch_idx, infer_steps): |
|
outputs = {} |
|
if self.trainer.proc_rank == 0: |
|
|
|
|
|
|
|
|
|
|
|
|
|
pass |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return outputs |
|
|
|
@torch.no_grad() |
|
def test_step(self, sample, batch_idx): |
|
infer_steps = hparams['infer_steps'] |
|
return self._validation_step(sample, batch_idx, infer_steps) |