Spaces:
Running
Running
from transformers import VitsModel, AutoTokenizer | |
import torch | |
import scipy.io.wavfile | |
from parallel_wavegan.utils import load_model | |
from espnet2.bin.tts_inference import Text2Speech | |
from turkicTTS_utils import normalization | |
import util | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# Load processor and model | |
models_info = { | |
"IS2AI-TurkicTTS": None, | |
"Meta-MMS": { | |
"processor": AutoTokenizer.from_pretrained("facebook/mms-tts-uig-script_arabic"), | |
"model": VitsModel.from_pretrained("facebook/mms-tts-uig-script_arabic"), | |
"arabic_script": True | |
}, | |
"Ixxan-FineTuned-MMS": { | |
"processor": AutoTokenizer.from_pretrained("ixxan/mms-tts-uig-script_arabic-UQSpeech"), | |
"model": VitsModel.from_pretrained("ixxan/mms-tts-uig-script_arabic-UQSpeech"), | |
"arabic_script": True | |
} | |
} | |
vocoder_checkpoint="parallelwavegan_male2_checkpoint/checkpoint-400000steps.pkl" ### specify vocoder path | |
vocoder = load_model(vocoder_checkpoint).to(device).eval() | |
vocoder.remove_weight_norm() | |
### specify path to the main model(transformer/tacotron2/fastspeech) and its config file | |
config_file = "exp/tts_train_raw_char/config.yaml" | |
model_path = "exp/tts_train_raw_char/train.loss.ave_5best.pth" | |
text2speech = Text2Speech( | |
config_file, | |
model_path, | |
device=device, ## if cuda not available use cpu | |
### only for Tacotron 2 | |
threshold=0.5, | |
minlenratio=0.0, | |
maxlenratio=10.0, | |
use_att_constraint=True, | |
backward_window=1, | |
forward_window=3, | |
### only for FastSpeech & FastSpeech2 | |
speed_control_alpha=1.0, | |
) | |
text2speech.spc2wav = None ### disable griffin-lim | |
def synthesize(text, model_id): | |
print(text) | |
# if len(text) > 200: | |
# raise ValueError(f"Input text exceeds 200 characters. Please provide a shorter input text for faster processing.") | |
if model_id == 'IS2AI-TurkicTTS': | |
return synthesize_turkic_tts(text) | |
if models_info[model_id]["arabic_script"]: | |
text = util.ug_latn_to_arab(text) | |
processor = models_info[model_id]["processor"] | |
model = models_info[model_id]["model"].to(device) | |
inputs = processor(text, return_tensors="pt").to(device) | |
with torch.no_grad(): | |
output = model(**inputs).waveform.cpu().numpy()[0] # Move output back to CPU for saving | |
output_path = "tts_output.wav" | |
sample_rate = model.config.sampling_rate | |
scipy.io.wavfile.write(output_path, rate=sample_rate, data=output) | |
return output_path | |
def synthesize_turkic_tts(text): | |
text = util.ug_arab_to_latn(text) | |
text = normalization(text, 'uyghur') | |
with torch.no_grad(): | |
c_mel = text2speech(text)['feat_gen'] | |
wav = vocoder.inference(c_mel) | |
output = wav.view(-1).cpu().numpy() | |
output_path = "tts_output.wav" | |
scipy.io.wavfile.write(output_path, rate=22050, data=output) | |
return output_path | |