File size: 3,478 Bytes
b256b6f
 
 
 
 
 
 
 
 
 
758aaaa
b256b6f
 
 
e9e9b11
b256b6f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cea00e9
 
 
 
b256b6f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import torch
import jyutping
from whisper_audio_classifier import WhisperAudioClassifier
import librosa
from transformers import WhisperFeatureExtractor

feature_extractor = WhisperFeatureExtractor.from_pretrained(f"alvanlii/whisper-small-cantonese")
feature_extractor.chunk_length = 3

# Instantiate the model
device = torch.device("cpu")
model = WhisperAudioClassifier().to(device)

# Load the state dict
state_dict = torch.load(f"whisper-small-encoder-bisyllabic-jyutping/checkpoints/model_epoch_1_step_1800.pth", map_location=device)

# Load the state dict into the model
model.load_state_dict(state_dict)

# Set the model to evaluation mode
model.eval()

def predict(audio):
    features = feature_extractor(audio, sampling_rate=16000)
    with torch.no_grad():
        inputs = torch.from_numpy(features['input_features'][0]).to(device)
        inputs = inputs.unsqueeze(0)  # Add extra batch dimension in front
        outs = model(inputs)
        return [torch.softmax(tensor.squeeze(), dim=0).tolist() for tensor in outs]

import gradio as gr
import numpy as np

def rank_initials(preds, k=3):
    ranked = sorted([((jyutping.inflate_initial(i) if jyutping.inflate_initial(i) != '' else '∅'), p) for i, p in enumerate(preds)], key=lambda x: x[1], reverse=True)
    return dict(ranked[:k])

def rank_nucli(preds, k=3):
    ranked = sorted([((jyutping.inflate_nucleus(i) if jyutping.inflate_nucleus(i) != '' else '∅'), p) for i, p in enumerate(preds)], key=lambda x: x[1], reverse=True)
    return dict(ranked[:k])

def rank_codas(preds, k=3):
    ranked = sorted([((jyutping.inflate_coda(i) if jyutping.inflate_coda(i) != '' else '∅'), p) for i, p in enumerate(preds)], key=lambda x: x[1], reverse=True)
    return dict(ranked[:k])

def rank_tones(preds, k=3):
    ranked = sorted([(str(i + 1), p) for i, p in enumerate(preds)], key=lambda x: x[1], reverse=True)
    return dict(ranked[:k])

def classify_audio(audio):
    sampling_rate, audio = audio
    audio = audio.astype(np.float32)
    audio /= np.max(np.abs(audio))
    audio_resampled = librosa.resample(audio, orig_sr=sampling_rate, target_sr=16000)
    preds = predict(torch.from_numpy(audio_resampled))
    return [
        rank_initials(preds[0]),
        rank_nucli(preds[1]),
        rank_codas(preds[2]),
        rank_tones(preds[3]),
        rank_initials(preds[4]),
        rank_nucli(preds[5]),
        rank_codas(preds[6]),
        rank_tones(preds[7]),
    ]

with gr.Blocks() as demo:
    with gr.Row():
        gr.Label("Please say a Cantonese word with exactly 2 characters, like 你好, into the microphone and click submit to see model predictions.\nNote that the predictions are not very reliable currently.")

    with gr.Row():
        inputs = gr.Audio(sources=["microphone"], type="numpy", label="Input Audio")
        submit_btn = gr.Button("Submit")

    with gr.Row():
        with gr.Column():
            outputs_left = [
                gr.Label(label="Initial 1"),
                gr.Label(label="Nucleus 1"),
                gr.Label(label="Coda 1"),
                gr.Label(label="Tone 1"),
            ]
            
        with gr.Column():
            outputs_right = [
                gr.Label(label="Initial 2"),
                gr.Label(label="Nucleus 2"),
                gr.Label(label="Coda 2"),
                gr.Label(label="Tone 2"),
            ]
    
    submit_btn.click(fn=classify_audio, inputs=inputs, outputs=outputs_left+outputs_right)

demo.launch()