|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
import os |
|
import torch |
|
import soundfile as sf |
|
import numpy as np |
|
|
|
from models.tts.naturalspeech2.ns2 import NaturalSpeech2 |
|
from encodec import EncodecModel |
|
from encodec.utils import convert_audio |
|
from utils.util import load_config |
|
|
|
from text import text_to_sequence |
|
from text.cmudict import valid_symbols |
|
from text.g2p import preprocess_english, read_lexicon |
|
|
|
import torchaudio |
|
|
|
|
|
class NS2Inference: |
|
def __init__(self, args, cfg): |
|
self.cfg = cfg |
|
self.args = args |
|
|
|
self.model = self.build_model() |
|
self.codec = self.build_codec() |
|
|
|
self.symbols = valid_symbols + ["sp", "spn", "sil"] + ["<s>", "</s>"] |
|
self.phone2id = {s: i for i, s in enumerate(self.symbols)} |
|
self.id2phone = {i: s for s, i in self.phone2id.items()} |
|
|
|
def build_model(self): |
|
model = NaturalSpeech2(self.cfg.model) |
|
model.load_state_dict( |
|
torch.load( |
|
os.path.join(self.args.checkpoint_path, "pytorch_model.bin"), |
|
map_location="cpu", |
|
) |
|
) |
|
model = model.to(self.args.device) |
|
return model |
|
|
|
def build_codec(self): |
|
encodec_model = EncodecModel.encodec_model_24khz() |
|
encodec_model = encodec_model.to(device=self.args.device) |
|
encodec_model.set_target_bandwidth(12.0) |
|
return encodec_model |
|
|
|
def get_ref_code(self): |
|
ref_wav_path = self.args.ref_audio |
|
ref_wav, sr = torchaudio.load(ref_wav_path) |
|
ref_wav = convert_audio( |
|
ref_wav, sr, self.codec.sample_rate, self.codec.channels |
|
) |
|
ref_wav = ref_wav.unsqueeze(0).to(device=self.args.device) |
|
|
|
with torch.no_grad(): |
|
encoded_frames = self.codec.encode(ref_wav) |
|
ref_code = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1) |
|
|
|
|
|
ref_mask = torch.ones(ref_code.shape[0], ref_code.shape[-1]).to(ref_code.device) |
|
|
|
|
|
return ref_code, ref_mask |
|
|
|
def inference(self): |
|
ref_code, ref_mask = self.get_ref_code() |
|
|
|
lexicon = read_lexicon(self.cfg.preprocess.lexicon_path) |
|
phone_seq = preprocess_english(self.args.text, lexicon) |
|
print(phone_seq) |
|
|
|
phone_id = np.array( |
|
[ |
|
*map( |
|
self.phone2id.get, |
|
phone_seq.replace("{", "").replace("}", "").split(), |
|
) |
|
] |
|
) |
|
phone_id = torch.from_numpy(phone_id).unsqueeze(0).to(device=self.args.device) |
|
print(phone_id) |
|
|
|
x0, prior_out = self.model.inference( |
|
ref_code, phone_id, ref_mask, self.args.inference_step |
|
) |
|
print(prior_out["dur_pred"]) |
|
print(prior_out["dur_pred_round"]) |
|
print(torch.sum(prior_out["dur_pred_round"])) |
|
|
|
latent_ref = self.codec.quantizer.vq.decode(ref_code.transpose(0, 1)) |
|
|
|
rec_wav = self.codec.decoder(x0) |
|
|
|
|
|
os.makedirs(self.args.output_dir, exist_ok=True) |
|
|
|
sf.write( |
|
"{}/{}.wav".format( |
|
self.args.output_dir, self.args.text.replace(" ", "_", 100) |
|
), |
|
rec_wav[0, 0].detach().cpu().numpy(), |
|
samplerate=24000, |
|
) |
|
|
|
def add_arguments(parser: argparse.ArgumentParser): |
|
parser.add_argument( |
|
"--ref_audio", |
|
type=str, |
|
default="", |
|
help="Reference audio path", |
|
) |
|
parser.add_argument( |
|
"--device", |
|
type=str, |
|
default="cuda", |
|
) |
|
parser.add_argument( |
|
"--inference_step", |
|
type=int, |
|
default=200, |
|
help="Total inference steps for the diffusion model", |
|
) |
|
|