import sys sys.path.append("..") import gradio import torch, torchaudio import numpy as np from transformers import ( Wav2Vec2ForPreTraining, Wav2Vec2CTCTokenizer, Wav2Vec2FeatureExtractor, ) from finetuning.wav2vec2 import SpeechRecognizer def load_model(ckpt_path: str): model_name = "nguyenvulebinh/wav2vec2-base-vietnamese-250h" wav2vec2 = Wav2Vec2ForPreTraining.from_pretrained(model_name) tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(model_name) feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(model_name) model = SpeechRecognizer.load_from_checkpoint( ckpt_path, wav2vec2=wav2vec2, tokenizer=tokenizer, feature_extractor=feature_extractor, map_location='cpu' ) return model model = load_model("checkpoints/last.ckpt") model.eval() def transcribe(audio): sample_rate, waveform = audio if len(waveform.shape) == 2: waveform = waveform[:, 0] waveform = torch.from_numpy(waveform).float().unsqueeze_(0) waveform = torchaudio.functional.resample(waveform, sample_rate, 16_000) transcript = model.predict(waveform)[0] return transcript gradio.Interface(fn=transcribe, inputs=gradio.Audio(source="microphone", type="numpy"), outputs="textbox").launch()