|
|
|
|
|
|
|
|
|
|
|
import os |
|
import numpy as np |
|
import torch |
|
import torchaudio |
|
import argparse |
|
|
|
|
|
from text.g2p_module import G2PModule |
|
from utils.tokenizer import AudioTokenizer, tokenize_audio |
|
from models.tts.valle.valle import VALLE |
|
from models.tts.base.tts_inferece import TTSInference |
|
from models.tts.valle.valle_dataset import VALLETestDataset, VALLETestCollator |
|
from processors.phone_extractor import phoneExtractor |
|
from text.text_token_collation import phoneIDCollation |
|
|
|
|
|
class VALLEInference(TTSInference): |
|
def __init__(self, args=None, cfg=None): |
|
TTSInference.__init__(self, args, cfg) |
|
|
|
self.g2p_module = G2PModule(backend=self.cfg.preprocess.phone_extractor) |
|
text_token_path = os.path.join( |
|
cfg.preprocess.processed_dir, cfg.dataset[0], cfg.preprocess.symbols_dict |
|
) |
|
self.audio_tokenizer = AudioTokenizer() |
|
|
|
def _build_model(self): |
|
model = VALLE(self.cfg.model) |
|
return model |
|
|
|
def _build_test_dataset(self): |
|
return VALLETestDataset, VALLETestCollator |
|
|
|
def inference_one_clip(self, text, text_prompt, audio_file, save_name="pred"): |
|
|
|
phone_symbol_file = None |
|
if self.cfg.preprocess.phone_extractor != "lexicon": |
|
phone_symbol_file = os.path.join( |
|
self.exp_dir, self.cfg.preprocess.symbols_dict |
|
) |
|
assert os.path.exists(phone_symbol_file) |
|
|
|
phone_extractor = phoneExtractor(self.cfg) |
|
|
|
phon_id_collator = phoneIDCollation( |
|
self.cfg, symbols_dict_file=phone_symbol_file |
|
) |
|
|
|
text = f"{text_prompt} {text}".strip() |
|
phone_seq = phone_extractor.extract_phone(text) |
|
phone_id_seq = phon_id_collator.get_phone_id_sequence(self.cfg, phone_seq) |
|
phone_id_seq_len = torch.IntTensor([len(phone_id_seq)]).to(self.device) |
|
|
|
|
|
phone_id_seq = np.array([phone_id_seq]) |
|
phone_id_seq = torch.from_numpy(phone_id_seq).to(self.device) |
|
|
|
|
|
encoded_frames = tokenize_audio(self.audio_tokenizer, audio_file) |
|
audio_prompt_token = encoded_frames[0][0].transpose(2, 1).to(self.device) |
|
|
|
|
|
if self.args.copysyn: |
|
samples = self.audio_tokenizer.decode(encoded_frames) |
|
audio_copysyn = samples[0].cpu().detach() |
|
|
|
out_path = os.path.join( |
|
self.args.output_dir, self.infer_type, f"{save_name}_copysyn.wav" |
|
) |
|
torchaudio.save(out_path, audio_copysyn, self.cfg.preprocess.sampling_rate) |
|
|
|
if self.args.continual: |
|
encoded_frames = self.model.continual( |
|
phone_id_seq, |
|
phone_id_seq_len, |
|
audio_prompt_token, |
|
) |
|
else: |
|
enroll_x_lens = None |
|
if text_prompt: |
|
|
|
|
|
|
|
text = f"{text_prompt}".strip() |
|
prompt_phone_seq = phone_extractor.extract_phone( |
|
text |
|
) |
|
prompt_phone_id_seq = phon_id_collator.get_phone_id_sequence( |
|
self.cfg, prompt_phone_seq |
|
) |
|
prompt_phone_id_seq_len = torch.IntTensor( |
|
[len(prompt_phone_id_seq)] |
|
).to(self.device) |
|
|
|
encoded_frames = self.model.inference( |
|
phone_id_seq, |
|
phone_id_seq_len, |
|
audio_prompt_token, |
|
enroll_x_lens=prompt_phone_id_seq_len, |
|
top_k=self.args.top_k, |
|
temperature=self.args.temperature, |
|
) |
|
|
|
samples = self.audio_tokenizer.decode([(encoded_frames.transpose(2, 1), None)]) |
|
|
|
audio = samples[0].squeeze(0).cpu().detach() |
|
|
|
return audio |
|
|
|
def inference_for_single_utterance(self): |
|
text = self.args.text |
|
text_prompt = self.args.text_prompt |
|
audio_file = self.args.audio_prompt |
|
|
|
if not self.args.continual: |
|
assert text != "" |
|
else: |
|
text = "" |
|
assert text_prompt != "" |
|
assert audio_file != "" |
|
|
|
audio = self.inference_one_clip(text, text_prompt, audio_file) |
|
|
|
return audio |
|
|
|
def inference_for_batches(self): |
|
test_list_file = self.args.test_list_file |
|
assert test_list_file is not None |
|
|
|
pred_res = [] |
|
with open(test_list_file, "r") as fin: |
|
for idx, line in enumerate(fin.readlines()): |
|
fields = line.strip().split("|") |
|
if self.args.continual: |
|
assert len(fields) == 2 |
|
text_prompt, audio_prompt_path = fields |
|
text = "" |
|
else: |
|
assert len(fields) == 3 |
|
text_prompt, audio_prompt_path, text = fields |
|
|
|
audio = self.inference_one_clip( |
|
text, text_prompt, audio_prompt_path, str(idx) |
|
) |
|
pred_res.append(audio) |
|
|
|
return pred_res |
|
|
|
""" |
|
TODO: batch inference |
|
###### Construct test_batch ###### |
|
n_batch = len(self.test_dataloader) |
|
now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time())) |
|
print( |
|
"Model eval time: {}, batch_size = {}, n_batch = {}".format( |
|
now, self.test_batch_size, n_batch |
|
) |
|
) |
|
|
|
###### Inference for each batch ###### |
|
pred_res = [] |
|
with torch.no_grad(): |
|
for i, batch_data in enumerate( |
|
self.test_dataloader if n_batch == 1 else tqdm(self.test_dataloader) |
|
): |
|
if self.args.continual: |
|
encoded_frames = self.model.continual( |
|
batch_data["phone_seq"], |
|
batch_data["phone_len"], |
|
batch_data["acoustic_token"], |
|
) |
|
else: |
|
encoded_frames = self.model.inference( |
|
batch_data["phone_seq"], |
|
batch_data["phone_len"], |
|
batch_data["acoustic_token"], |
|
enroll_x_lens=batch_data["pmt_phone_len"], |
|
top_k=self.args.top_k, |
|
temperature=self.args.temperature |
|
) |
|
|
|
samples = self.audio_tokenizer.decode( |
|
[(encoded_frames.transpose(2, 1), None)] |
|
) |
|
|
|
|
|
for idx in range(samples.size(0)): |
|
audio = samples[idx].cpu() |
|
pred_res.append(audio) |
|
|
|
return pred_res |
|
""" |
|
|
|
def add_arguments(parser: argparse.ArgumentParser): |
|
parser.add_argument( |
|
"--text_prompt", |
|
type=str, |
|
default="", |
|
help="Text prompt that should be aligned with --audio_prompt.", |
|
) |
|
|
|
parser.add_argument( |
|
"--audio_prompt", |
|
type=str, |
|
default="", |
|
help="Audio prompt that should be aligned with --text_prompt.", |
|
) |
|
parser.add_argument( |
|
"--top-k", |
|
type=int, |
|
default=-100, |
|
help="Whether AR Decoder do top_k(if > 0) sampling.", |
|
) |
|
|
|
parser.add_argument( |
|
"--temperature", |
|
type=float, |
|
default=1.0, |
|
help="The temperature of AR Decoder top_k sampling.", |
|
) |
|
|
|
parser.add_argument( |
|
"--continual", |
|
action="store_true", |
|
help="Inference for continual task.", |
|
) |
|
|
|
parser.add_argument( |
|
"--copysyn", |
|
action="store_true", |
|
help="Copysyn: generate audio by decoder of the original audio tokenizer.", |
|
) |
|
|