Spaces:
Running
on
L4
Running
on
L4
import os | |
import torch | |
import gdown | |
import logging | |
import psutil | |
import langid | |
langid.set_languages(['en', 'zh', 'ja']) | |
import pathlib | |
import platform | |
if platform.system().lower() == 'windows': | |
temp = pathlib.PosixPath | |
pathlib.PosixPath = pathlib.WindowsPath | |
elif platform.system().lower() == 'linux': | |
temp = pathlib.WindowsPath | |
pathlib.WindowsPath = pathlib.PosixPath | |
import numpy as np | |
from data.tokenizer import ( | |
AudioTokenizer, | |
tokenize_audio, | |
) | |
from data.collation import get_text_token_collater | |
from models.vallex import VALLE | |
from utils.g2p import PhonemeBpeTokenizer | |
from utils.sentence_cutter import split_text_into_sentences | |
from macros import * | |
device = torch.device("cpu") | |
if torch.cuda.is_available(): | |
device = torch.device("cuda", 0) | |
url = 'https://drive.google.com/file/d/10gdQWvP-K_e1undkvv0p2b7SU6I4Egyl/view?usp=sharing' | |
checkpoints_dir = "./checkpoints/" | |
model_checkpoint_name = "vallex-checkpoint.pt" | |
model = None | |
codec = None | |
text_tokenizer = PhonemeBpeTokenizer(tokenizer_path="./utils/g2p/bpe_69.json") | |
text_collater = get_text_token_collater() | |
def preload_models(): | |
global model, codec | |
if not os.path.exists(checkpoints_dir): os.mkdir(checkpoints_dir) | |
if not os.path.exists(os.path.join(checkpoints_dir, model_checkpoint_name)): | |
gdown.download(id="10gdQWvP-K_e1undkvv0p2b7SU6I4Egyl", output=os.path.join(checkpoints_dir, model_checkpoint_name), quiet=False) | |
# VALL-E | |
model = VALLE( | |
N_DIM, | |
NUM_HEAD, | |
NUM_LAYERS, | |
norm_first=True, | |
add_prenet=False, | |
prefix_mode=PREFIX_MODE, | |
share_embedding=True, | |
nar_scale_factor=1.0, | |
prepend_bos=True, | |
num_quantizers=NUM_QUANTIZERS, | |
).to(device) | |
checkpoint = torch.load(os.path.join(checkpoints_dir, model_checkpoint_name), map_location='cpu') | |
missing_keys, unexpected_keys = model.load_state_dict( | |
checkpoint["model"], strict=True | |
) | |
assert not missing_keys | |
model.eval() | |
# Encodec | |
codec = AudioTokenizer(device) | |
def generate_audio(text, prompt=None, language='auto', accent='no-accent'): | |
global model, codec, text_tokenizer, text_collater | |
text = text.replace("\n", "").strip(" ") | |
# detect language | |
if language == "auto": | |
language = langid.classify(text)[0] | |
lang_token = lang2token[language] | |
lang = token2lang[lang_token] | |
text = lang_token + text + lang_token | |
# load prompt | |
if prompt is not None: | |
prompt_path = prompt | |
if not os.path.exists(prompt_path): | |
prompt_path = "./presets/" + prompt + ".npz" | |
if not os.path.exists(prompt_path): | |
prompt_path = "./customs/" + prompt + ".npz" | |
if not os.path.exists(prompt_path): | |
raise ValueError(f"Cannot find prompt {prompt}") | |
prompt_data = np.load(prompt_path) | |
audio_prompts = prompt_data['audio_tokens'] | |
text_prompts = prompt_data['text_tokens'] | |
lang_pr = prompt_data['lang_code'] | |
lang_pr = code2lang[int(lang_pr)] | |
# numpy to tensor | |
audio_prompts = torch.tensor(audio_prompts).type(torch.int32).to(device) | |
text_prompts = torch.tensor(text_prompts).type(torch.int32) | |
else: | |
audio_prompts = torch.zeros([1, 0, NUM_QUANTIZERS]).type(torch.int32).to(device) | |
text_prompts = torch.zeros([1, 0]).type(torch.int32) | |
lang_pr = lang if lang != 'mix' else 'en' | |
enroll_x_lens = text_prompts.shape[-1] | |
logging.info(f"synthesize text: {text}") | |
phone_tokens, langs = text_tokenizer.tokenize(text=f"_{text}".strip()) | |
text_tokens, text_tokens_lens = text_collater( | |
[ | |
phone_tokens | |
] | |
) | |
text_tokens = torch.cat([text_prompts, text_tokens], dim=-1) | |
text_tokens_lens += enroll_x_lens | |
# accent control | |
lang = lang if accent == "no-accent" else token2lang[langdropdown2token[accent]] | |
encoded_frames = model.inference( | |
text_tokens.to(device), | |
text_tokens_lens.to(device), | |
audio_prompts, | |
enroll_x_lens=enroll_x_lens, | |
top_k=-100, | |
temperature=1, | |
prompt_language=lang_pr, | |
text_language=langs if accent == "no-accent" else lang, | |
) | |
samples = codec.decode( | |
[(encoded_frames.transpose(2, 1), None)] | |
) | |
return samples[0][0].cpu().numpy() | |
def generate_audio_from_long_text(text, prompt=None, language='auto', accent='no-accent', mode='sliding-window'): | |
""" | |
For long audio generation, two modes are available. | |
fixed-prompt: This mode will keep using the same prompt the user has provided, and generate audio sentence by sentence. | |
sliding-window: This mode will use the last sentence as the prompt for the next sentence, but has some concern on speaker maintenance. | |
""" | |
global model, codec, text_tokenizer, text_collater | |
if prompt is None or prompt == "": | |
mode = 'sliding-window' # If no prompt is given, use sliding-window mode | |
sentences = split_text_into_sentences(text) | |
# detect language | |
if language == "auto": | |
language = langid.classify(text)[0] | |
# if initial prompt is given, encode it | |
if prompt is not None and prompt != "": | |
prompt_path = prompt | |
if not os.path.exists(prompt_path): | |
prompt_path = "./presets/" + prompt + ".npz" | |
if not os.path.exists(prompt_path): | |
prompt_path = "./customs/" + prompt + ".npz" | |
if not os.path.exists(prompt_path): | |
raise ValueError(f"Cannot find prompt {prompt}") | |
prompt_data = np.load(prompt_path) | |
audio_prompts = prompt_data['audio_tokens'] | |
text_prompts = prompt_data['text_tokens'] | |
lang_pr = prompt_data['lang_code'] | |
lang_pr = code2lang[int(lang_pr)] | |
# numpy to tensor | |
audio_prompts = torch.tensor(audio_prompts).type(torch.int32).to(device) | |
text_prompts = torch.tensor(text_prompts).type(torch.int32) | |
else: | |
audio_prompts = torch.zeros([1, 0, NUM_QUANTIZERS]).type(torch.int32).to(device) | |
text_prompts = torch.zeros([1, 0]).type(torch.int32) | |
lang_pr = language if language != 'mix' else 'en' | |
if mode == 'fixed-prompt': | |
complete_tokens = torch.zeros([1, NUM_QUANTIZERS, 0]).type(torch.LongTensor).to(device) | |
for text in sentences: | |
text = text.replace("\n", "").strip(" ") | |
if text == "": | |
continue | |
lang_token = lang2token[language] | |
lang = token2lang[lang_token] | |
text = lang_token + text + lang_token | |
enroll_x_lens = text_prompts.shape[-1] | |
logging.info(f"synthesize text: {text}") | |
phone_tokens, langs = text_tokenizer.tokenize(text=f"_{text}".strip()) | |
text_tokens, text_tokens_lens = text_collater( | |
[ | |
phone_tokens | |
] | |
) | |
text_tokens = torch.cat([text_prompts, text_tokens], dim=-1) | |
text_tokens_lens += enroll_x_lens | |
# accent control | |
lang = lang if accent == "no-accent" else token2lang[langdropdown2token[accent]] | |
encoded_frames = model.inference( | |
text_tokens.to(device), | |
text_tokens_lens.to(device), | |
audio_prompts, | |
enroll_x_lens=enroll_x_lens, | |
top_k=-100, | |
temperature=1, | |
prompt_language=lang_pr, | |
text_language=langs if accent == "no-accent" else lang, | |
) | |
complete_tokens = torch.cat([complete_tokens, encoded_frames.transpose(2, 1)], dim=-1) | |
samples = codec.decode( | |
[(complete_tokens, None)] | |
) | |
return samples[0][0].cpu().numpy() | |
elif mode == "sliding-window": | |
complete_tokens = torch.zeros([1, NUM_QUANTIZERS, 0]).type(torch.LongTensor).to(device) | |
original_audio_prompts = audio_prompts | |
original_text_prompts = text_prompts | |
for text in sentences: | |
text = text.replace("\n", "").strip(" ") | |
if text == "": | |
continue | |
lang_token = lang2token[language] | |
lang = token2lang[lang_token] | |
text = lang_token + text + lang_token | |
enroll_x_lens = text_prompts.shape[-1] | |
logging.info(f"synthesize text: {text}") | |
phone_tokens, langs = text_tokenizer.tokenize(text=f"_{text}".strip()) | |
text_tokens, text_tokens_lens = text_collater( | |
[ | |
phone_tokens | |
] | |
) | |
text_tokens = torch.cat([text_prompts, text_tokens], dim=-1) | |
text_tokens_lens += enroll_x_lens | |
# accent control | |
lang = lang if accent == "no-accent" else token2lang[langdropdown2token[accent]] | |
encoded_frames = model.inference( | |
text_tokens.to(device), | |
text_tokens_lens.to(device), | |
audio_prompts, | |
enroll_x_lens=enroll_x_lens, | |
top_k=-100, | |
temperature=1, | |
prompt_language=lang_pr, | |
text_language=langs if accent == "no-accent" else lang, | |
) | |
complete_tokens = torch.cat([complete_tokens, encoded_frames.transpose(2, 1)], dim=-1) | |
if torch.rand(1) < 0.5: | |
audio_prompts = encoded_frames[:, :, -NUM_QUANTIZERS:] | |
text_prompts = text_tokens[:, enroll_x_lens:] | |
else: | |
audio_prompts = original_audio_prompts | |
text_prompts = original_text_prompts | |
samples = codec.decode( | |
[(complete_tokens, None)] | |
) | |
return samples[0][0].cpu().numpy() | |
else: | |
raise ValueError(f"No such mode {mode}") | |