import os import spaces import torch from transformers import pipeline, WhisperTokenizer import torchaudio import gradio as gr # Please note that the below import will override whisper LANGUAGES to add bambara # this is not the best way to do it but at least it works. for more info check the bambara_utils code from bambara_utils import BambaraWhisperTokenizer # Determine the appropriate device (GPU or CPU) device = "cuda" if torch.cuda.is_available() else "cpu" # Define the model checkpoint and language #model_checkpoint = "oza75/whisper-bambara-asr-002" #revision = "831cd15ed74a554caac9f304cf50dc773841ba1b" model_checkpoint = "oza75/whisper-bambara-asr-005" revision = "6a92cd0f19985d12739c2f6864607627115e015d" # first good checkpoint for bambara #revision = "fb69a5750182933868397543366dbb63747cf40c" # this only translate in english #revision = "129f9e68ead6cc854e7754b737b93aa78e0e61e1" # support transcription and translation #revision = "cb8e351b35d6dc524066679d9646f4a947300b27" #revision = "5f143f6070b64412a44fea08e912e1b7312e9ae9" # this checkpoint support both task without overfitting #model_checkpoint = "oza75/whisper-bambara-asr-006" #revision = "96535debb4ce0b7af7c9c186d09d088825f63840" #revision = "4549778c08f29ed2e033cc9a497a187488b6bf56" # language = "bambara" language = "icelandic" # we use icelandic as the model was trained to replace the icelandic with bambara. # Load the custom tokenizer designed for Bambara and the ASR model #tokenizer = BambaraWhisperTokenizer.from_pretrained(model_checkpoint, language=language, device=device) tokenizer = WhisperTokenizer.from_pretrained(model_checkpoint, language=language, device=device) pipe = pipeline(model=model_checkpoint, tokenizer=tokenizer, device=device, revision=revision) def resample_audio(audio_path, target_sample_rate=16000): """ Converts the audio file to the target sampling rate (16000 Hz). Args: audio_path (str): Path to the audio file. target_sample_rate (int): The desired sample rate. Returns: A tensor containing the resampled audio data and the target sample rate. """ waveform, original_sample_rate = torchaudio.load(audio_path) if original_sample_rate != target_sample_rate: resampler = torchaudio.transforms.Resample(orig_freq=original_sample_rate, new_freq=target_sample_rate) waveform = resampler(waveform) return waveform, target_sample_rate @spaces.GPU() def transcribe(audio, task_type): """ Transcribes the provided audio file into text using the configured ASR pipeline. Args: audio: The path to the audio file to transcribe. Returns: A string representing the transcribed text. """ # Convert the audio to 16000 Hz waveform, sample_rate = resample_audio(audio) # Use the pipeline to perform transcription sample = {"array": waveform.squeeze().numpy(), "sampling_rate": sample_rate} text = pipe(sample, generate_kwargs={"task": task_type, "language": language})["text"] return text def get_wav_files(directory): """ Returns a list of absolute paths to all .wav files in the specified directory. Args: directory (str): The directory to search for .wav files. Returns: list: A list of absolute paths to the .wav files. """ # List all files in the directory files = os.listdir(directory) # Filter for .wav files and create absolute paths wav_files = [os.path.abspath(os.path.join(directory, file)) for file in files if file.endswith('.wav')] wav_files = [[f, "transcribe"] for f in wav_files] return wav_files def main(): # Get a list of all .wav files in the examples directory example_files = get_wav_files("./examples") # Setup Gradio interface iface = gr.Interface( fn=transcribe, inputs=[ gr.Audio(type="filepath", value=example_files[0][0]), gr.Radio(choices=["transcribe"], label="Task Type", value="transcribe") ], outputs="text", title="Bambara Automatic Speech Recognition", description="Realtime demo for Bambara speech recognition based on a fine-tuning of the Whisper model.", examples=example_files, cache_examples="lazy", ) # Launch the interface iface.launch(share=False) if __name__ == "__main__": main()