whisper-nepali / app.py
kiranpantha's picture
Update app.py
03c4019 verified
import gradio as gr
from transformers import AutomaticSpeechRecognitionPipeline, AutoProcessor, AutoModelForSpeechSeq2Seq
import torch
import torchaudio
# Model URLs
model_urls = [
"kiranpantha/whisper-tiny-ne",
"kiranpantha/whisper-base-ne",
"kiranpantha/whisper-small-np",
"kiranpantha/whisper-medium-nepali",
"kiranpantha/whisper-large-v3-nepali",
"kiranpantha/whisper-large-v3-turbo-nepali",
]
# Mapping model names correctly
processor_mappings = {
"kiranpantha/whisper-tiny-ne": "openai/whisper-tiny",
"kiranpantha/whisper-base-ne": "openai/whisper-base",
"kiranpantha/whisper-small-np": "openai/whisper-small",
"kiranpantha/whisper-medium-nepali": "openai/whisper-medium",
"kiranpantha/whisper-large-v3-nepali": "openai/whisper-large-v3",
"kiranpantha/whisper-large-v3-turbo-nepali": "openai/whisper-large-v3",
}
# Cache models and processors
model_cache = {}
def load_model(model_name):
"""Loads and caches the model and processor with proper device management."""
if model_name not in model_cache:
processor_name = processor_mappings.get(model_name, model_name) # Handle mapping
processor = AutoProcessor.from_pretrained(processor_name)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = AutoModelForSpeechSeq2Seq.from_pretrained(model_name).to(device)
model_cache[model_name] = (processor, model, device)
return model_cache[model_name]
def create_pipeline(model_name):
"""Creates an ASR pipeline with proper configuration."""
processor, model, device = load_model(model_name)
return AutomaticSpeechRecognitionPipeline(
model=model,
processor=processor,
device=device.index if device.type == "cuda" else -1, # Ensure compatibility
generate_kwargs={"task": "transcribe", "language": "ne"} # "nepali" might not work
)
def process_audio(model_url, audio_chunk):
"""Processes audio and returns transcription with error handling."""
try:
# Unpack audio_chunk (tuple) into audio array and sample rate
audio_array, sample_rate = audio_chunk
# Convert stereo to mono
if len(audio_array.shape) > 1:
audio_array = audio_array.mean(axis=0)
# Resample to 16kHz if needed
if sample_rate != 16000:
resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
audio_array = resampler(torch.tensor(audio_array).unsqueeze(0)).squeeze(0).numpy()
# Create pipeline and process
asr_pipeline = create_pipeline(model_url)
transcription = asr_pipeline(audio_array)["text"]
return transcription
except Exception as e:
return f"Error: {str(e)}"
# Gradio Interface
with gr.Blocks() as demo:
gr.Markdown("# Nepali Speech Recognition with Whisper Models")
model_dropdown = gr.Dropdown(choices=model_urls, label="Select Model", value=model_urls[0])
audio_input = gr.Audio(type="numpy", label="Input Audio")
output_text = gr.Textbox(label="Transcription")
transcribe_button = gr.Button("Transcribe")
transcribe_button.click(
fn=process_audio,
inputs=[model_dropdown, audio_input],
outputs=output_text,
)
demo.launch()