from typing import List import torch import argparse import shutil import tempfile from speechbrain.pretrained import EncoderDecoderASR def asr_model_inference(model: EncoderDecoderASR, audios: List[str]) -> List[str]: """ convert input audio to words and return the result """ tmp_dir = tempfile.mkdtemp() results = [process_audio(model, audio, tmp_dir) for audio in audios] shutil.rmtree(tmp_dir) return results def process_audio(model: EncoderDecoderASR, audio: str, savedir:str) -> str: """ convert input audio to words and return the result """ waveform = model.load_audio(audio, savedir=savedir) # Fake a batch: batch = waveform.unsqueeze(0) rel_length = torch.tensor([1.0]) predicted_words, predicted_tokens = model.transcribe_batch( batch, rel_length ) return predicted_words[0] if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("-I", dest="audio_file", required=True) args = parser.parse_args() asr_model = EncoderDecoderASR.from_hparams( source="./inference", hparams_file="hyperparams.yaml", savedir="inference", run_opts={"device": "cpu"}) print(asr_model_inference(asr_model, [args.audio_file]))