File size: 3,963 Bytes
c968fc3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 |
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import os
import torch
import soundfile as sf
import numpy as np
from models.tts.naturalspeech2.ns2 import NaturalSpeech2
from encodec import EncodecModel
from encodec.utils import convert_audio
from utils.util import load_config
from text import text_to_sequence
from text.cmudict import valid_symbols
from text.g2p import preprocess_english, read_lexicon
import torchaudio
class NS2Inference:
def __init__(self, args, cfg):
self.cfg = cfg
self.args = args
self.model = self.build_model()
self.codec = self.build_codec()
self.symbols = valid_symbols + ["sp", "spn", "sil"] + ["<s>", "</s>"]
self.phone2id = {s: i for i, s in enumerate(self.symbols)}
self.id2phone = {i: s for s, i in self.phone2id.items()}
def build_model(self):
model = NaturalSpeech2(self.cfg.model)
model.load_state_dict(
torch.load(
os.path.join(self.args.checkpoint_path, "pytorch_model.bin"),
map_location="cpu",
)
)
model = model.to(self.args.device)
return model
def build_codec(self):
encodec_model = EncodecModel.encodec_model_24khz()
encodec_model = encodec_model.to(device=self.args.device)
encodec_model.set_target_bandwidth(12.0)
return encodec_model
def get_ref_code(self):
ref_wav_path = self.args.ref_audio
ref_wav, sr = torchaudio.load(ref_wav_path)
ref_wav = convert_audio(
ref_wav, sr, self.codec.sample_rate, self.codec.channels
)
ref_wav = ref_wav.unsqueeze(0).to(device=self.args.device)
with torch.no_grad():
encoded_frames = self.codec.encode(ref_wav)
ref_code = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1)
# print(ref_code.shape)
ref_mask = torch.ones(ref_code.shape[0], ref_code.shape[-1]).to(ref_code.device)
# print(ref_mask.shape)
return ref_code, ref_mask
def inference(self):
ref_code, ref_mask = self.get_ref_code()
lexicon = read_lexicon(self.cfg.preprocess.lexicon_path)
phone_seq = preprocess_english(self.args.text, lexicon)
print(phone_seq)
phone_id = np.array(
[
*map(
self.phone2id.get,
phone_seq.replace("{", "").replace("}", "").split(),
)
]
)
phone_id = torch.from_numpy(phone_id).unsqueeze(0).to(device=self.args.device)
print(phone_id)
x0, prior_out = self.model.inference(
ref_code, phone_id, ref_mask, self.args.inference_step
)
print(prior_out["dur_pred"])
print(prior_out["dur_pred_round"])
print(torch.sum(prior_out["dur_pred_round"]))
latent_ref = self.codec.quantizer.vq.decode(ref_code.transpose(0, 1))
rec_wav = self.codec.decoder(x0)
# ref_wav = self.codec.decoder(latent_ref)
os.makedirs(self.args.output_dir, exist_ok=True)
sf.write(
"{}/{}.wav".format(
self.args.output_dir, self.args.text.replace(" ", "_", 100)
),
rec_wav[0, 0].detach().cpu().numpy(),
samplerate=24000,
)
def add_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--ref_audio",
type=str,
default="",
help="Reference audio path",
)
parser.add_argument(
"--device",
type=str,
default="cuda",
)
parser.add_argument(
"--inference_step",
type=int,
default=200,
help="Total inference steps for the diffusion model",
)
|