Spaces:
Running
Running
# Copyright (c) 2023 Amphion. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import os | |
import torch | |
from tqdm import tqdm | |
from collections import OrderedDict | |
from models.tts.base.tts_inferece import TTSInference | |
from models.tts.fastspeech2.fs2_dataset import FS2TestDataset, FS2TestCollator | |
from utils.util import load_config | |
from utils.io import save_audio | |
from models.tts.fastspeech2.fs2 import FastSpeech2 | |
from models.vocoders.vocoder_inference import synthesis | |
from pathlib import Path | |
class FastSpeech2Inference(TTSInference): | |
def __init__(self, args, cfg): | |
TTSInference.__init__(self, args, cfg) | |
self.args = args | |
self.cfg = cfg | |
self.infer_type = args.mode | |
def _build_model(self): | |
self.model = FastSpeech2(self.cfg) | |
return self.model | |
def load_model(self, state_dict): | |
raw_dict = state_dict["model"] | |
clean_dict = OrderedDict() | |
for k, v in raw_dict.items(): | |
if k.startswith("module."): | |
clean_dict[k[7:]] = v | |
else: | |
clean_dict[k] = v | |
self.model.load_state_dict(clean_dict) | |
def _build_test_dataset(self): | |
return FS2TestDataset, FS2TestCollator | |
def _parse_vocoder(vocoder_dir): | |
r"""Parse vocoder config""" | |
vocoder_dir = os.path.abspath(vocoder_dir) | |
ckpt_list = [ckpt for ckpt in Path(vocoder_dir).glob("*.pt")] | |
# last step (different from the base *int(x.stem)*) | |
ckpt_list.sort( | |
key=lambda x: int(x.stem.split("_")[-2].split("-")[-1]), reverse=True | |
) | |
ckpt_path = str(ckpt_list[0]) | |
vocoder_cfg = load_config( | |
os.path.join(vocoder_dir, "args.json"), lowercase=True | |
) | |
return vocoder_cfg, ckpt_path | |
def inference(self): | |
y_pred = [] | |
for i, batch in tqdm(enumerate(self.test_dataloader)): | |
y_pred, mel_lens, _ = self._inference_each_batch(batch) | |
y_ls = y_pred.chunk(self.test_batch_size) | |
tgt_ls = mel_lens.chunk(self.test_batch_size) | |
j = 0 | |
for it, l in zip(y_ls, tgt_ls): | |
l = l.item() | |
it = it.squeeze(0)[:l].detach().cpu() | |
uid = self.test_dataset.metadata[i * self.test_batch_size + j]["Uid"] | |
torch.save(it, os.path.join(self.args.output_dir, f"{uid}.pt")) | |
j += 1 | |
vocoder_cfg, vocoder_ckpt = self._parse_vocoder(self.args.vocoder_dir) | |
res = synthesis( | |
cfg=vocoder_cfg, | |
vocoder_weight_file=vocoder_ckpt, | |
n_samples=None, | |
pred=[ | |
torch.load( | |
os.path.join(self.args.output_dir, "{}.pt".format(item["Uid"])) | |
).numpy() | |
for item in self.test_dataset.metadata | |
], | |
) | |
for it, wav in zip(self.test_dataset.metadata, res): | |
uid = it["Uid"] | |
save_audio( | |
os.path.join(self.args.output_dir, f"{uid}.wav"), | |
wav.numpy(), | |
self.cfg.preprocess.sample_rate, | |
add_silence=True, | |
turn_up=True, | |
) | |
os.remove(os.path.join(self.args.output_dir, f"{uid}.pt")) | |
def _inference_each_batch(self, batch_data): | |
device = self.accelerator.device | |
control_values = ( | |
self.args.pitch_control, | |
self.args.energy_control, | |
self.args.duration_control, | |
) | |
for k, v in batch_data.items(): | |
batch_data[k] = v.to(device) | |
pitch_control, energy_control, duration_control = control_values | |
output = self.model( | |
batch_data, | |
p_control=pitch_control, | |
e_control=energy_control, | |
d_control=duration_control, | |
) | |
pred_res = output["postnet_output"] | |
mel_lens = output["mel_lens"].cpu() | |
return pred_res, mel_lens, 0 | |