import torchaudio from speechbrain.inference.ASR import EncoderASR from speechbrain.dataio.encoder import CTCTextEncoder from pyctcdecode import build_ctcdecoder import torch import speechbrain as sb import logging from huggingface_hub import hf_hub_download # Add this import # Set up logging logging.basicConfig(level=logging.INFO) # Download the checkpoint from Hugging Face Hub checkpoint_path = hf_hub_download( repo_id="brdhaker3/TunASR", filename="model/1234/save/CKPT+2024-05-27+00-52-30+00/wav2vec2.ckpt", # Path to your checkpoint local_dir="./", # Save it to a local directory ) logging.info(f"Checkpoint downloaded to: {checkpoint_path}") # Load the ASR model asr_model = EncoderASR.from_hparams( source="brdhaker3/TunASR", savedir = "./model" ) # Loading Custom Tokenizer encoder = CTCTextEncoder() encoder.load_or_create( path=asr_model.hparams.encoder_file, from_didatasets=[[]], output_key="char_list", special_labels={"blank_label": 0, "unk_label": 1}, sequence_input=True, ) asr_model.tokenizer = encoder # Prepare labels for the CTC decoder vocab = asr_model.tokenizer.ind2lab labels = [vocab[i] for i in range(len(vocab))] # Extract labels from the tokenizer labels = [""] + labels[1:-1] + ["1"] # Adjust labels to match CTC format # Initialize the CTC decoder with a language model decoder = build_ctcdecoder( labels, kenlm_model_path=asr_model.hparams.ngram_lm_path, # Path to your LM alpha=0.5, # LM weight beta=1.0, # Word insertion penalty ) class ASR(sb.core.Brain): def treat_wav(self, sig): """Process a waveform and return the transcribed text.""" feats = self.modules.wav2vec2(sig.to("cpu"), torch.tensor([1]).to("cpu")) feats = self.modules.enc(feats) logits = self.modules.ctc_lin(feats) p_ctc = self.hparams.log_softmax(logits) predicted_words = [] for logs in p_ctc: text = decoder.decode(logs.detach().cpu().numpy()) predicted_words.append(text.split(" ")) return " ".join(predicted_words[0]) # Initialize the ASR model asr_brain = ASR( modules=asr_model.hparams.modules, hparams=vars(asr_model.hparams), run_opts={"device": "cpu"}, checkpointer=asr_model.hparams.checkpointer, ) asr_brain.tokenizer = encoder asr_brain.checkpointer.recover_if_possible() asr_brain.modules.eval() def treat_wav_file(file_mic, file_upload, asr=asr_brain, device="cpu"): if file_mic is not None: wav = file_mic elif file_upload is not None: wav = file_upload else: return "ERROR: You have to either use the microphone or upload an audio file" # Read and preprocess the audio file info = torchaudio.info(wav) sr = info.sample_rate sig = sb.dataio.dataio.read_audio(wav) if len(sig.shape) > 1: sig = torch.mean(sig, dim=1) sig = torch.unsqueeze(sig, 0) tensor_wav = sig.to(device) resampled = torchaudio.functional.resample(tensor_wav, sr, 16000) # Transcribe the audio sentence = asr.treat_wav(resampled) return sentence # Test the function # print(treat_wav_file("./audio.wav", "./audio.wav")) #Gradio interface import gradio as gr title = "Tunisian Speech Recognition" description = ''' This is a Tunisian ASR based on the **WavLM Model**, fine-tuned on a dataset of **2.5 hours**, resulting in a **W.E.R of 24%** and a **C.E.R of 9%**. \n Interested? Try it out! ''' disclaimer = ''' > ⚠️ **Disclaimer:** > This is a **demo model**, The transcription accuracy isn't accurate due to Hugging Face model storage constraints. > For better performance,you can run the full model locally. > Please check out the repository and follow the instructions: [Full Model Repo Link](https://huggingface.co/brdhaker3/TunASR) ''' with gr.Blocks() as demo: gr.Markdown(f"# {title}") gr.Markdown(description) gr.Markdown(disclaimer) interface = gr.Interface( fn=treat_wav_file, inputs=[ gr.Audio(sources="microphone", type='filepath', label="Record"), gr.Audio(sources="upload", type='filepath', label="Upload File") ], outputs="text", title="", description="" ) demo.launch()