Spaces:
Running
Running
import gradio as gr | |
from transformers import AutomaticSpeechRecognitionPipeline, AutoProcessor, AutoModelForSpeechSeq2Seq | |
import torch | |
import torchaudio | |
# Model URLs | |
model_urls = [ | |
"kiranpantha/whisper-tiny-ne", | |
"kiranpantha/whisper-base-ne", | |
"kiranpantha/whisper-small-np", | |
"kiranpantha/whisper-medium-nepali", | |
"kiranpantha/whisper-large-v3-nepali", | |
"kiranpantha/whisper-large-v3-turbo-nepali", | |
] | |
# Mapping model names correctly | |
processor_mappings = { | |
"kiranpantha/whisper-tiny-ne": "openai/whisper-tiny", | |
"kiranpantha/whisper-base-ne": "openai/whisper-base", | |
"kiranpantha/whisper-small-np": "openai/whisper-small", | |
"kiranpantha/whisper-medium-nepali": "openai/whisper-medium", | |
"kiranpantha/whisper-large-v3-nepali": "openai/whisper-large-v3", | |
"kiranpantha/whisper-large-v3-turbo-nepali": "openai/whisper-large-v3", | |
} | |
# Cache models and processors | |
model_cache = {} | |
def load_model(model_name): | |
"""Loads and caches the model and processor with proper device management.""" | |
if model_name not in model_cache: | |
processor_name = processor_mappings.get(model_name, model_name) # Handle mapping | |
processor = AutoProcessor.from_pretrained(processor_name) | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model = AutoModelForSpeechSeq2Seq.from_pretrained(model_name).to(device) | |
model_cache[model_name] = (processor, model, device) | |
return model_cache[model_name] | |
def create_pipeline(model_name): | |
"""Creates an ASR pipeline with proper configuration.""" | |
processor, model, device = load_model(model_name) | |
return AutomaticSpeechRecognitionPipeline( | |
model=model, | |
processor=processor, | |
device=device.index if device.type == "cuda" else -1, # Ensure compatibility | |
generate_kwargs={"task": "transcribe", "language": "ne"} # "nepali" might not work | |
) | |
def process_audio(model_url, audio_chunk): | |
"""Processes audio and returns transcription with error handling.""" | |
try: | |
# Unpack audio_chunk (tuple) into audio array and sample rate | |
audio_array, sample_rate = audio_chunk | |
# Convert stereo to mono | |
if len(audio_array.shape) > 1: | |
audio_array = audio_array.mean(axis=0) | |
# Resample to 16kHz if needed | |
if sample_rate != 16000: | |
resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000) | |
audio_array = resampler(torch.tensor(audio_array).unsqueeze(0)).squeeze(0).numpy() | |
# Create pipeline and process | |
asr_pipeline = create_pipeline(model_url) | |
transcription = asr_pipeline(audio_array)["text"] | |
return transcription | |
except Exception as e: | |
return f"Error: {str(e)}" | |
# Gradio Interface | |
with gr.Blocks() as demo: | |
gr.Markdown("# Nepali Speech Recognition with Whisper Models") | |
model_dropdown = gr.Dropdown(choices=model_urls, label="Select Model", value=model_urls[0]) | |
audio_input = gr.Audio(type="numpy", label="Input Audio") | |
output_text = gr.Textbox(label="Transcription") | |
transcribe_button = gr.Button("Transcribe") | |
transcribe_button.click( | |
fn=process_audio, | |
inputs=[model_dropdown, audio_input], | |
outputs=output_text, | |
) | |
demo.launch() | |