Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -1,7 +1,119 @@
|
|
1 |
import gradio as gr
|
|
|
|
|
|
|
|
|
2 |
|
3 |
-
|
4 |
-
|
|
|
|
|
5 |
|
6 |
-
|
7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
+
from transformers import AutoTokenizer, VitsModel, Wav2Vec2Processor, Wav2Vec2ForCTC
|
3 |
+
import torch
|
4 |
+
import scipy.io.wavfile
|
5 |
+
import numpy as np
|
6 |
|
7 |
+
# TTS Models
|
8 |
+
twi_model_id = "FarmerlineML/main_twi_TTS_V2"
|
9 |
+
twi_model = VitsModel.from_pretrained(twi_model_id)
|
10 |
+
twi_tokenizer = AutoTokenizer.from_pretrained(twi_model_id)
|
11 |
|
12 |
+
ewe_model_id = "FarmerlineML/main_ewe_TTS"
|
13 |
+
ewe_model = VitsModel.from_pretrained(ewe_model_id)
|
14 |
+
ewe_tokenizer = AutoTokenizer.from_pretrained(ewe_model_id)
|
15 |
+
|
16 |
+
# ASR Models
|
17 |
+
asr_models = {
|
18 |
+
"Twi": "FarmerlineML/akan_ASR_alpha",
|
19 |
+
"Ewe": "FarmerlineML/ewe_ASR_3.0"
|
20 |
+
}
|
21 |
+
|
22 |
+
# Initialize ASR models and processors
|
23 |
+
asr_processors = {}
|
24 |
+
asr_models_loaded = {}
|
25 |
+
for lang, model_name in asr_models.items():
|
26 |
+
asr_processors[lang] = Wav2Vec2Processor.from_pretrained(model_name)
|
27 |
+
asr_models_loaded[lang] = Wav2Vec2ForCTC.from_pretrained(model_name)
|
28 |
+
|
29 |
+
def generate_speech(text, language, noise_scale=0.8, noise_scale_duration=0.9, speaking_rate=1.0):
|
30 |
+
if language == "Twi":
|
31 |
+
model = twi_model
|
32 |
+
tokenizer = twi_tokenizer
|
33 |
+
else: # Ewe
|
34 |
+
model = ewe_model
|
35 |
+
tokenizer = ewe_tokenizer
|
36 |
+
|
37 |
+
inputs = tokenizer(text, return_tensors="pt")
|
38 |
+
model.noise_scale = noise_scale
|
39 |
+
model.noise_scale_duration = noise_scale_duration
|
40 |
+
model.speaking_rate = speaking_rate
|
41 |
+
with torch.no_grad():
|
42 |
+
output = model(**inputs)
|
43 |
+
return output.waveform[0].cpu().numpy(), model.config.sampling_rate
|
44 |
+
|
45 |
+
def tts_interface(text, language):
|
46 |
+
if not text.strip():
|
47 |
+
return None, "Please enter some text."
|
48 |
+
|
49 |
+
audio, sampling_rate = generate_speech(text, language)
|
50 |
+
|
51 |
+
# Save the audio to a file
|
52 |
+
output_file = f"output_{language.lower()}.wav"
|
53 |
+
scipy.io.wavfile.write(output_file, rate=sampling_rate, data=audio)
|
54 |
+
|
55 |
+
return output_file, f"{language} audio generated successfully."
|
56 |
+
|
57 |
+
def transcribe_audio(audio, language):
|
58 |
+
processor = asr_processors[language]
|
59 |
+
model = asr_models_loaded[language]
|
60 |
+
|
61 |
+
# Preprocess the audio
|
62 |
+
input_values = processor(audio["array"], sampling_rate=audio["sampling_rate"], return_tensors="pt", padding="longest").input_values
|
63 |
+
|
64 |
+
# Perform inference
|
65 |
+
with torch.no_grad():
|
66 |
+
logits = model(input_values).logits
|
67 |
+
|
68 |
+
# Decode the output
|
69 |
+
predicted_ids = torch.argmax(logits, dim=-1)
|
70 |
+
transcription = processor.batch_decode(predicted_ids)
|
71 |
+
|
72 |
+
return transcription[0]
|
73 |
+
|
74 |
+
def asr_interface(audio, language):
|
75 |
+
if audio is None:
|
76 |
+
return "Please upload or record audio."
|
77 |
+
return transcribe_audio(audio, language)
|
78 |
+
|
79 |
+
def update_language(language):
|
80 |
+
return f"Selected language: {language}", f"Enter {language} text to convert to speech", f"Upload or record {language} audio"
|
81 |
+
|
82 |
+
with gr.Blocks() as demo:
|
83 |
+
gr.Markdown("# Twi and Ewe ASR and TTS Demo")
|
84 |
+
|
85 |
+
language_selector = gr.Dropdown(choices=["Twi", "Ewe"], label="Language", value="Twi")
|
86 |
+
language_display = gr.Textbox(label="Current Language", value="Selected language: Twi")
|
87 |
+
|
88 |
+
with gr.Tab("Text-to-Speech"):
|
89 |
+
tts_input = gr.Textbox(label="Enter Twi text to convert to speech")
|
90 |
+
tts_button = gr.Button("Generate Speech")
|
91 |
+
tts_output = gr.Audio(label="Generated Speech")
|
92 |
+
tts_message = gr.Textbox(label="Message")
|
93 |
+
|
94 |
+
with gr.Tab("Automatic Speech Recognition"):
|
95 |
+
asr_input = gr.Audio(label="Upload or record Twi audio")
|
96 |
+
asr_button = gr.Button("Transcribe")
|
97 |
+
asr_output = gr.Textbox(label="Transcription")
|
98 |
+
|
99 |
+
# Set up event handlers
|
100 |
+
language_selector.change(
|
101 |
+
update_language,
|
102 |
+
inputs=[language_selector],
|
103 |
+
outputs=[language_display, tts_input, asr_input]
|
104 |
+
)
|
105 |
+
|
106 |
+
tts_button.click(
|
107 |
+
tts_interface,
|
108 |
+
inputs=[tts_input, language_selector],
|
109 |
+
outputs=[tts_output, tts_message]
|
110 |
+
)
|
111 |
+
|
112 |
+
asr_button.click(
|
113 |
+
asr_interface,
|
114 |
+
inputs=[asr_input, language_selector],
|
115 |
+
outputs=[asr_output]
|
116 |
+
)
|
117 |
+
|
118 |
+
# For Hugging Face Spaces, we use `demo.launch()` without any arguments
|
119 |
+
demo.launch()
|