File size: 5,963 Bytes
09aaa9c
 
 
 
 
 
 
 
4485862
3ce86ea
09aaa9c
73b065a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5ed82c5
73b065a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
09aaa9c
f93d945
4485862
603d981
5ed82c5
 
 
3ce86ea
 
 
 
09aaa9c
603d981
 
3ce86ea
 
603d981
5ed82c5
 
 
 
3ce86ea
09aaa9c
f93d945
 
5ed82c5
3ce86ea
4485862
a9b6797
f93d945
a6d5ae5
a9b6797
f93d945
5ed82c5
 
 
 
 
f93d945
5ed82c5
 
 
 
 
 
 
 
 
 
 
 
09aaa9c
a9b6797
 
 
5ed82c5
 
a9b6797
 
 
5ed82c5
 
 
 
 
 
a9b6797
3ce86ea
 
5ed82c5
2c19de2
 
 
c779966
3ce86ea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2c19de2
5ed82c5
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import gradio as gr
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import WhisperModel, WhisperFeatureExtractor
import datasets
from datasets import load_dataset, DatasetDict, Audio
from huggingface_hub import PyTorchModelHubMixin
import numpy as np
import librosa

# Ensure you have the device setup (cuda or cpu)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define the config for your model
config = {"encoder": "openai/whisper-base", "num_labels": 2}

# Define data class
class SpeechInferenceDataset(Dataset):
    def __init__(self, audio_data, text_processor):
        self.audio_data = audio_data
        self.text_processor = text_processor

    def __len__(self):
        return len(self.audio_data)

    def __getitem__(self, index):
        inputs = self.text_processor(self.audio_data[index]["audio"]["array"],
                                     return_tensors="pt",
                                     sampling_rate=self.audio_data[index]["audio"]["sampling_rate"])
        input_features = inputs.input_features.squeeze(0)
        decoder_input_ids = torch.tensor([[1, 1]])  # Modify as per your model's requirements
        return input_features, decoder_input_ids

# Define model class
class SpeechClassifier(nn.Module, PyTorchModelHubMixin):
    def __init__(self, config):
        super(SpeechClassifier, self).__init__()
        self.encoder = WhisperModel.from_pretrained(config["encoder"])
        self.classifier = nn.Sequential(
            nn.Linear(self.encoder.config.hidden_size, 4096),
            nn.ReLU(),
            nn.Linear(4096, 2048),
            nn.ReLU(),
            nn.Linear(2048, 1024),
            nn.ReLU(),
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Linear(512, config["num_labels"])
        )

    def forward(self, input_features, decoder_input_ids):
        outputs = self.encoder(input_features, decoder_input_ids=decoder_input_ids)
        pooled_output = outputs['last_hidden_state'][:, 0, :]
        logits = self.classifier(pooled_output)
        return logits

# Prepare data function
def prepare_data(audio_data, sampling_rate, model_checkpoint="openai/whisper-base"):

    # Convert audio data to float32
    audio_data = audio_data.astype(np.float32)
    
    # Resample audio data to 16000 Hz
    audio_data_resampled = librosa.resample(audio_data, orig_sr=sampling_rate, target_sr=16000)

    # Initialize the feature extractor
    feature_extractor = WhisperFeatureExtractor.from_pretrained(model_checkpoint)

    # Use Dataset class
    dataset = SpeechInferenceDataset([{"audio": {"array": audio_data_resampled, "sampling_rate": 16000}}],
                                     text_processor=feature_extractor)

    dataloader = DataLoader(dataset, batch_size=1)

    return dataloader
    # return dataset


# Prediction function
def predict(audio_data, sampling_rate, config):
    dataloader = prepare_data(audio_data, sampling_rate, config["encoder"])

    model = SpeechClassifier(config).to(device)
    # Here we load the model from Hugging Face Hub
    model.load_state_dict(torch.hub.load_state_dict_from_url("https://huggingface.co/jcho02/whisper_cleft/resolve/main/pytorch_model.bin", map_location=device))

    model.eval()
    with torch.no_grad():
        for input_features, decoder_input_ids in dataloader:
            input_features = input_features.to(device)
            decoder_input_ids = decoder_input_ids.to(device)
            logits = model(input_features, decoder_input_ids)
            predicted_ids = int(torch.argmax(logits, dim=-1))
    return predicted_ids
    
    # input_features, decoder_input_ids = prepare_data(audio_data, sampling_rate, config["encoder"])

    # model = SpeechClassifier(config).to(device)
    # # Here we load the model from Hugging Face Hub
    # model.load_state_dict(torch.hub.load_state_dict_from_url("https://huggingface.co/jcho02/whisper_cleft/resolve/main/pytorch_model.bin", map_location=device))

    # model.eval()
    # with torch.no_grad():
    #     logits = model(input_features, decoder_input_ids)
    #     predicted_ids = int(torch.argmax(logits, dim=-1))
    # return predicted_ids

# Gradio Interface functions
def gradio_file_interface(uploaded_file):
    # Assuming the uploaded_file is a filepath (str)
    audio_data, sampling_rate = librosa.load(uploaded_file, sr=None)
    prediction = predict(audio_data, sampling_rate, config)
    label = "Hypernasality Detected" if prediction == 1 else "No Hypernasality Detected"
    return label

    # with open(uploaded_file, "rb") as f:
    #     audio_data = np.frombuffer(f.read(), np.int16)
    # prediction = predict(audio_data, 16000, config)  # Assume 16kHz sample rate
    # label = "Hypernasality Detected" if prediction == 1 else "No Hypernasality Detected"
    # return label

def gradio_mic_interface(mic_input):
    # mic_input is a tuple with sample_rate and data as entries
    # (44100, array([   0,    0,    0, ..., -153, -140, -120], dtype=int16))
    prediction = predict(mic_input[1].astype(np.float32), mic_input[0], config)
    label = "Hypernasality Detected" if prediction == 1 else "No Hypernasality Detected"
    return label

# Define the interfaces inside the Blocks context
with gr.Blocks() as demo:
    # File Upload Tab
    with gr.Tab("Upload File"):
        gr.Interface(
            fn=gradio_file_interface,
            inputs=gr.Audio(sources="upload", type="filepath"),  # Use filepath for uploaded audio files
            outputs=gr.Textbox(label="Prediction")
        )

    # Mic Tab
    with gr.Tab("Record Using Microphone"):
        gr.Interface(
            fn=gradio_mic_interface,
            inputs=gr.Audio(sources="microphone", type="numpy"),  # Use numpy for real-time audio like microphone
            outputs=gr.Textbox(label="Prediction")
        )

# Launch the demo with debugging enabled
demo.launch(debug=True)