lilyhof commited on
Commit
5ed82c5
1 Parent(s): 603d981

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -11
app.py CHANGED
@@ -28,7 +28,7 @@ class SpeechInferenceDataset(Dataset):
28
  inputs = self.text_processor(self.audio_data[index]["audio"]["array"],
29
  return_tensors="pt",
30
  sampling_rate=self.audio_data[index]["audio"]["sampling_rate"])
31
- input_features = inputs.input_features
32
  decoder_input_ids = torch.tensor([[1, 1]]) # Modify as per your model's requirements
33
  return input_features, decoder_input_ids
34
 
@@ -58,6 +58,9 @@ class SpeechClassifier(nn.Module, PyTorchModelHubMixin):
58
  # Prepare data function
59
  def prepare_data(audio_data, sampling_rate, model_checkpoint="openai/whisper-base"):
60
 
 
 
 
61
  # Resample audio data to 16000 Hz
62
  audio_data_resampled = librosa.resample(audio_data, orig_sr=sampling_rate, target_sr=16000)
63
 
@@ -68,12 +71,15 @@ def prepare_data(audio_data, sampling_rate, model_checkpoint="openai/whisper-bas
68
  dataset = SpeechInferenceDataset([{"audio": {"array": audio_data_resampled, "sampling_rate": 16000}}],
69
  text_processor=feature_extractor)
70
 
71
- return dataset
 
 
 
72
 
73
 
74
  # Prediction function
75
  def predict(audio_data, sampling_rate, config):
76
- input_features, decoder_input_ids = prepare_data(audio_data, sampling_rate, config["encoder"])
77
 
78
  model = SpeechClassifier(config).to(device)
79
  # Here we load the model from Hugging Face Hub
@@ -81,23 +87,43 @@ def predict(audio_data, sampling_rate, config):
81
 
82
  model.eval()
83
  with torch.no_grad():
84
- logits = model(input_features, decoder_input_ids)
85
- predicted_ids = int(torch.argmax(logits, dim=-1))
 
 
 
86
  return predicted_ids
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
  # Gradio Interface functions
89
  def gradio_file_interface(uploaded_file):
90
  # Assuming the uploaded_file is a filepath (str)
91
- with open(uploaded_file, "rb") as f:
92
- audio_data = np.frombuffer(f.read(), np.int16)
93
- prediction = predict(audio_data, 16000, config) # Assume 16kHz sample rate
94
  label = "Hypernasality Detected" if prediction == 1 else "No Hypernasality Detected"
95
  return label
96
 
 
 
 
 
 
 
97
  def gradio_mic_interface(mic_input):
98
  # mic_input is a tuple with sample_rate and data as entries
99
  # (44100, array([ 0, 0, 0, ..., -153, -140, -120], dtype=int16))
100
- prediction = predict(mic_input[1], mic_input[0], config)
101
  label = "Hypernasality Detected" if prediction == 1 else "No Hypernasality Detected"
102
  return label
103
 
@@ -119,5 +145,5 @@ with gr.Blocks() as demo:
119
  outputs=gr.Textbox(label="Prediction")
120
  )
121
 
122
- # Launch the demo
123
- demo.launch()
 
28
  inputs = self.text_processor(self.audio_data[index]["audio"]["array"],
29
  return_tensors="pt",
30
  sampling_rate=self.audio_data[index]["audio"]["sampling_rate"])
31
+ input_features = inputs.input_features.squeeze(0)
32
  decoder_input_ids = torch.tensor([[1, 1]]) # Modify as per your model's requirements
33
  return input_features, decoder_input_ids
34
 
 
58
  # Prepare data function
59
  def prepare_data(audio_data, sampling_rate, model_checkpoint="openai/whisper-base"):
60
 
61
+ # Convert audio data to float32
62
+ audio_data = audio_data.astype(np.float32)
63
+
64
  # Resample audio data to 16000 Hz
65
  audio_data_resampled = librosa.resample(audio_data, orig_sr=sampling_rate, target_sr=16000)
66
 
 
71
  dataset = SpeechInferenceDataset([{"audio": {"array": audio_data_resampled, "sampling_rate": 16000}}],
72
  text_processor=feature_extractor)
73
 
74
+ dataloader = DataLoader(dataset, batch_size=1)
75
+
76
+ return dataloader
77
+ # return dataset
78
 
79
 
80
  # Prediction function
81
  def predict(audio_data, sampling_rate, config):
82
+ dataloader = prepare_data(audio_data, sampling_rate, config["encoder"])
83
 
84
  model = SpeechClassifier(config).to(device)
85
  # Here we load the model from Hugging Face Hub
 
87
 
88
  model.eval()
89
  with torch.no_grad():
90
+ for input_features, decoder_input_ids in dataloader:
91
+ input_features = input_features.to(device)
92
+ decoder_input_ids = decoder_input_ids.to(device)
93
+ logits = model(input_features, decoder_input_ids)
94
+ predicted_ids = int(torch.argmax(logits, dim=-1))
95
  return predicted_ids
96
+
97
+ # input_features, decoder_input_ids = prepare_data(audio_data, sampling_rate, config["encoder"])
98
+
99
+ # model = SpeechClassifier(config).to(device)
100
+ # # Here we load the model from Hugging Face Hub
101
+ # model.load_state_dict(torch.hub.load_state_dict_from_url("https://huggingface.co/jcho02/whisper_cleft/resolve/main/pytorch_model.bin", map_location=device))
102
+
103
+ # model.eval()
104
+ # with torch.no_grad():
105
+ # logits = model(input_features, decoder_input_ids)
106
+ # predicted_ids = int(torch.argmax(logits, dim=-1))
107
+ # return predicted_ids
108
 
109
  # Gradio Interface functions
110
  def gradio_file_interface(uploaded_file):
111
  # Assuming the uploaded_file is a filepath (str)
112
+ audio_data, sampling_rate = librosa.load(uploaded_file, sr=None)
113
+ prediction = predict(audio_data, sampling_rate, config)
 
114
  label = "Hypernasality Detected" if prediction == 1 else "No Hypernasality Detected"
115
  return label
116
 
117
+ # with open(uploaded_file, "rb") as f:
118
+ # audio_data = np.frombuffer(f.read(), np.int16)
119
+ # prediction = predict(audio_data, 16000, config) # Assume 16kHz sample rate
120
+ # label = "Hypernasality Detected" if prediction == 1 else "No Hypernasality Detected"
121
+ # return label
122
+
123
  def gradio_mic_interface(mic_input):
124
  # mic_input is a tuple with sample_rate and data as entries
125
  # (44100, array([ 0, 0, 0, ..., -153, -140, -120], dtype=int16))
126
+ prediction = predict(mic_input[1].astype(np.float32), mic_input[0], config)
127
  label = "Hypernasality Detected" if prediction == 1 else "No Hypernasality Detected"
128
  return label
129
 
 
145
  outputs=gr.Textbox(label="Prediction")
146
  )
147
 
148
+ # Launch the demo with debugging enabled
149
+ demo.launch(debug=True)