File size: 2,805 Bytes
4cc3c9c
7b1a576
 
 
4cc3c9c
 
 
 
7b1a576
 
 
 
 
 
d7c7caa
4cc3c9c
7b1a576
 
 
 
4cc3c9c
7b1a576
4cc3c9c
 
 
 
7b1a576
 
 
 
 
 
 
 
 
 
 
 
 
4cc3c9c
7b1a576
f618a35
7b1a576
4cc3c9c
7b1a576
4cc3c9c
7b1a576
4cc3c9c
7b1a576
4cc3c9c
7b1a576
4cc3c9c
 
 
7b1a576
4cc3c9c
7b1a576
4cc3c9c
 
 
 
 
 
 
 
7b1a576
5e021b3
4cc3c9c
 
 
 
7b1a576
4cc3c9c
 
 
 
 
 
 
7b1a576
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import gradio as gr
import torch
import torchaudio
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
from huggingface_hub import InferenceClient
from ttsmms import download, TTS
from langdetect import detect

# Load ASR Model
asr_model_name = "Futuresony/Future-sw_ASR-24-02-2025"
processor = Wav2Vec2Processor.from_pretrained(asr_model_name)
asr_model = Wav2Vec2ForCTC.from_pretrained(asr_model_name)

# Load Text Generation Model
client = InferenceClient("Futuresony/future_ai_12_10_2024.gguf")

def format_prompt(user_input):
    return f"### User: {user_input}\n### Assistant:"  

# Load TTS Models
swahili_dir = download("swh", "./data/swahili")
english_dir = download("eng", "./data/english")

swahili_tts = TTS(swahili_dir)
english_tts = TTS(english_dir)

# ASR Function
def transcribe(audio_file):
    speech_array, sample_rate = torchaudio.load(audio_file)
    resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
    speech_array = resampler(speech_array).squeeze().numpy()
    input_values = processor(speech_array, sampling_rate=16000, return_tensors="pt").input_values
    with torch.no_grad():
        logits = asr_model(input_values).logits
    predicted_ids = torch.argmax(logits, dim=-1)
    transcription = processor.batch_decode(predicted_ids)[0]
    return transcription

# Text Generation Function
def generate_text(prompt):
    formatted_prompt = format_prompt(prompt)
    response = client.text_generation(formatted_prompt, max_new_tokens=250, temperature=0.7, top_p=0.95)
    return response.strip()

# TTS Function
def text_to_speech(text):
    lang = detect(text)
    wav_path = "./output.wav"
    if lang == "sw":
        swahili_tts.synthesis(text, wav_path=wav_path)
    else:
        english_tts.synthesis(text, wav_path=wav_path)
    return wav_path

# Combined Processing Function
def process_audio(audio):
    transcription = transcribe(audio)
    generated_text = generate_text(transcription)
    speech = text_to_speech(generated_text)
    return transcription, generated_text, speech

# Gradio Interface
with gr.Blocks() as demo:
    gr.Markdown("<p align='center' style='font-size: 20px;'>End-to-End ASR, Text Generation, and TTS</p>")
    gr.HTML("<center>Upload or record audio. The model will transcribe, generate a response, and read it out.</center>")
    
    audio_input = gr.Audio(label="Input Audio", type="filepath")
    text_output = gr.Textbox(label="Transcription")
    generated_text_output = gr.Textbox(label="Generated Text")
    audio_output = gr.Audio(label="Output Speech")
    submit_btn = gr.Button("Submit")
    
    submit_btn.click(
        fn=process_audio,
        inputs=audio_input,
        outputs=[text_output, generated_text_output, audio_output]
    )

if __name__ == "__main__":
    demo.launch()