Futuresony commited on
Commit
5d421f3
·
verified ·
1 Parent(s): c60bd76

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +81 -0
app.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torchaudio
4
+ from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
5
+ from huggingface_hub import InferenceClient
6
+ from ttsmms import download, TTS
7
+ from langdetect import detect
8
+
9
+ # Load ASR Model
10
+ asr_model_name = "Futuresony/Future-sw_ASR-24-02-2025"
11
+ processor = Wav2Vec2Processor.from_pretrained(asr_model_name)
12
+ asr_model = Wav2Vec2ForCTC.from_pretrained(asr_model_name)
13
+
14
+ # Load Text Generation Model
15
+ client = InferenceClient("Futuresony/future_ai_12_10_2024.gguf")
16
+
17
+ def format_prompt(user_input):
18
+ return f"{user_input}"
19
+
20
+ # Load TTS Models
21
+ swahili_dir = download("swh", "./data/swahili")
22
+ english_dir = download("eng", "./data/english")
23
+
24
+ swahili_tts = TTS(swahili_dir)
25
+ english_tts = TTS(english_dir)
26
+
27
+ # ASR Function
28
+ def transcribe(audio_file):
29
+ speech_array, sample_rate = torchaudio.load(audio_file)
30
+ resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
31
+ speech_array = resampler(speech_array).squeeze().numpy()
32
+ input_values = processor(speech_array, sampling_rate=16000, return_tensors="pt").input_values
33
+ with torch.no_grad():
34
+ logits = asr_model(input_values).logits
35
+ predicted_ids = torch.argmax(logits, dim=-1)
36
+ transcription = processor.batch_decode(predicted_ids)[0]
37
+ return transcription
38
+
39
+ # Text Generation Function
40
+ def generate_text(prompt):
41
+ formatted_prompt = format_prompt(prompt)
42
+ response = client.text_generation(formatted_prompt, max_new_tokens=250, temperature=0.7, top_p=0.95)
43
+ return response.strip()
44
+
45
+ # TTS Function
46
+ def text_to_speech(text):
47
+ lang = detect(text)
48
+ wav_path = "./output.wav"
49
+ if lang == "sw":
50
+ swahili_tts.synthesis(text, wav_path=wav_path)
51
+ else:
52
+ english_tts.synthesis(text, wav_path=wav_path)
53
+ return wav_path
54
+
55
+ # Combined Processing Function
56
+ def process_audio(audio):
57
+ transcription = transcribe(audio)
58
+ generated_text = generate_text(transcription)
59
+ speech = text_to_speech(generated_text)
60
+ return transcription, generated_text, speech
61
+
62
+ # Gradio Interface
63
+ with gr.Blocks() as demo:
64
+ gr.Markdown("<p align='center' style='font-size: 20px;'>End-to-End ASR, Text Generation, and TTS</p>")
65
+ gr.HTML("<center>Upload or record audio. The model will transcribe, generate a response, and read it out.</center>")
66
+
67
+ audio_input = gr.Audio(label="Input Audio", type="filepath")
68
+ text_output = gr.Textbox(label="Transcription")
69
+ generated_text_output = gr.Textbox(label="Generated Text")
70
+ audio_output = gr.Audio(label="Output Speech")
71
+ submit_btn = gr.Button("Submit")
72
+
73
+ submit_btn.click(
74
+ fn=process_audio,
75
+ inputs=audio_input,
76
+ outputs=[text_output, generated_text_output, audio_output]
77
+ )
78
+
79
+ if __name__ == "__main__":
80
+ demo.launch()
81
+