Spaces:
Sleeping
Sleeping
import torch | |
import jyutping | |
from whisper_audio_classifier import WhisperAudioClassifier | |
import librosa | |
from transformers import WhisperFeatureExtractor | |
feature_extractor = WhisperFeatureExtractor.from_pretrained(f"alvanlii/whisper-small-cantonese") | |
feature_extractor.chunk_length = 3 | |
# Instantiate the model | |
device = torch.device("cpu") | |
model = WhisperAudioClassifier().to(device) | |
# Load the state dict | |
state_dict = torch.load(f"whisper-small-encoder-bisyllabic-jyutping/checkpoints/model_epoch_1_step_1800.pth", map_location=device) | |
# Load the state dict into the model | |
model.load_state_dict(state_dict) | |
# Set the model to evaluation mode | |
model.eval() | |
def predict(audio): | |
features = feature_extractor(audio, sampling_rate=16000) | |
with torch.no_grad(): | |
inputs = torch.from_numpy(features['input_features'][0]).to(device) | |
inputs = inputs.unsqueeze(0) # Add extra batch dimension in front | |
outs = model(inputs) | |
return [torch.softmax(tensor.squeeze(), dim=0).tolist() for tensor in outs] | |
import gradio as gr | |
import numpy as np | |
def rank_initials(preds, k=3): | |
ranked = sorted([((jyutping.inflate_initial(i) if jyutping.inflate_initial(i) != '' else '∅'), p) for i, p in enumerate(preds)], key=lambda x: x[1], reverse=True) | |
return dict(ranked[:k]) | |
def rank_nucli(preds, k=3): | |
ranked = sorted([((jyutping.inflate_nucleus(i) if jyutping.inflate_nucleus(i) != '' else '∅'), p) for i, p in enumerate(preds)], key=lambda x: x[1], reverse=True) | |
return dict(ranked[:k]) | |
def rank_codas(preds, k=3): | |
ranked = sorted([((jyutping.inflate_coda(i) if jyutping.inflate_coda(i) != '' else '∅'), p) for i, p in enumerate(preds)], key=lambda x: x[1], reverse=True) | |
return dict(ranked[:k]) | |
def rank_tones(preds, k=3): | |
ranked = sorted([(str(i + 1), p) for i, p in enumerate(preds)], key=lambda x: x[1], reverse=True) | |
return dict(ranked[:k]) | |
def classify_audio(audio): | |
sampling_rate, audio = audio | |
audio = audio.astype(np.float32) | |
audio /= np.max(np.abs(audio)) | |
audio_resampled = librosa.resample(audio, orig_sr=sampling_rate, target_sr=16000) | |
preds = predict(torch.from_numpy(audio_resampled)) | |
return [ | |
rank_initials(preds[0]), | |
rank_nucli(preds[1]), | |
rank_codas(preds[2]), | |
rank_tones(preds[3]), | |
rank_initials(preds[4]), | |
rank_nucli(preds[5]), | |
rank_codas(preds[6]), | |
rank_tones(preds[7]), | |
] | |
with gr.Blocks() as demo: | |
with gr.Row(): | |
gr.Label("Please say a Cantonese word with exactly 2 characters, like 你好, into the microphone and click submit to see model predictions.\nNote that the predictions are not very reliable currently.") | |
with gr.Row(): | |
inputs = gr.Audio(sources=["microphone"], type="numpy", label="Input Audio") | |
submit_btn = gr.Button("Submit") | |
with gr.Row(): | |
with gr.Column(): | |
outputs_left = [ | |
gr.Label(label="Initial 1"), | |
gr.Label(label="Nucleus 1"), | |
gr.Label(label="Coda 1"), | |
gr.Label(label="Tone 1"), | |
] | |
with gr.Column(): | |
outputs_right = [ | |
gr.Label(label="Initial 2"), | |
gr.Label(label="Nucleus 2"), | |
gr.Label(label="Coda 2"), | |
gr.Label(label="Tone 2"), | |
] | |
submit_btn.click(fn=classify_audio, inputs=inputs, outputs=outputs_left+outputs_right) | |
demo.launch() | |