Spaces:
Build error
Build error
import torch | |
from modules.tts.diffspeech.shallow_diffusion_tts import GaussianDiffusion | |
from tasks.tts.fs2_orig import FastSpeech2OrigTask | |
import utils | |
from utils.commons.hparams import hparams | |
from utils.commons.ckpt_utils import load_ckpt | |
from utils.audio.pitch.utils import denorm_f0 | |
class DiffSpeechTask(FastSpeech2OrigTask): | |
def build_tts_model(self): | |
# get min and max | |
# import torch | |
# from tqdm import tqdm | |
# v_min = torch.ones([80]) * 100 | |
# v_max = torch.ones([80]) * -100 | |
# for i, ds in enumerate(tqdm(self.dataset_cls('train'))): | |
# v_max = torch.max(torch.max(ds['mel'].reshape(-1, 80), 0)[0], v_max) | |
# v_min = torch.min(torch.min(ds['mel'].reshape(-1, 80), 0)[0], v_min) | |
# if i % 100 == 0: | |
# print(i, v_min, v_max) | |
# print('final', v_min, v_max) | |
dict_size = len(self.token_encoder) | |
self.model = GaussianDiffusion(dict_size, hparams) | |
if hparams['fs2_ckpt'] != '': | |
load_ckpt(self.model.fs2, hparams['fs2_ckpt'], 'model', strict=True) | |
# for k, v in self.model.fs2.named_parameters(): | |
# if 'predictor' not in k: | |
# v.requires_grad = False | |
# or | |
for k, v in self.model.fs2.named_parameters(): | |
v.requires_grad = False | |
def build_optimizer(self, model): | |
self.optimizer = optimizer = torch.optim.AdamW( | |
filter(lambda p: p.requires_grad, model.parameters()), | |
lr=hparams['lr'], | |
betas=(hparams['optimizer_adam_beta1'], hparams['optimizer_adam_beta2']), | |
weight_decay=hparams['weight_decay']) | |
return optimizer | |
def build_scheduler(self, optimizer): | |
return torch.optim.lr_scheduler.StepLR(optimizer, hparams['decay_steps'], gamma=0.5) | |
def run_model(self, sample, infer=False, *args, **kwargs): | |
txt_tokens = sample['txt_tokens'] # [B, T_t] | |
spk_embed = sample.get('spk_embed') | |
spk_id = sample.get('spk_ids') | |
if not infer: | |
target = sample['mels'] # [B, T_s, 80] | |
mel2ph = sample['mel2ph'] # [B, T_s] | |
f0 = sample.get('f0') | |
uv = sample.get('uv') | |
output = self.model(txt_tokens, mel2ph=mel2ph, spk_embed=spk_embed, spk_id=spk_id, | |
ref_mels=target, f0=f0, uv=uv, infer=False) | |
losses = {} | |
if 'diff_loss' in output: | |
losses['mel'] = output['diff_loss'] | |
self.add_dur_loss(output['dur'], mel2ph, txt_tokens, losses=losses) | |
if hparams['use_pitch_embed']: | |
self.add_pitch_loss(output, sample, losses) | |
return losses, output | |
else: | |
use_gt_dur = kwargs.get('infer_use_gt_dur', hparams['use_gt_dur']) | |
use_gt_f0 = kwargs.get('infer_use_gt_f0', hparams['use_gt_f0']) | |
mel2ph, uv, f0 = None, None, None | |
if use_gt_dur: | |
mel2ph = sample['mel2ph'] | |
if use_gt_f0: | |
f0 = sample['f0'] | |
uv = sample['uv'] | |
output = self.model(txt_tokens, mel2ph=mel2ph, spk_embed=spk_embed, spk_id=spk_id, | |
ref_mels=None, f0=f0, uv=uv, infer=True) | |
return output | |
def save_valid_result(self, sample, batch_idx, model_out): | |
sr = hparams['audio_sample_rate'] | |
f0_gt = None | |
# mel_out = model_out['mel_out'] | |
if sample.get('f0') is not None: | |
f0_gt = denorm_f0(sample['f0'][0].cpu(), sample['uv'][0].cpu()) | |
# self.plot_mel(batch_idx, sample['mels'], mel_out, f0s=f0_gt) | |
if self.global_step > 0: | |
# wav_pred = self.vocoder.spec2wav(mel_out[0].cpu(), f0=f0_gt) | |
# self.logger.add_audio(f'wav_val_{batch_idx}', wav_pred, self.global_step, sr) | |
# with gt duration | |
model_out = self.run_model(sample, infer=True, infer_use_gt_dur=True) | |
dur_info = self.get_plot_dur_info(sample, model_out) | |
del dur_info['dur_pred'] | |
wav_pred = self.vocoder.spec2wav(model_out['mel_out'][0].cpu(), f0=f0_gt) | |
self.logger.add_audio(f'wav_gdur_{batch_idx}', wav_pred, self.global_step, sr) | |
self.plot_mel(batch_idx, sample['mels'], model_out['mel_out'][0], f'diffmel_gdur_{batch_idx}', | |
dur_info=dur_info, f0s=f0_gt) | |
self.plot_mel(batch_idx, sample['mels'], model_out['fs2_mel'][0], f'fs2mel_gdur_{batch_idx}', | |
dur_info=dur_info, f0s=f0_gt) # gt mel vs. fs2 mel | |
# with pred duration | |
if not hparams['use_gt_dur']: | |
model_out = self.run_model(sample, infer=True, infer_use_gt_dur=False) | |
dur_info = self.get_plot_dur_info(sample, model_out) | |
self.plot_mel(batch_idx, sample['mels'], model_out['mel_out'][0], f'mel_pdur_{batch_idx}', | |
dur_info=dur_info, f0s=f0_gt) | |
wav_pred = self.vocoder.spec2wav(model_out['mel_out'][0].cpu(), f0=f0_gt) | |
self.logger.add_audio(f'wav_pdur_{batch_idx}', wav_pred, self.global_step, sr) | |
# gt wav | |
if self.global_step <= hparams['valid_infer_interval']: | |
mel_gt = sample['mels'][0].cpu() | |
wav_gt = self.vocoder.spec2wav(mel_gt, f0=f0_gt) | |
self.logger.add_audio(f'wav_gt_{batch_idx}', wav_gt, self.global_step, sr) | |