Spaces:
Running
on
A10G
Running
on
A10G
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import os | |
import re | |
import tempfile | |
import torch | |
import sys | |
import gradio as gr | |
import numpy as np | |
from huggingface_hub import hf_hub_download | |
# Setup TTS env | |
if "vits" not in sys.path: | |
sys.path.append("vits") | |
from vits import commons, utils | |
from vits.models import SynthesizerTrn | |
TTS_LANGUAGES = {} | |
with open(f"data/tts/all_langs.tsv") as f: | |
for line in f: | |
iso, name = line.split(" ", 1) | |
TTS_LANGUAGES[iso] = name | |
class TextMapper(object): | |
def __init__(self, vocab_file): | |
self.symbols = [ | |
x.replace("\n", "") for x in open(vocab_file, encoding="utf-8").readlines() | |
] | |
self.SPACE_ID = self.symbols.index(" ") | |
self._symbol_to_id = {s: i for i, s in enumerate(self.symbols)} | |
self._id_to_symbol = {i: s for i, s in enumerate(self.symbols)} | |
def text_to_sequence(self, text, cleaner_names): | |
"""Converts a string of text to a sequence of IDs corresponding to the symbols in the text. | |
Args: | |
text: string to convert to a sequence | |
cleaner_names: names of the cleaner functions to run the text through | |
Returns: | |
List of integers corresponding to the symbols in the text | |
""" | |
sequence = [] | |
clean_text = text.strip() | |
for symbol in clean_text: | |
symbol_id = self._symbol_to_id[symbol] | |
sequence += [symbol_id] | |
return sequence | |
def uromanize(self, text, uroman_pl): | |
iso = "xxx" | |
with tempfile.NamedTemporaryFile() as tf, tempfile.NamedTemporaryFile() as tf2: | |
with open(tf.name, "w") as f: | |
f.write("\n".join([text])) | |
cmd = f"perl " + uroman_pl | |
cmd += f" -l {iso} " | |
cmd += f" < {tf.name} > {tf2.name}" | |
os.system(cmd) | |
outtexts = [] | |
with open(tf2.name) as f: | |
for line in f: | |
line = re.sub(r"\s+", " ", line).strip() | |
outtexts.append(line) | |
outtext = outtexts[0] | |
return outtext | |
def get_text(self, text, hps): | |
text_norm = self.text_to_sequence(text, hps.data.text_cleaners) | |
if hps.data.add_blank: | |
text_norm = commons.intersperse(text_norm, 0) | |
text_norm = torch.LongTensor(text_norm) | |
return text_norm | |
def filter_oov(self, text, lang=None): | |
text = self.preprocess_char(text, lang=lang) | |
val_chars = self._symbol_to_id | |
txt_filt = "".join(list(filter(lambda x: x in val_chars, text))) | |
return txt_filt | |
def preprocess_char(self, text, lang=None): | |
""" | |
Special treatement of characters in certain languages | |
""" | |
if lang == "ron": | |
text = text.replace("ț", "ţ") | |
print(f"{lang} (ț -> ţ): {text}") | |
return text | |
def synthesize(text, lang, speed=None): | |
if speed is None: | |
speed = 1.0 | |
lang_code = lang.split()[0].strip() | |
vocab_file = hf_hub_download( | |
repo_id="facebook/mms-tts", | |
filename="vocab.txt", | |
subfolder=f"models/{lang_code}", | |
) | |
config_file = hf_hub_download( | |
repo_id="facebook/mms-tts", | |
filename="config.json", | |
subfolder=f"models/{lang_code}", | |
) | |
g_pth = hf_hub_download( | |
repo_id="facebook/mms-tts", | |
filename="G_100000.pth", | |
subfolder=f"models/{lang_code}", | |
) | |
if torch.cuda.is_available(): | |
device = torch.device("cuda") | |
elif ( | |
hasattr(torch.backends, "mps") | |
and torch.backends.mps.is_available() | |
and torch.backends.mps.is_built() | |
): | |
device = torch.device("mps") | |
else: | |
device = torch.device("cpu") | |
print(f"Run inference with {device}") | |
assert os.path.isfile(config_file), f"{config_file} doesn't exist" | |
hps = utils.get_hparams_from_file(config_file) | |
text_mapper = TextMapper(vocab_file) | |
net_g = SynthesizerTrn( | |
len(text_mapper.symbols), | |
hps.data.filter_length // 2 + 1, | |
hps.train.segment_size // hps.data.hop_length, | |
**hps.model, | |
) | |
net_g.to(device) | |
_ = net_g.eval() | |
_ = utils.load_checkpoint(g_pth, net_g, None) | |
is_uroman = hps.data.training_files.split(".")[-1] == "uroman" | |
if is_uroman: | |
uroman_dir = "uroman" | |
assert os.path.exists(uroman_dir) | |
uroman_pl = os.path.join(uroman_dir, "bin", "uroman.pl") | |
text = text_mapper.uromanize(text, uroman_pl) | |
text = text.lower() | |
text = text_mapper.filter_oov(text, lang=lang) | |
stn_tst = text_mapper.get_text(text, hps) | |
with torch.no_grad(): | |
x_tst = stn_tst.unsqueeze(0).to(device) | |
x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).to(device) | |
hyp = ( | |
net_g.infer( | |
x_tst, | |
x_tst_lengths, | |
noise_scale=0.667, | |
noise_scale_w=0.8, | |
length_scale=1.0 / speed, | |
)[0][0, 0] | |
.cpu() | |
.float() | |
.numpy() | |
) | |
hyp = (hyp * 32768).astype(np.int16) | |
return (hps.data.sampling_rate, hyp), text | |
TTS_EXAMPLES = [ | |
["I am going to the store.", "eng (English)"], | |
["안녕하세요.", "kor (Korean)"], | |
["क्या मुझे पीने का पानी मिल सकता है?", "hin (Hindi)"], | |
["Tanış olmağıma çox şadam", "azj-script_latin (Azerbaijani, North)"], | |
["Mu zo murna a cikin ƙasar.", "hau (Hausa)"], | |
] | |