Spaces:
Sleeping
Sleeping
File size: 2,805 Bytes
4cc3c9c 7b1a576 4cc3c9c 7b1a576 d7c7caa 4cc3c9c 7b1a576 4cc3c9c 7b1a576 4cc3c9c 7b1a576 4cc3c9c 7b1a576 f618a35 7b1a576 4cc3c9c 7b1a576 4cc3c9c 7b1a576 4cc3c9c 7b1a576 4cc3c9c 7b1a576 4cc3c9c 7b1a576 4cc3c9c 7b1a576 4cc3c9c 7b1a576 5e021b3 4cc3c9c 7b1a576 4cc3c9c 7b1a576 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 |
import gradio as gr
import torch
import torchaudio
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
from huggingface_hub import InferenceClient
from ttsmms import download, TTS
from langdetect import detect
# Load ASR Model
asr_model_name = "Futuresony/Future-sw_ASR-24-02-2025"
processor = Wav2Vec2Processor.from_pretrained(asr_model_name)
asr_model = Wav2Vec2ForCTC.from_pretrained(asr_model_name)
# Load Text Generation Model
client = InferenceClient("Futuresony/future_ai_12_10_2024.gguf")
def format_prompt(user_input):
return f"### User: {user_input}\n### Assistant:"
# Load TTS Models
swahili_dir = download("swh", "./data/swahili")
english_dir = download("eng", "./data/english")
swahili_tts = TTS(swahili_dir)
english_tts = TTS(english_dir)
# ASR Function
def transcribe(audio_file):
speech_array, sample_rate = torchaudio.load(audio_file)
resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
speech_array = resampler(speech_array).squeeze().numpy()
input_values = processor(speech_array, sampling_rate=16000, return_tensors="pt").input_values
with torch.no_grad():
logits = asr_model(input_values).logits
predicted_ids = torch.argmax(logits, dim=-1)
transcription = processor.batch_decode(predicted_ids)[0]
return transcription
# Text Generation Function
def generate_text(prompt):
formatted_prompt = format_prompt(prompt)
response = client.text_generation(formatted_prompt, max_new_tokens=250, temperature=0.7, top_p=0.95)
return response.strip()
# TTS Function
def text_to_speech(text):
lang = detect(text)
wav_path = "./output.wav"
if lang == "sw":
swahili_tts.synthesis(text, wav_path=wav_path)
else:
english_tts.synthesis(text, wav_path=wav_path)
return wav_path
# Combined Processing Function
def process_audio(audio):
transcription = transcribe(audio)
generated_text = generate_text(transcription)
speech = text_to_speech(generated_text)
return transcription, generated_text, speech
# Gradio Interface
with gr.Blocks() as demo:
gr.Markdown("<p align='center' style='font-size: 20px;'>End-to-End ASR, Text Generation, and TTS</p>")
gr.HTML("<center>Upload or record audio. The model will transcribe, generate a response, and read it out.</center>")
audio_input = gr.Audio(label="Input Audio", type="filepath")
text_output = gr.Textbox(label="Transcription")
generated_text_output = gr.Textbox(label="Generated Text")
audio_output = gr.Audio(label="Output Speech")
submit_btn = gr.Button("Submit")
submit_btn.click(
fn=process_audio,
inputs=audio_input,
outputs=[text_output, generated_text_output, audio_output]
)
if __name__ == "__main__":
demo.launch()
|