cantone / app.py
AlienKevin's picture
Update to work with latest gradio 4.19.0
cea00e9
import torch
import jyutping
from whisper_audio_classifier import WhisperAudioClassifier
import librosa
from transformers import WhisperFeatureExtractor
feature_extractor = WhisperFeatureExtractor.from_pretrained(f"alvanlii/whisper-small-cantonese")
feature_extractor.chunk_length = 3
# Instantiate the model
device = torch.device("cpu")
model = WhisperAudioClassifier().to(device)
# Load the state dict
state_dict = torch.load(f"whisper-small-encoder-bisyllabic-jyutping/checkpoints/model_epoch_1_step_1800.pth", map_location=device)
# Load the state dict into the model
model.load_state_dict(state_dict)
# Set the model to evaluation mode
model.eval()
def predict(audio):
features = feature_extractor(audio, sampling_rate=16000)
with torch.no_grad():
inputs = torch.from_numpy(features['input_features'][0]).to(device)
inputs = inputs.unsqueeze(0) # Add extra batch dimension in front
outs = model(inputs)
return [torch.softmax(tensor.squeeze(), dim=0).tolist() for tensor in outs]
import gradio as gr
import numpy as np
def rank_initials(preds, k=3):
ranked = sorted([((jyutping.inflate_initial(i) if jyutping.inflate_initial(i) != '' else '∅'), p) for i, p in enumerate(preds)], key=lambda x: x[1], reverse=True)
return dict(ranked[:k])
def rank_nucli(preds, k=3):
ranked = sorted([((jyutping.inflate_nucleus(i) if jyutping.inflate_nucleus(i) != '' else '∅'), p) for i, p in enumerate(preds)], key=lambda x: x[1], reverse=True)
return dict(ranked[:k])
def rank_codas(preds, k=3):
ranked = sorted([((jyutping.inflate_coda(i) if jyutping.inflate_coda(i) != '' else '∅'), p) for i, p in enumerate(preds)], key=lambda x: x[1], reverse=True)
return dict(ranked[:k])
def rank_tones(preds, k=3):
ranked = sorted([(str(i + 1), p) for i, p in enumerate(preds)], key=lambda x: x[1], reverse=True)
return dict(ranked[:k])
def classify_audio(audio):
sampling_rate, audio = audio
audio = audio.astype(np.float32)
audio /= np.max(np.abs(audio))
audio_resampled = librosa.resample(audio, orig_sr=sampling_rate, target_sr=16000)
preds = predict(torch.from_numpy(audio_resampled))
return [
rank_initials(preds[0]),
rank_nucli(preds[1]),
rank_codas(preds[2]),
rank_tones(preds[3]),
rank_initials(preds[4]),
rank_nucli(preds[5]),
rank_codas(preds[6]),
rank_tones(preds[7]),
]
with gr.Blocks() as demo:
with gr.Row():
gr.Label("Please say a Cantonese word with exactly 2 characters, like 你好, into the microphone and click submit to see model predictions.\nNote that the predictions are not very reliable currently.")
with gr.Row():
inputs = gr.Audio(sources=["microphone"], type="numpy", label="Input Audio")
submit_btn = gr.Button("Submit")
with gr.Row():
with gr.Column():
outputs_left = [
gr.Label(label="Initial 1"),
gr.Label(label="Nucleus 1"),
gr.Label(label="Coda 1"),
gr.Label(label="Tone 1"),
]
with gr.Column():
outputs_right = [
gr.Label(label="Initial 2"),
gr.Label(label="Nucleus 2"),
gr.Label(label="Coda 2"),
gr.Label(label="Tone 2"),
]
submit_btn.click(fn=classify_audio, inputs=inputs, outputs=outputs_left+outputs_right)
demo.launch()