VALL-E-X / app.py
soiz1's picture
Update app.py
32ed949 verified
import logging
import os
import pathlib
import time
import tempfile
import platform
import gc
def get_available_npz_files():
# Specify the directory where your .npz files are stored
npz_directory = './presets/' # Update this path if necessary
# Get a list of .npz files in the directory
npz_files = [f for f in os.listdir(npz_directory) if f.endswith('.npz')]
return npz_files
if platform.system().lower() == 'windows':
temp = pathlib.PosixPath
pathlib.PosixPath = pathlib.WindowsPath
elif platform.system().lower() == 'linux':
temp = pathlib.WindowsPath
pathlib.WindowsPath = pathlib.PosixPath
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
import langid
langid.set_languages(['en', 'zh', 'ja'])
import torch
import torchaudio
import numpy as np
from data.tokenizer import (
AudioTokenizer,
tokenize_audio,
)
from data.collation import get_text_token_collater
from models.vallex import VALLE
from utils.g2p import PhonemeBpeTokenizer
from descriptions import *
from macros import *
from examples import *
import gradio as gr
from vocos import Vocos
from transformers import WhisperProcessor, WhisperForConditionalGeneration
torch._C._jit_set_profiling_executor(False)
torch._C._jit_set_profiling_mode(False)
torch._C._set_graph_executor_optimize(False)
text_tokenizer = PhonemeBpeTokenizer(tokenizer_path="./utils/g2p/bpe_69.json")
text_collater = get_text_token_collater()
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
model = VALLE(
N_DIM,
NUM_HEAD,
NUM_LAYERS,
norm_first=True,
add_prenet=False,
prefix_mode=PREFIX_MODE,
share_embedding=True,
nar_scale_factor=1.0,
prepend_bos=True,
num_quantizers=NUM_QUANTIZERS,
).to(device)
checkpoint = torch.load("./epoch-10.pt", map_location='cpu')
missing_keys, unexpected_keys = model.load_state_dict(
checkpoint["model"], strict=True
)
del checkpoint
assert not missing_keys
model.eval()
audio_tokenizer = AudioTokenizer(device)
vocos = Vocos.from_pretrained('charactr/vocos-encodec-24khz').to(device)
whisper_processor = WhisperProcessor.from_pretrained("openai/whisper-medium")
whisper = WhisperForConditionalGeneration.from_pretrained("openai/whisper-medium").to(device)
whisper.config.forced_decoder_ids = None
preset_list = os.walk("./presets/").__next__()[2]
preset_list = [preset[:-4] for preset in preset_list if preset.endswith(".npz")]
def clear_prompts():
try:
path = tempfile.gettempdir()
for eachfile in os.listdir(path):
filename = os.path.join(path, eachfile)
if os.path.isfile(filename) and filename.endswith(".npz"):
lastmodifytime = os.stat(filename).st_mtime
endfiletime = time.time() - 60
if endfiletime > lastmodifytime:
os.remove(filename)
del path, filename, lastmodifytime, endfiletime
gc.collect()
except:
return
def transcribe_one(wav, sr):
if sr != 16000:
wav4trans = torchaudio.transforms.Resample(sr, 16000)(wav)
else:
wav4trans = wav
input_features = whisper_processor(wav4trans.squeeze(0), sampling_rate=16000, return_tensors="pt").input_features
predicted_ids = whisper.generate(input_features.to(device))
lang = whisper_processor.batch_decode(predicted_ids[:, 1])[0].strip("<|>")
text_pr = whisper_processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
print(text_pr)
if text_pr.strip(" ")[-1] not in "?!.,。,?!。、":
text_pr += "."
del wav4trans, input_features, predicted_ids
gc.collect()
return lang, text_pr
def make_npz_prompt(name, uploaded_audio, recorded_audio, transcript_content):
clear_prompts()
audio_prompt = uploaded_audio if uploaded_audio is not None else recorded_audio
sr, wav_pr = audio_prompt
if len(wav_pr) / sr > 15:
return "Rejected, Audio too long (should be less than 15 seconds)", None
if not isinstance(wav_pr, torch.FloatTensor):
wav_pr = torch.FloatTensor(wav_pr)
if wav_pr.abs().max() > 1:
wav_pr /= wav_pr.abs().max()
if wav_pr.size(-1) == 2:
wav_pr = wav_pr[:, 0]
if wav_pr.ndim == 1:
wav_pr = wav_pr.unsqueeze(0)
assert wav_pr.ndim and wav_pr.size(0) == 1
if transcript_content:
lang_pr = langid.classify(str(transcript_content))[0]
lang_token = lang2token[lang_pr]
transcript_content = transcript_content.replace("\n", "")
text_pr = f"{lang_token}{str(transcript_content)}{lang_token}"
lang_pr, text_pr = transcribe_one(wav_pr, sr)
lang_token = lang2token[lang_pr]
text_pr = lang_token + text_pr + lang_token
transcript_content=""
else:
llang_pr, text_pr = transcribe_one(wav_pr, sr)
lang_pr = lang_pr if lang_pr else 'ja'
lang_token = lang2token[lang_pr]
text_pr = lang_token + text_pr + lang_token
encoded_frames = tokenize_audio(audio_tokenizer, (wav_pr, sr))
audio_tokens = encoded_frames[0][0].transpose(2, 1).cpu().numpy()
phonemes, _ = text_tokenizer.tokenize(text=f"{text_pr}".strip())
text_tokens, enroll_x_lens = text_collater(
[
phonemes
]
)
message = f"Detected language: {lang_pr}\n Detected text {text_pr}\n"
if lang_pr not in ['ja', 'zh', 'en']:
return f"Prompt can only made with one of model-supported languages, got {lang_pr} instead", None
np.savez(os.path.join(tempfile.gettempdir(), f"{name}.npz"),
audio_tokens=audio_tokens, text_tokens=text_tokens, lang_code=lang2code[lang_pr])
del audio_tokens, text_tokens, phonemes, lang_pr, text_pr, wav_pr, sr, uploaded_audio, recorded_audio
gc.collect()
return message, os.path.join(tempfile.gettempdir(), f"{name}.npz")
@torch.no_grad()
def infer_from_audio(text, language, accent, audio_prompt, record_audio_prompt, transcript_content):
if len(text) > 150:
return "Rejected, Text too long (should be less than 150 characters)", None
if audio_prompt is None and record_audio_prompt is None:
audio_prompts = torch.zeros([1, 0, NUM_QUANTIZERS]).type(torch.int32).to(device)
text_prompts = torch.zeros([1, 0]).type(torch.int32)
lang_pr = 'en'
text_pr = ""
enroll_x_lens = 0
wav_pr, sr = None, None
else:
audio_prompt = audio_prompt if audio_prompt is not None else record_audio_prompt
sr, wav_pr = audio_prompt
if len(wav_pr) / sr > 15:
return "Rejected, Audio too long (should be less than 15 seconds)", None
if not isinstance(wav_pr, torch.FloatTensor):
wav_pr = torch.FloatTensor(wav_pr)
if wav_pr.abs().max() > 1:
wav_pr /= wav_pr.abs().max()
if wav_pr.size(-1) == 2:
wav_pr = wav_pr[:, 0]
if wav_pr.ndim == 1:
wav_pr = wav_pr.unsqueeze(0)
assert wav_pr.ndim and wav_pr.size(0) == 1
if transcript_content == "":
lang_pr, text_pr = transcribe_one(wav_pr, sr)
lang_token = lang2token[lang_pr]
text_pr = lang_token + text_pr + lang_token
else:
lang_pr = langid.classify(str(transcript_content))[0]
text_pr = transcript_content.replace("\n", "")
if lang_pr not in ['ja', 'zh', 'en']:
return f"Reference audio must be a speech of one of model-supported languages, got {lang_pr} instead", None
lang_token = lang2token[lang_pr]
text_pr = lang_token + text_pr + lang_token
encoded_frames = tokenize_audio(audio_tokenizer, (wav_pr, sr))
audio_prompts = encoded_frames[0][0].transpose(2, 1).to(device)
enroll_x_lens = None
if text_pr:
text_prompts, _ = text_tokenizer.tokenize(text=f"{text_pr}".strip())
text_prompts, enroll_x_lens = text_collater(
[
text_prompts
]
)
if language == 'auto-detect':
lang_token = lang2token[langid.classify(text)[0]]
else:
lang_token = langdropdown2token[language]
lang = token2lang[lang_token]
text = text.replace("\n", "")
text = lang_token + text + lang_token
logging.info(f"synthesize text: {text}")
phone_tokens, langs = text_tokenizer.tokenize(text=f"_{text}".strip())
text_tokens, text_tokens_lens = text_collater(
[
phone_tokens
]
)
text_tokens = torch.cat([text_prompts, text_tokens], dim=-1)
text_tokens_lens += enroll_x_lens
lang = lang if accent == "no-accent" else token2lang[langdropdown2token[accent]]
encoded_frames = model.inference(
text_tokens.to(device),
text_tokens_lens.to(device),
audio_prompts,
enroll_x_lens=enroll_x_lens,
top_k=-100,
temperature=1,
prompt_language=lang_pr,
text_language=langs if accent == "no-accent" else lang,
)
frames = encoded_frames.permute(2,0,1)
features = vocos.codes_to_features(frames)
samples = vocos.decode(features, bandwidth_id=torch.tensor([2], device=device))
message = f"text prompt: {text_pr}\nsythesized text: {text}"
del audio_prompts, text_tokens, text_prompts, phone_tokens, encoded_frames, wav_pr, sr, audio_prompt, record_audio_prompt, transcript_content
gc.collect()
return message, (24000, samples.squeeze(0).cpu().numpy())
@torch.no_grad()
def infer_from_prompt(text, language, accent, preset_prompt, prompt_file):
if len(text) > 150:
return "Rejected, Text too long (should be less than 150 characters)", None
clear_prompts()
if language == 'auto-detect':
lang_token = lang2token[langid.classify(text)[0]]
else:
lang_token = langdropdown2token[language]
lang = token2lang[lang_token]
text = text.replace("\n", "")
text = lang_token + text + lang_token
if prompt_file is not None:
prompt_data = np.load(prompt_file.name)
else:
prompt_data = np.load(os.path.join("./presets/", f"{preset_prompt}.npz"))
audio_prompts = prompt_data['audio_tokens']
text_prompts = prompt_data['text_tokens']
lang_pr = prompt_data['lang_code']
lang_pr = code2lang[int(lang_pr)]
audio_prompts = torch.tensor(audio_prompts).type(torch.int32).to(device)
text_prompts = torch.tensor(text_prompts).type(torch.int32)
enroll_x_lens = text_prompts.shape[-1]
logging.info(f"synthesize text: {text}")
phone_tokens, langs = text_tokenizer.tokenize(text=f"_{text}".strip())
text_tokens, text_tokens_lens = text_collater(
[
phone_tokens
]
)
text_tokens = torch.cat([text_prompts, text_tokens], dim=-1)
text_tokens_lens += enroll_x_lens
lang = lang if accent == "no-accent" else token2lang[langdropdown2token[accent]]
encoded_frames = model.inference(
text_tokens.to(device),
text_tokens_lens.to(device),
audio_prompts,
enroll_x_lens=enroll_x_lens,
top_k=-100,
temperature=1,
prompt_language=lang_pr,
text_language=langs if accent == "no-accent" else lang,
)
frames = encoded_frames.permute(2,0,1)
features = vocos.codes_to_features(frames)
samples = vocos.decode(features, bandwidth_id=torch.tensor([2], device=device))
message = f"sythesized text: {text}"
del audio_prompts, text_tokens, text_prompts, phone_tokens, encoded_frames, prompt_file, preset_prompt
gc.collect()
return message, (24000, samples.squeeze(0).cpu().numpy())
from utils.sentence_cutter import split_text_into_sentences
@torch.no_grad()
def infer_long_text(text, preset_prompt, prompt=None, language='auto', accent='no-accent'):
"""
For long audio generation, two modes are available.
fixed-prompt: This mode will keep using the same prompt the user has provided, and generate audio sentence by sentence.
sliding-window: This mode will use the last sentence as the prompt for the next sentence, but has some concern on speaker maintenance.
"""
if len(text) > 1000:
return "Rejected, Text too long (should be less than 1000 characters)", None
mode = 'fixed-prompt'
if (prompt is None or prompt == "") and preset_prompt == "":
mode = 'sliding-window' # If no prompt is given, use sliding-window mode
sentences = split_text_into_sentences(text)
if language == "auto-detect":
language = langid.classify(text)[0]
else:
language = token2lang[langdropdown2token[language]]
if prompt is not None and prompt != "":
prompt_data = np.load(prompt.name)
audio_prompts = prompt_data['audio_tokens']
text_prompts = prompt_data['text_tokens']
lang_pr = prompt_data['lang_code']
lang_pr = code2lang[int(lang_pr)]
audio_prompts = torch.tensor(audio_prompts).type(torch.int32).to(device)
text_prompts = torch.tensor(text_prompts).type(torch.int32)
elif preset_prompt is not None and preset_prompt != "":
prompt_data = np.load(os.path.join("./presets/", f"{preset_prompt}.npz"))
audio_prompts = prompt_data['audio_tokens']
text_prompts = prompt_data['text_tokens']
lang_pr = prompt_data['lang_code']
lang_pr = code2lang[int(lang_pr)]
audio_prompts = torch.tensor(audio_prompts).type(torch.int32).to(device)
text_prompts = torch.tensor(text_prompts).type(torch.int32)
else:
audio_prompts = torch.zeros([1, 0, NUM_QUANTIZERS]).type(torch.int32).to(device)
text_prompts = torch.zeros([1, 0]).type(torch.int32)
lang_pr = language if language != 'mix' else 'en'
if mode == 'fixed-prompt':
complete_tokens = torch.zeros([1, NUM_QUANTIZERS, 0]).type(torch.LongTensor).to(device)
for text in sentences:
text = text.replace("\n", "").strip(" ")
if text == "":
continue
lang_token = lang2token[language]
lang = token2lang[lang_token]
text = lang_token + text + lang_token
enroll_x_lens = text_prompts.shape[-1]
logging.info(f"synthesize text: {text}")
phone_tokens, langs = text_tokenizer.tokenize(text=f"_{text}".strip())
text_tokens, text_tokens_lens = text_collater(
[
phone_tokens
]
)
text_tokens = torch.cat([text_prompts, text_tokens], dim=-1)
text_tokens_lens += enroll_x_lens
lang = lang if accent == "no-accent" else token2lang[langdropdown2token[accent]]
encoded_frames = model.inference(
text_tokens.to(device),
text_tokens_lens.to(device),
audio_prompts,
enroll_x_lens=enroll_x_lens,
top_k=-100,
temperature=1,
prompt_language=lang_pr,
text_language=langs if accent == "no-accent" else lang,
)
complete_tokens = torch.cat([complete_tokens, encoded_frames.transpose(2, 1)], dim=-1)
frames = complete_tokens.permute(1, 0, 2)
features = vocos.codes_to_features(frames)
samples = vocos.decode(features, bandwidth_id=torch.tensor([2], device=device))
message = f"Cut into {len(sentences)} sentences"
return message, (24000, samples.squeeze(0).cpu().numpy())
elif mode == "sliding-window":
complete_tokens = torch.zeros([1, NUM_QUANTIZERS, 0]).type(torch.LongTensor).to(device)
original_audio_prompts = audio_prompts
original_text_prompts = text_prompts
for text in sentences:
text = text.replace("\n", "").strip(" ")
if text == "":
continue
lang_token = lang2token[language]
lang = token2lang[lang_token]
text = lang_token + text + lang_token
enroll_x_lens = text_prompts.shape[-1]
logging.info(f"synthesize text: {text}")
phone_tokens, langs = text_tokenizer.tokenize(text=f"_{text}".strip())
text_tokens, text_tokens_lens = text_collater(
[
phone_tokens
]
)
text_tokens = torch.cat([text_prompts, text_tokens], dim=-1)
text_tokens_lens += enroll_x_lens
lang = lang if accent == "no-accent" else token2lang[langdropdown2token[accent]]
encoded_frames = model.inference(
text_tokens.to(device),
text_tokens_lens.to(device),
audio_prompts,
enroll_x_lens=enroll_x_lens,
top_k=-100,
temperature=1,
prompt_language=lang_pr,
text_language=langs if accent == "no-accent" else lang,
)
complete_tokens = torch.cat([complete_tokens, encoded_frames.transpose(2, 1)], dim=-1)
if torch.rand(1) < 1.0:
audio_prompts = encoded_frames[:, :, -NUM_QUANTIZERS:]
text_prompts = text_tokens[:, enroll_x_lens:]
else:
audio_prompts = original_audio_prompts
text_prompts = original_text_prompts
frames = complete_tokens.permute(1, 0, 2)
features = vocos.codes_to_features(frames)
samples = vocos.decode(features, bandwidth_id=torch.tensor([2], device=device))
message = f"Cut into {len(sentences)} sentences"
return message, (24000, samples.squeeze(0).cpu().numpy())
else:
raise ValueError(f"No such mode {mode}")
with gr.Blocks() as app:
with gr.Tabs():
with gr.Tab("NPZファイルを作成"):
name = gr.Textbox(label="ファイル名")
uploaded_audio = gr.Audio(label="アップロード音声", type="numpy")
transcript = gr.Textbox(label="文字起こし内容")
result_message = gr.Textbox(label="結果", interactive=False)
npz_output = gr.File(label=".npz ファイル")
save_button = gr.Button("変換して保存")
save_button.click(make_npz_prompt, [name, uploaded_audio, transcript], [result_message, npz_output])
with gr.Tab("NPZファイルで生成"):
npz_files_dropdown = gr.Dropdown(label="利用可能な .npz ファイル", choices=get_available_npz_files(), interactive=True)
text_input = gr.Textbox(label="生成するテキスト")
language = gr.Radio(label="言語", choices=["auto-detect", "en", "ja", "zh"], value="auto-detect")
accent = gr.Radio(label="アクセント", choices=["no-accent", "en-accent", "ja-accent", "zh-accent"], value="no-accent")
preset_prompt = gr.Textbox(label="プロンプト名")
synthesis_message = gr.Textbox(label="結果", interactive=False)
audio_output = gr.Audio(label="生成音声", type="numpy")
generate_button = gr.Button("生成開始")
generate_button.click(infer_from_prompt, [text_input, language, accent, preset_prompt, npz_files_dropdown], [synthesis_message, audio_output])
app.launch()