jcho02 commited on
Commit
a6d5ae5
1 Parent(s): f4739ca

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -84
app.py CHANGED
@@ -8,58 +8,14 @@ from datasets import load_dataset, DatasetDict, Audio
8
  from huggingface_hub import PyTorchModelHubMixin
9
  import numpy as np
10
 
11
- # Ensure you have the device setup (cuda or cpu)
12
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
-
14
- # Define the config for your model
15
- config = {"encoder": "openai/whisper-base", "num_labels": 2}
16
-
17
- # Define data class
18
- class SpeechInferenceDataset(Dataset):
19
- def __init__(self, audio_data, text_processor):
20
- self.audio_data = audio_data
21
- self.text_processor = text_processor
22
-
23
- def __len__(self):
24
- return len(self.audio_data)
25
-
26
- def __getitem__(self, index):
27
- inputs = self.text_processor(self.audio_data[index]["audio"]["array"],
28
- return_tensors="pt",
29
- sampling_rate=self.audio_data[index]["audio"]["sampling_rate"])
30
- input_features = inputs.input_features
31
- decoder_input_ids = torch.tensor([[1, 1]]) # Modify as per your model's requirements
32
- return input_features, decoder_input_ids
33
-
34
- # Define model class
35
- class SpeechClassifier(nn.Module, PyTorchModelHubMixin):
36
- def __init__(self, config):
37
- super(SpeechClassifier, self).__init__()
38
- self.encoder = WhisperModel.from_pretrained(config["encoder"])
39
- self.classifier = nn.Sequential(
40
- nn.Linear(self.encoder.config.hidden_size, 4096),
41
- nn.ReLU(),
42
- nn.Linear(4096, 2048),
43
- nn.ReLU(),
44
- nn.Linear(2048, 1024),
45
- nn.ReLU(),
46
- nn.Linear(1024, 512),
47
- nn.ReLU(),
48
- nn.Linear(512, config["num_labels"])
49
- )
50
-
51
- def forward(self, input_features, decoder_input_ids):
52
- outputs = self.encoder(input_features, decoder_input_ids=decoder_input_ids)
53
- pooled_output = outputs['last_hidden_state'][:, 0, :]
54
- logits = self.classifier(pooled_output)
55
- return logits
56
 
57
  # Prepare data function
58
  def prepare_data(audio_data, sampling_rate, model_checkpoint="openai/whisper-base"):
59
  feature_extractor = WhisperFeatureExtractor.from_pretrained(model_checkpoint)
60
  inputs = feature_extractor(audio_data, sampling_rate=sampling_rate, return_tensors="pt")
61
  input_features = inputs.input_features
62
- decoder_input_ids = torch.tensor([[1, 1]]) # Modify as per your model's requirements
63
  return input_features.to(device), decoder_input_ids.to(device)
64
 
65
  # Prediction function
@@ -67,56 +23,35 @@ def predict(audio_data, sampling_rate, config):
67
  input_features, decoder_input_ids = prepare_data(audio_data, sampling_rate, config["encoder"])
68
 
69
  model = SpeechClassifier(config).to(device)
70
- # Here we load the model from Hugging Face Hub
71
  model.load_state_dict(torch.hub.load_state_dict_from_url("https://huggingface.co/jcho02/whisper_cleft/resolve/main/pytorch_model.bin", map_location=device))
72
-
73
  model.eval()
 
74
  with torch.no_grad():
75
  logits = model(input_features, decoder_input_ids)
76
  predicted_ids = int(torch.argmax(logits, dim=-1))
77
  return predicted_ids
78
 
79
- # Gradio Interface functions
80
- def gradio_file_interface(uploaded_file):
81
- # Assuming the uploaded_file is a filepath (str)
82
- with open(uploaded_file, "rb") as f:
83
- audio_data = np.frombuffer(f.read(), np.int16)
84
- prediction = predict(audio_data, 16000, config) # Assume 16kHz sample rate
85
- label = "Hypernasality Detected" if prediction == 1 else "No Hypernasality Detected"
86
- return label
87
-
88
- def gradio_mic_interface(mic_input):
89
- # Assuming mic_input is a tuple with the audio data and sample rate
90
- if isinstance(mic_input, tuple):
91
- audio_data, sample_rate = mic_input
92
  else:
93
- # If it's not a tuple, we will try to extract data as if it's a dict
94
- audio_data = mic_input['data']
95
- sample_rate = mic_input['sample_rate']
 
96
 
97
  prediction = predict(audio_data, sample_rate, config)
98
  label = "Hypernasality Detected" if prediction == 1 else "No Hypernasality Detected"
99
  return label
100
 
101
- # Initialize Blocks
102
- demo = gr.Blocks()
103
-
104
- # Define the interfaces inside the Blocks context
105
- with demo:
106
- mic_transcribe = gr.Interface(
107
- fn=gradio_mic_interface,
108
- inputs=gr.Audio(type="numpy"), # Use numpy for real-time audio like microphone
109
- outputs=gr.Textbox(label="Prediction")
110
- )
111
-
112
- file_transcribe = gr.Interface(
113
- fn=gradio_file_interface,
114
- inputs=gr.Audio(type="filepath"), # Use filepath for uploaded audio files
115
- outputs=gr.Textbox(label="Prediction")
116
- )
117
-
118
- # Combine interfaces into a tabbed interface
119
- gr.TabbedInterface([mic_transcribe, file_transcribe], ["Transcribe Microphone", "Transcribe Audio File"])
120
 
121
- # Launch the demo with debugging enabled
122
  demo.launch(debug=True)
 
8
  from huggingface_hub import PyTorchModelHubMixin
9
  import numpy as np
10
 
11
+ # [Your existing code for device setup, config, SpeechInferenceDataset, SpeechClassifier]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  # Prepare data function
14
  def prepare_data(audio_data, sampling_rate, model_checkpoint="openai/whisper-base"):
15
  feature_extractor = WhisperFeatureExtractor.from_pretrained(model_checkpoint)
16
  inputs = feature_extractor(audio_data, sampling_rate=sampling_rate, return_tensors="pt")
17
  input_features = inputs.input_features
18
+ decoder_input_ids = torch.tensor([[1, 1]])
19
  return input_features.to(device), decoder_input_ids.to(device)
20
 
21
  # Prediction function
 
23
  input_features, decoder_input_ids = prepare_data(audio_data, sampling_rate, config["encoder"])
24
 
25
  model = SpeechClassifier(config).to(device)
 
26
  model.load_state_dict(torch.hub.load_state_dict_from_url("https://huggingface.co/jcho02/whisper_cleft/resolve/main/pytorch_model.bin", map_location=device))
 
27
  model.eval()
28
+
29
  with torch.no_grad():
30
  logits = model(input_features, decoder_input_ids)
31
  predicted_ids = int(torch.argmax(logits, dim=-1))
32
  return predicted_ids
33
 
34
+ # Unified Gradio interface function
35
+ def gradio_interface(audio_input):
36
+ if isinstance(audio_input, tuple):
37
+ # If the input is a tuple, it's from the microphone
38
+ audio_data, sample_rate = audio_input
 
 
 
 
 
 
 
 
39
  else:
40
+ # Otherwise, it's an uploaded file
41
+ with open(audio_input, "rb") as f:
42
+ audio_data = np.frombuffer(f.read(), np.int16)
43
+ sample_rate = 16000 # Assume 16kHz sample rate for uploaded files
44
 
45
  prediction = predict(audio_data, sample_rate, config)
46
  label = "Hypernasality Detected" if prediction == 1 else "No Hypernasality Detected"
47
  return label
48
 
49
+ # Create Gradio interface
50
+ demo = gr.Interface(
51
+ fn=gradio_interface,
52
+ inputs=gr.Audio(type="numpy", label="Upload or Record Audio"),
53
+ outputs=gr.Textbox(label="Prediction")
54
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
+ # Launch the demo
57
  demo.launch(debug=True)