SyntaSpeech / tasks /tts /diffspeech.py
yerfor's picture
init
22871e7
import torch
from modules.tts.diffspeech.shallow_diffusion_tts import GaussianDiffusion
from tasks.tts.fs2_orig import FastSpeech2OrigTask
import utils
from utils.commons.hparams import hparams
from utils.commons.ckpt_utils import load_ckpt
from utils.audio.pitch.utils import denorm_f0
class DiffSpeechTask(FastSpeech2OrigTask):
def build_tts_model(self):
# get min and max
# import torch
# from tqdm import tqdm
# v_min = torch.ones([80]) * 100
# v_max = torch.ones([80]) * -100
# for i, ds in enumerate(tqdm(self.dataset_cls('train'))):
# v_max = torch.max(torch.max(ds['mel'].reshape(-1, 80), 0)[0], v_max)
# v_min = torch.min(torch.min(ds['mel'].reshape(-1, 80), 0)[0], v_min)
# if i % 100 == 0:
# print(i, v_min, v_max)
# print('final', v_min, v_max)
dict_size = len(self.token_encoder)
self.model = GaussianDiffusion(dict_size, hparams)
if hparams['fs2_ckpt'] != '':
load_ckpt(self.model.fs2, hparams['fs2_ckpt'], 'model', strict=True)
# for k, v in self.model.fs2.named_parameters():
# if 'predictor' not in k:
# v.requires_grad = False
# or
for k, v in self.model.fs2.named_parameters():
v.requires_grad = False
def build_optimizer(self, model):
self.optimizer = optimizer = torch.optim.AdamW(
filter(lambda p: p.requires_grad, model.parameters()),
lr=hparams['lr'],
betas=(hparams['optimizer_adam_beta1'], hparams['optimizer_adam_beta2']),
weight_decay=hparams['weight_decay'])
return optimizer
def build_scheduler(self, optimizer):
return torch.optim.lr_scheduler.StepLR(optimizer, hparams['decay_steps'], gamma=0.5)
def run_model(self, sample, infer=False, *args, **kwargs):
txt_tokens = sample['txt_tokens'] # [B, T_t]
spk_embed = sample.get('spk_embed')
spk_id = sample.get('spk_ids')
if not infer:
target = sample['mels'] # [B, T_s, 80]
mel2ph = sample['mel2ph'] # [B, T_s]
f0 = sample.get('f0')
uv = sample.get('uv')
output = self.model(txt_tokens, mel2ph=mel2ph, spk_embed=spk_embed, spk_id=spk_id,
ref_mels=target, f0=f0, uv=uv, infer=False)
losses = {}
if 'diff_loss' in output:
losses['mel'] = output['diff_loss']
self.add_dur_loss(output['dur'], mel2ph, txt_tokens, losses=losses)
if hparams['use_pitch_embed']:
self.add_pitch_loss(output, sample, losses)
return losses, output
else:
use_gt_dur = kwargs.get('infer_use_gt_dur', hparams['use_gt_dur'])
use_gt_f0 = kwargs.get('infer_use_gt_f0', hparams['use_gt_f0'])
mel2ph, uv, f0 = None, None, None
if use_gt_dur:
mel2ph = sample['mel2ph']
if use_gt_f0:
f0 = sample['f0']
uv = sample['uv']
output = self.model(txt_tokens, mel2ph=mel2ph, spk_embed=spk_embed, spk_id=spk_id,
ref_mels=None, f0=f0, uv=uv, infer=True)
return output
def save_valid_result(self, sample, batch_idx, model_out):
sr = hparams['audio_sample_rate']
f0_gt = None
# mel_out = model_out['mel_out']
if sample.get('f0') is not None:
f0_gt = denorm_f0(sample['f0'][0].cpu(), sample['uv'][0].cpu())
# self.plot_mel(batch_idx, sample['mels'], mel_out, f0s=f0_gt)
if self.global_step > 0:
# wav_pred = self.vocoder.spec2wav(mel_out[0].cpu(), f0=f0_gt)
# self.logger.add_audio(f'wav_val_{batch_idx}', wav_pred, self.global_step, sr)
# with gt duration
model_out = self.run_model(sample, infer=True, infer_use_gt_dur=True)
dur_info = self.get_plot_dur_info(sample, model_out)
del dur_info['dur_pred']
wav_pred = self.vocoder.spec2wav(model_out['mel_out'][0].cpu(), f0=f0_gt)
self.logger.add_audio(f'wav_gdur_{batch_idx}', wav_pred, self.global_step, sr)
self.plot_mel(batch_idx, sample['mels'], model_out['mel_out'][0], f'diffmel_gdur_{batch_idx}',
dur_info=dur_info, f0s=f0_gt)
self.plot_mel(batch_idx, sample['mels'], model_out['fs2_mel'][0], f'fs2mel_gdur_{batch_idx}',
dur_info=dur_info, f0s=f0_gt) # gt mel vs. fs2 mel
# with pred duration
if not hparams['use_gt_dur']:
model_out = self.run_model(sample, infer=True, infer_use_gt_dur=False)
dur_info = self.get_plot_dur_info(sample, model_out)
self.plot_mel(batch_idx, sample['mels'], model_out['mel_out'][0], f'mel_pdur_{batch_idx}',
dur_info=dur_info, f0s=f0_gt)
wav_pred = self.vocoder.spec2wav(model_out['mel_out'][0].cpu(), f0=f0_gt)
self.logger.add_audio(f'wav_pdur_{batch_idx}', wav_pred, self.global_step, sr)
# gt wav
if self.global_step <= hparams['valid_infer_interval']:
mel_gt = sample['mels'][0].cpu()
wav_gt = self.vocoder.spec2wav(mel_gt, f0=f0_gt)
self.logger.add_audio(f'wav_gt_{batch_idx}', wav_gt, self.global_step, sr)