RayeRen's picture
update
53fa903
raw
history blame
1.09 kB
import torch
# from inference.tts.fs import FastSpeechInfer
# from modules.tts.fs2_orig import FastSpeech2Orig
from inference.tts.base_tts_infer import BaseTTSInfer
from modules.tts.diffspeech.shallow_diffusion_tts import GaussianDiffusion
from utils.commons.ckpt_utils import load_ckpt
from utils.commons.hparams import hparams
class DiffSpeechInfer(BaseTTSInfer):
def build_model(self):
dict_size = len(self.ph_encoder)
model = GaussianDiffusion(dict_size, self.hparams)
model.eval()
load_ckpt(model, hparams['work_dir'], 'model')
return model
def forward_model(self, inp):
sample = self.input_to_batch(inp)
txt_tokens = sample['txt_tokens'] # [B, T_t]
spk_id = sample.get('spk_ids')
with torch.no_grad():
output = self.model(txt_tokens, spk_id=spk_id, ref_mels=None, infer=True)
mel_out = output['mel_out']
wav_out = self.run_vocoder(mel_out)
wav_out = wav_out.cpu().numpy()
return wav_out[0]
if __name__ == '__main__':
DiffSpeechInfer.example_run()