Spaces:
Runtime error
Runtime error
# This demo is adopted from https://github.com/coqui-ai/TTS/blob/dev/TTS/demos/xtts_ft_demo/xtts_demo.py | |
# With some modifications to fit the viXTTS model | |
import argparse | |
import hashlib | |
import logging | |
import os | |
import string | |
import subprocess | |
import sys | |
import tempfile | |
from datetime import datetime | |
import gradio as gr | |
import soundfile as sf | |
import torch | |
import torchaudio | |
from huggingface_hub import hf_hub_download, snapshot_download | |
from underthesea import sent_tokenize | |
from unidecode import unidecode | |
from vinorm import TTSnorm | |
from TTS.tts.configs.xtts_config import XttsConfig | |
from TTS.tts.models.xtts import Xtts | |
XTTS_MODEL = None | |
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) | |
MODEL_DIR = os.path.join(SCRIPT_DIR, "model") | |
OUTPUT_DIR = os.path.join(SCRIPT_DIR, "output") | |
FILTER_SUFFIX = "_DeepFilterNet3.wav" | |
os.makedirs(OUTPUT_DIR, exist_ok=True) | |
def clear_gpu_cache(): | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
def load_model(checkpoint_dir="model/", repo_id="tuandaodev/xtts-vi-vinai-100h-custom-dvae", use_deepspeed=False): | |
global XTTS_MODEL | |
clear_gpu_cache() | |
os.makedirs(checkpoint_dir, exist_ok=True) | |
required_files = ["model.pth", "config.json", "vocab.json", "speakers_xtts.pth"] | |
files_in_dir = os.listdir(checkpoint_dir) | |
if not all(file in files_in_dir for file in required_files): | |
yield f"Missing model files! Downloading from {repo_id}..." | |
snapshot_download( | |
repo_id=repo_id, | |
repo_type="model", | |
local_dir=checkpoint_dir, | |
) | |
hf_hub_download( | |
repo_id="coqui/XTTS-v2", | |
filename="speakers_xtts.pth", | |
local_dir=checkpoint_dir, | |
) | |
yield f"Model download finished..." | |
xtts_config = os.path.join(checkpoint_dir, "config.json") | |
config = XttsConfig() | |
config.load_json(xtts_config) | |
XTTS_MODEL = Xtts.init_from_config(config) | |
yield "Loading model..." | |
XTTS_MODEL.load_checkpoint( | |
config, checkpoint_dir=checkpoint_dir, use_deepspeed=use_deepspeed | |
) | |
if torch.cuda.is_available(): | |
XTTS_MODEL.cuda() | |
print("Model Loaded!") | |
yield "Model Loaded!" | |
# Define dictionaries to store cached results | |
cache_queue = [] | |
speaker_audio_cache = {} | |
filter_cache = {} | |
conditioning_latents_cache = {} | |
def invalidate_cache(cache_limit=50): | |
"""Invalidate the cache for the oldest key""" | |
if len(cache_queue) > cache_limit: | |
key_to_remove = cache_queue.pop(0) | |
print("Invalidating cache", key_to_remove) | |
if os.path.exists(key_to_remove): | |
os.remove(key_to_remove) | |
if os.path.exists(key_to_remove.replace(".wav", "_DeepFilterNet3.wav")): | |
os.remove(key_to_remove.replace(".wav", "_DeepFilterNet3.wav")) | |
if key_to_remove in filter_cache: | |
del filter_cache[key_to_remove] | |
if key_to_remove in conditioning_latents_cache: | |
del conditioning_latents_cache[key_to_remove] | |
def generate_hash(data): | |
hash_object = hashlib.md5() | |
hash_object.update(data) | |
return hash_object.hexdigest() | |
def get_file_name(text, max_char=50): | |
filename = text[:max_char] | |
filename = filename.lower() | |
filename = filename.replace(" ", "_") | |
filename = filename.translate( | |
str.maketrans("", "", string.punctuation.replace("_", "")) | |
) | |
filename = unidecode(filename) | |
current_datetime = datetime.now().strftime("%m%d%H%M%S") | |
filename = f"{current_datetime}_{filename}" | |
return filename | |
def normalize_vietnamese_text(text): | |
text = text.encode('utf-8', 'ignore').decode('utf-8') | |
text = ( | |
TTSnorm(text, unknown=False, lower=False, rule=True) | |
.replace("..", ".") | |
.replace("!.", "!") | |
.replace("?.", "?") | |
.replace(" .", ".") | |
.replace(" ,", ",") | |
.replace('"', "") | |
.replace("'", "") | |
) | |
return text | |
def calculate_keep_len(text, lang): | |
"""Simple hack for short sentences""" | |
if lang in ["ja", "zh-cn"]: | |
return -1 | |
word_count = len(text.split()) | |
num_punct = text.count(".") + text.count("!") + text.count("?") + text.count(",") | |
if word_count < 5: | |
return 15000 * word_count + 2000 * num_punct | |
elif word_count < 10: | |
return 13000 * word_count + 2000 * num_punct | |
return -1 | |
def run_tts(lang, tts_text, speaker_audio_file, use_deepfilter, normalize_text): | |
global filter_cache, conditioning_latents_cache, cache_queue | |
if XTTS_MODEL is None: | |
return "You need to run the previous step to load the model !!", None, None | |
if not speaker_audio_file: | |
return "You need to provide reference audio!!!", None, None | |
# Use the file name as the key, since it's suppose to be unique 💀 | |
speaker_audio_key = speaker_audio_file | |
if not speaker_audio_key in cache_queue: | |
cache_queue.append(speaker_audio_key) | |
invalidate_cache() | |
# Check if filtered reference is cached | |
if use_deepfilter and speaker_audio_key in filter_cache: | |
print("Using filter cache...") | |
speaker_audio_file = filter_cache[speaker_audio_key] | |
elif use_deepfilter: | |
print("Running filter...") | |
subprocess.run( | |
[ | |
"deepFilter", | |
speaker_audio_file, | |
"-o", | |
os.path.dirname(speaker_audio_file), | |
] | |
) | |
filter_cache[speaker_audio_key] = speaker_audio_file.replace( | |
".wav", FILTER_SUFFIX | |
) | |
speaker_audio_file = filter_cache[speaker_audio_key] | |
# Check if conditioning latents are cached | |
cache_key = ( | |
speaker_audio_key, | |
XTTS_MODEL.config.gpt_cond_len, | |
XTTS_MODEL.config.max_ref_len, | |
XTTS_MODEL.config.sound_norm_refs, | |
) | |
if cache_key in conditioning_latents_cache: | |
print("Using conditioning latents cache...") | |
gpt_cond_latent, speaker_embedding = conditioning_latents_cache[cache_key] | |
else: | |
print("Computing conditioning latents...") | |
gpt_cond_latent, speaker_embedding = XTTS_MODEL.get_conditioning_latents( | |
audio_path=speaker_audio_file, | |
gpt_cond_len=XTTS_MODEL.config.gpt_cond_len, | |
max_ref_length=XTTS_MODEL.config.max_ref_len, | |
sound_norm_refs=XTTS_MODEL.config.sound_norm_refs, | |
) | |
conditioning_latents_cache[cache_key] = (gpt_cond_latent, speaker_embedding) | |
if normalize_text and lang == "vi": | |
tts_text = normalize_vietnamese_text(tts_text) | |
# Split text by sentence | |
if lang in ["ja", "zh-cn"]: | |
sentences = tts_text.split("。") | |
else: | |
sentences = sent_tokenize(tts_text) | |
from pprint import pprint | |
pprint(sentences) | |
wav_chunks = [] | |
for sentence in sentences: | |
if sentence.strip() == "": | |
continue | |
wav_chunk = XTTS_MODEL.inference( | |
text=sentence, | |
language=lang, | |
gpt_cond_latent=gpt_cond_latent, | |
speaker_embedding=speaker_embedding, | |
# The following values are carefully chosen for viXTTS | |
temperature=0.3, | |
length_penalty=1.0, | |
repetition_penalty=10.0, | |
top_k=30, | |
top_p=0.85, | |
enable_text_splitting=True, | |
) | |
keep_len = calculate_keep_len(sentence, lang) | |
wav_chunk["wav"] = wav_chunk["wav"][:keep_len] | |
wav_chunks.append(torch.tensor(wav_chunk["wav"])) | |
out_wav = torch.cat(wav_chunks, dim=0).unsqueeze(0) | |
gr_audio_id = os.path.basename(os.path.dirname(speaker_audio_file)) | |
out_path = os.path.join(OUTPUT_DIR, f"{get_file_name(tts_text)}_{gr_audio_id}.wav") | |
print("Saving output to ", out_path) | |
torchaudio.save(out_path, out_wav, 24000) | |
return "Speech generated !", out_path | |
# Define a logger to redirect | |
class Logger: | |
def __init__(self, filename="log.out"): | |
self.log_file = filename | |
self.terminal = sys.stdout | |
self.log = open(self.log_file, "w") | |
def write(self, message): | |
self.terminal.write(message) | |
self.log.write(message) | |
def flush(self): | |
self.terminal.flush() | |
self.log.flush() | |
def isatty(self): | |
return False | |
# Redirect stdout and stderr to a file | |
sys.stdout = Logger() | |
sys.stderr = sys.stdout | |
logging.basicConfig( | |
level=logging.ERROR, | |
format="%(asctime)s [%(levelname)s] %(message)s", | |
handlers=[logging.StreamHandler(sys.stdout)], | |
) | |
def read_logs(): | |
sys.stdout.flush() | |
with open(sys.stdout.log_file, "r") as f: | |
return f.read() | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser( | |
description="""XTTS VI inference demo\n\n""", | |
formatter_class=argparse.RawTextHelpFormatter, | |
) | |
parser.add_argument( | |
"--port", | |
type=int, | |
help="Port to run the gradio demo. Default: 5003", | |
default=5003, | |
) | |
parser.add_argument( | |
"--model_dir", | |
type=str, | |
help="Path to the checkpoint directory. This directory must contain 04 files: model.pth, config.json, vocab.json and speakers_xtts.pth", | |
default=None, | |
) | |
parser.add_argument( | |
"--reference_audio", | |
type=str, | |
help="Path to the reference audio file.", | |
default=None, | |
) | |
args = parser.parse_args() | |
if args.model_dir: | |
MODEL_DIR = os.path.abspath(args.model_dir) | |
REFERENCE_AUDIO = os.path.join(SCRIPT_DIR, "assets", "vi-man_kien-thuc-quan-su.wav") | |
if args.reference_audio: | |
REFERENCE_AUDIO = os.abspath(args.reference_audio) | |
with gr.Blocks() as demo: | |
intro = """ | |
# XTTS VI Inference Demo | |
""" | |
gr.Markdown(intro) | |
with gr.Row(): | |
with gr.Column() as col1: | |
repo_id = gr.Textbox( | |
label="HuggingFace Repo ID", | |
value="tuandaodev/xtts-vi-vinai-100h-custom-dvae", | |
) | |
checkpoint_dir = gr.Textbox( | |
label="XTTS VI model directory", | |
value=MODEL_DIR, | |
) | |
use_deepspeed = gr.Checkbox( | |
value=True, label="Use DeepSpeed for faster inference" | |
) | |
progress_load = gr.Label(label="Progress:") | |
load_btn = gr.Button( | |
value="Step 1 - Load XTTS VI model", variant="primary" | |
) | |
with gr.Column() as col2: | |
speaker_reference_audio = gr.Audio( | |
label="Speaker reference audio:", | |
value=REFERENCE_AUDIO, | |
type="filepath", | |
) | |
tts_language = gr.Dropdown( | |
label="Language", | |
value="vi", | |
choices=[ | |
"vi", | |
"en", | |
"es", | |
"fr", | |
"de", | |
"it", | |
"pt", | |
"pl", | |
"tr", | |
"ru", | |
"nl", | |
"cs", | |
"ar", | |
"zh", | |
"hu", | |
"ko", | |
"ja", | |
], | |
) | |
use_filter = gr.Checkbox( | |
label="Denoise Reference Audio", | |
value=True, | |
) | |
normalize_text = gr.Checkbox( | |
label="Normalize Input Text", | |
value=True, | |
) | |
tts_text = gr.Textbox( | |
label="Input Text.", | |
value="Xin chào, tôi là một công cụ chuyển đổi văn bản thành giọng nói tiếng Việt, được huấn luyện trong môn học xử lý giọng nói.", | |
) | |
tts_btn = gr.Button(value="Step 2 - Inference", variant="primary") | |
with gr.Column() as col3: | |
progress_gen = gr.Label(label="Progress:") | |
tts_output_audio = gr.Audio(label="Generated Audio.") | |
load_btn.click( | |
fn=load_model, | |
inputs=[checkpoint_dir, repo_id, use_deepspeed], | |
outputs=[progress_load], | |
) | |
tts_btn.click( | |
fn=run_tts, | |
inputs=[ | |
tts_language, | |
tts_text, | |
speaker_reference_audio, | |
use_filter, | |
normalize_text, | |
], | |
outputs=[progress_gen, tts_output_audio], | |
) | |
demo.launch(share=True, debug=False, server_port=args.port, server_name="0.0.0.0") | |