Irpan
asr
499b2c1
from transformers import VitsModel, AutoTokenizer
import torch
import scipy.io.wavfile
from parallel_wavegan.utils import load_model
from espnet2.bin.tts_inference import Text2Speech
from turkicTTS_utils import normalization
import util
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load processor and model
models_info = {
"IS2AI-TurkicTTS": None,
"Meta-MMS": {
"processor": AutoTokenizer.from_pretrained("facebook/mms-tts-uig-script_arabic"),
"model": VitsModel.from_pretrained("facebook/mms-tts-uig-script_arabic"),
"arabic_script": True
},
"Ixxan-FineTuned-MMS": {
"processor": AutoTokenizer.from_pretrained("ixxan/mms-tts-uig-script_arabic-UQSpeech"),
"model": VitsModel.from_pretrained("ixxan/mms-tts-uig-script_arabic-UQSpeech"),
"arabic_script": True
}
}
vocoder_checkpoint="parallelwavegan_male2_checkpoint/checkpoint-400000steps.pkl" ### specify vocoder path
vocoder = load_model(vocoder_checkpoint).to(device).eval()
vocoder.remove_weight_norm()
### specify path to the main model(transformer/tacotron2/fastspeech) and its config file
config_file = "exp/tts_train_raw_char/config.yaml"
model_path = "exp/tts_train_raw_char/train.loss.ave_5best.pth"
text2speech = Text2Speech(
config_file,
model_path,
device=device, ## if cuda not available use cpu
### only for Tacotron 2
threshold=0.5,
minlenratio=0.0,
maxlenratio=10.0,
use_att_constraint=True,
backward_window=1,
forward_window=3,
### only for FastSpeech & FastSpeech2
speed_control_alpha=1.0,
)
text2speech.spc2wav = None ### disable griffin-lim
def synthesize(text, model_id):
print(text)
# if len(text) > 200:
# raise ValueError(f"Input text exceeds 200 characters. Please provide a shorter input text for faster processing.")
if model_id == 'IS2AI-TurkicTTS':
return synthesize_turkic_tts(text)
if models_info[model_id]["arabic_script"]:
text = util.ug_latn_to_arab(text)
processor = models_info[model_id]["processor"]
model = models_info[model_id]["model"].to(device)
inputs = processor(text, return_tensors="pt").to(device)
with torch.no_grad():
output = model(**inputs).waveform.cpu().numpy()[0] # Move output back to CPU for saving
output_path = "tts_output.wav"
sample_rate = model.config.sampling_rate
scipy.io.wavfile.write(output_path, rate=sample_rate, data=output)
return output_path
def synthesize_turkic_tts(text):
text = util.ug_arab_to_latn(text)
text = normalization(text, 'uyghur')
with torch.no_grad():
c_mel = text2speech(text)['feat_gen']
wav = vocoder.inference(c_mel)
output = wav.view(-1).cpu().numpy()
output_path = "tts_output.wav"
scipy.io.wavfile.write(output_path, rate=22050, data=output)
return output_path