ASR_TTS_DEMO / app.py
FarmerlineML's picture
Update app.py
1fd2fc8 verified
import gradio as gr
from transformers import AutoTokenizer, VitsModel, Wav2Vec2Processor, Wav2Vec2ForCTC
import torch
import scipy.io.wavfile
import numpy as np
# TTS Models
twi_model_id = "FarmerlineML/main_twi_TTS_V2"
twi_model = VitsModel.from_pretrained(twi_model_id)
twi_tokenizer = AutoTokenizer.from_pretrained(twi_model_id)
ewe_model_id = "FarmerlineML/main_ewe_TTS"
ewe_model = VitsModel.from_pretrained(ewe_model_id)
ewe_tokenizer = AutoTokenizer.from_pretrained(ewe_model_id)
# ASR Models
asr_models = {
"Twi": "FarmerlineML/akan_ASR_alpha",
"Ewe": "FarmerlineML/ewe_ASR_3.0"
}
# Initialize ASR models and processors
asr_processors = {}
asr_models_loaded = {}
for lang, model_name in asr_models.items():
asr_processors[lang] = Wav2Vec2Processor.from_pretrained(model_name)
asr_models_loaded[lang] = Wav2Vec2ForCTC.from_pretrained(model_name)
def generate_speech(text, language, noise_scale=0.8, noise_scale_duration=0.9, speaking_rate=1.0):
if language == "Twi":
model = twi_model
tokenizer = twi_tokenizer
else: # Ewe
model = ewe_model
tokenizer = ewe_tokenizer
inputs = tokenizer(text, return_tensors="pt")
model.noise_scale = noise_scale
model.noise_scale_duration = noise_scale_duration
model.speaking_rate = speaking_rate
with torch.no_grad():
output = model(**inputs)
return output.waveform[0].cpu().numpy(), model.config.sampling_rate
def tts_interface(text, language):
if not text.strip():
return None, "Please enter some text."
audio, sampling_rate = generate_speech(text, language)
# Save the audio to a file
output_file = f"output_{language.lower()}.wav"
scipy.io.wavfile.write(output_file, rate=sampling_rate, data=audio)
return output_file, f"{language} audio generated successfully."
def transcribe_audio(audio, language):
processor = asr_processors[language]
model = asr_models_loaded[language]
# Preprocess the audio
input_values = processor(audio["array"], sampling_rate=audio["sampling_rate"], return_tensors="pt", padding="longest").input_values
# Perform inference
with torch.no_grad():
logits = model(input_values).logits
# Decode the output
predicted_ids = torch.argmax(logits, dim=-1)
transcription = processor.batch_decode(predicted_ids)
return transcription[0]
def asr_interface(audio, language):
if audio is None:
return "Please upload or record audio."
return transcribe_audio(audio, language)
def update_language(language):
return f"Selected language: {language}", f"Enter {language} text to convert to speech", f"Upload or record {language} audio"
with gr.Blocks() as demo:
gr.Markdown("# Twi and Ewe ASR and TTS Demo")
language_selector = gr.Dropdown(choices=["Twi", "Ewe"], label="Language", value="Twi")
language_display = gr.Textbox(label="Current Language", value="Selected language: Twi")
with gr.Tab("Text-to-Speech"):
tts_input = gr.Textbox(label="Enter Twi text to convert to speech")
tts_button = gr.Button("Generate Speech")
tts_output = gr.Audio(label="Generated Speech")
tts_message = gr.Textbox(label="Message")
with gr.Tab("Automatic Speech Recognition"):
asr_input = gr.Audio(label="Upload or record Twi audio")
asr_button = gr.Button("Transcribe")
asr_output = gr.Textbox(label="Transcription")
# Set up event handlers
language_selector.change(
update_language,
inputs=[language_selector],
outputs=[language_display, tts_input, asr_input]
)
tts_button.click(
tts_interface,
inputs=[tts_input, language_selector],
outputs=[tts_output, tts_message]
)
asr_button.click(
asr_interface,
inputs=[asr_input, language_selector],
outputs=[asr_output]
)
# For Hugging Face Spaces, we use `demo.launch()` without any arguments
demo.launch()