bambara-asr / app.py
oza75's picture
switch back to v002
abdd451
raw
history blame contribute delete
No virus
2.51 kB
import os
import spaces
import torch
from transformers import pipeline
import gradio as gr
# Please note that the below import will override whisper LANGUAGES to add bambara
# this is not the best way to do it but at least it works. for more info check the bambara_utils code
from bambara_utils import BambaraWhisperTokenizer
# Determine the appropriate device (GPU or CPU)
device = "cuda" if torch.cuda.is_available() else "cpu"
# Define the model checkpoint and language
model_checkpoint = "oza75/whisper-bambara-asr-002"
revision = "831cd15ed74a554caac9f304cf50dc773841ba1b"
# model_checkpoint = "oza75/whisper-bambara-asr-001"
# revision = "3578bcb14a42a5d2c58a436fb2c38341898e7885"
language = "bambara"
# Load the custom tokenizer designed for Bambara and the ASR model
tokenizer = BambaraWhisperTokenizer.from_pretrained(model_checkpoint, language=language, device=device)
pipe = pipeline(model=model_checkpoint, tokenizer=tokenizer, device=device, revision=revision)
@spaces.GPU()
def transcribe(audio):
"""
Transcribes the provided audio file into text using the configured ASR pipeline.
Args:
audio: The path to the audio file to transcribe.
Returns:
A string representing the transcribed text.
"""
# Use the pipeline to perform transcription
text = pipe(audio)["text"]
return text
def get_wav_files(directory):
"""
Returns a list of absolute paths to all .wav files in the specified directory.
Args:
directory (str): The directory to search for .wav files.
Returns:
list: A list of absolute paths to the .wav files.
"""
# List all files in the directory
files = os.listdir(directory)
# Filter for .wav files and create absolute paths
wav_files = [os.path.abspath(os.path.join(directory, file)) for file in files if file.endswith('.wav')]
return wav_files
def main():
# Get a list of all .wav files in the examples directory
example_files = get_wav_files("./examples")
# Setup Gradio interface
iface = gr.Interface(
fn=transcribe,
inputs=gr.Audio(type="filepath", value=example_files[0]),
outputs="text",
title="Bambara Automatic Speech Recognition",
description="Realtime demo for Bambara speech recognition based on a fine-tuning of the Whisper model.",
examples=example_files,
cache_examples="lazy",
)
# Launch the interface
iface.launch(share=False)
if __name__ == "__main__":
main()