File size: 1,954 Bytes
e18a750
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import numpy as np

import skorch
import torch
import torch.nn as nn

import gradio as gr

import librosa

from joblib import dump, load

from sklearn.pipeline import Pipeline
from sklearn.preprocessing import LabelEncoder

from resnet import ResNet
from gradio_utils import load_as_librosa, predict_gradio
from dataloading import uniformize, to_numpy
from preprocessing import MfccTransformer, TorchTransform


SEED : int = 42
np.random.seed(SEED)
torch.manual_seed(SEED)

model = load('./model/model.joblib') 
only_mffc_transform = load('./model/only_mffc_transform.joblib') 
label_encoder = load('./model/label_encoder.joblib') 
SAMPLE_RATE = load("./model/SAMPLE_RATE.joblib")
METHOD = load("./model/METHOD.joblib")
MAX_TIME = load("./model/MAX_TIME.joblib")
N_MFCC = load("./model/N_MFCC.joblib")
HOP_LENGHT = load("./model/HOP_LENGHT.joblib")

sklearn_model = Pipeline(
            steps=[
                ("mfcc", only_mffc_transform),
                ("model", model)
            ]
        )

uniform_lambda = lambda y, sr: uniformize(y, sr, METHOD, MAX_TIME)

title = r"ResNet 9"

description = r"""
<center>
The resnet9 model was trained to classify drone speech command.
<img src="http://zeus.blanchon.cc/dropshare/modia.png" width=200px>
</center>
"""
article = r"""
- [Deep Residual Learning for Image Recognition](https://arxiv.org/pdf/1512.03385)
"""

demo_men = gr.Interface(
    title = title,
    description = description,
    article = article, 
    fn=lambda data: predict_gradio(
        data=data, 
        uniform_lambda=uniform_lambda, 
        sklearn_model=sklearn_model,
        label_transform=label_encoder,
        target_sr=SAMPLE_RATE),
    inputs = gr.Audio(source="microphone", type="numpy"),
    outputs = gr.Label(),
    # allow_flagging = "manual",
    # flagging_options = ['recule', 'tournedroite', 'arretetoi', 'tournegauche', 'gauche', 'avance', 'droite'],
    # flagging_dir = "./flag/men"
)

demo_men.launch()