import torch import jyutping from whisper_audio_classifier import WhisperAudioClassifier import librosa from transformers import WhisperFeatureExtractor feature_extractor = WhisperFeatureExtractor.from_pretrained(f"alvanlii/whisper-small-cantonese") feature_extractor.chunk_length = 3 # Instantiate the model device = torch.device("cpu") model = WhisperAudioClassifier().to(device) # Load the state dict state_dict = torch.load(f"whisper-small-encoder-bisyllabic-jyutping/checkpoints/model_epoch_1_step_1800.pth", map_location=device) # Load the state dict into the model model.load_state_dict(state_dict) # Set the model to evaluation mode model.eval() def predict(audio): features = feature_extractor(audio, sampling_rate=16000) with torch.no_grad(): inputs = torch.from_numpy(features['input_features'][0]).to(device) inputs = inputs.unsqueeze(0) # Add extra batch dimension in front outs = model(inputs) return [torch.softmax(tensor.squeeze(), dim=0).tolist() for tensor in outs] import gradio as gr import numpy as np def rank_initials(preds, k=3): ranked = sorted([((jyutping.inflate_initial(i) if jyutping.inflate_initial(i) != '' else '∅'), p) for i, p in enumerate(preds)], key=lambda x: x[1], reverse=True) return dict(ranked[:k]) def rank_nucli(preds, k=3): ranked = sorted([((jyutping.inflate_nucleus(i) if jyutping.inflate_nucleus(i) != '' else '∅'), p) for i, p in enumerate(preds)], key=lambda x: x[1], reverse=True) return dict(ranked[:k]) def rank_codas(preds, k=3): ranked = sorted([((jyutping.inflate_coda(i) if jyutping.inflate_coda(i) != '' else '∅'), p) for i, p in enumerate(preds)], key=lambda x: x[1], reverse=True) return dict(ranked[:k]) def rank_tones(preds, k=3): ranked = sorted([(str(i + 1), p) for i, p in enumerate(preds)], key=lambda x: x[1], reverse=True) return dict(ranked[:k]) def classify_audio(audio): sampling_rate, audio = audio audio = audio.astype(np.float32) audio /= np.max(np.abs(audio)) audio_resampled = librosa.resample(audio, orig_sr=sampling_rate, target_sr=16000) preds = predict(torch.from_numpy(audio_resampled)) return [ rank_initials(preds[0]), rank_nucli(preds[1]), rank_codas(preds[2]), rank_tones(preds[3]), rank_initials(preds[4]), rank_nucli(preds[5]), rank_codas(preds[6]), rank_tones(preds[7]), ] with gr.Blocks() as demo: with gr.Row(): gr.Label("Please say a Cantonese word with exactly 2 characters, like 你好, into the microphone and click submit to see model predictions.\nNote that the predictions are not very reliable currently.") with gr.Row(): inputs = gr.Audio(sources=["microphone"], type="numpy", label="Input Audio") submit_btn = gr.Button("Submit") with gr.Row(): with gr.Column(): outputs_left = [ gr.Label(label="Initial 1"), gr.Label(label="Nucleus 1"), gr.Label(label="Coda 1"), gr.Label(label="Tone 1"), ] with gr.Column(): outputs_right = [ gr.Label(label="Initial 2"), gr.Label(label="Nucleus 2"), gr.Label(label="Coda 2"), gr.Label(label="Tone 2"), ] submit_btn.click(fn=classify_audio, inputs=inputs, outputs=outputs_left+outputs_right) demo.launch()