jcho02 commited on
Commit
09aaa9c
1 Parent(s): 17c5efb

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +91 -0
app.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ # Import necessary libraries
4
+ import gradio as gr
5
+ import torch
6
+ import torch.nn as nn
7
+ from torch.utils.data import Dataset, DataLoader
8
+ from transformers import WhisperModel, WhisperFeatureExtractor
9
+ import datasets
10
+ from datasets import load_dataset, DatasetDict, Audio
11
+ from huggingface_hub import PyTorchModelHubMixin
12
+
13
+ # Define data class
14
+ class SpeechInferenceDataset(Dataset):
15
+ def __init__(self, audio_data, text_processor):
16
+ self.audio_data = audio_data
17
+ self.text_processor = text_processor
18
+
19
+ def __len__(self):
20
+ return len(self.audio_data)
21
+
22
+ def __getitem__(self, index):
23
+ inputs = self.text_processor(self.audio_data[index]["audio"]["array"],
24
+ return_tensors="pt",
25
+ sampling_rate=self.audio_data[index]["audio"]["sampling_rate"])
26
+ input_features = inputs.input_features
27
+ decoder_input_ids = torch.tensor([[1, 1]]) # Modify as per your model's requirements
28
+ return input_features, decoder_input_ids
29
+
30
+ # Define model class
31
+ class SpeechClassifier(nn.Module, PyTorchModelHubMixin):
32
+ def __init__(self, config):
33
+ super(SpeechClassifier, self).__init__()
34
+ self.encoder = WhisperModel.from_pretrained(config["encoder"])
35
+ self.classifier = nn.Sequential(
36
+ nn.Linear(self.encoder.config.hidden_size, 4096),
37
+ nn.ReLU(),
38
+ nn.Linear(4096, 2048),
39
+ nn.ReLU(),
40
+ nn.Linear(2048, 1024),
41
+ nn.ReLU(),
42
+ nn.Linear(1024, 512),
43
+ nn.ReLU(),
44
+ nn.Linear(512, config["num_labels"])
45
+ )
46
+
47
+ def forward(self, input_features, decoder_input_ids):
48
+ outputs = self.encoder(input_features, decoder_input_ids=decoder_input_ids)
49
+ pooled_output = outputs['last_hidden_state'][:, 0, :]
50
+ logits = self.classifier(pooled_output)
51
+ return logits
52
+
53
+ # Prepare data function
54
+ def prepare_data(audio_file_path, model_checkpoint="openai/whisper-base"):
55
+ feature_extractor = WhisperFeatureExtractor.from_pretrained(model_checkpoint)
56
+ inference_data = datasets.Dataset.from_dict({"path": [audio_file_path], "audio": [audio_file_path]}).cast_column("audio", Audio(sampling_rate=16_000))
57
+ inference_dataset = SpeechInferenceDataset(inference_data, feature_extractor)
58
+ inference_loader = DataLoader(inference_dataset, batch_size=1, shuffle=False)
59
+ input_features, decoder_input_ids = next(iter(inference_loader))
60
+ # Replace 'device' with your device configuration (e.g., 'cuda' or 'cpu')
61
+ input_features = input_features.squeeze(1).to(device)
62
+ decoder_input_ids = decoder_input_ids.squeeze(1).to(device)
63
+ return input_features, decoder_input_ids
64
+
65
+ # Prediction function
66
+ def predict(audio_file_path, config={"encoder": "openai/whisper-base", "num_labels": 2}):
67
+ input_features, decoder_input_ids = prepare_data(audio_file_path)
68
+
69
+ # Load the model from Hugging Face Hub
70
+ model = SpeechClassifier(config)
71
+ model.load_state_dict(torch.load(model.push_from_hub("jcho02/whisper_cleft")))
72
+ model.eval()
73
+
74
+ with torch.no_grad():
75
+ logits = model(input_features, decoder_input_ids)
76
+ predicted_ids = int(torch.argmax(logits, dim=-1))
77
+ return predicted_ids
78
+
79
+ # Gradio Interface function
80
+ def gradio_interface(uploaded_file):
81
+ with open(uploaded_file.name, "wb") as f:
82
+ f.write(uploaded_file.read())
83
+ prediction = predict(uploaded_file.name)
84
+ label = "Hypernasality Detected" if prediction == 1 else "No Hypernasality Detected"
85
+ return label
86
+
87
+ # Create and launch Gradio Interface with File upload input
88
+ iface = gr.Interface(fn=gradio_interface,
89
+ inputs=gr.inputs.File(label="Upload Audio File"),
90
+ outputs="text")
91
+ iface.launch()