lilyhof commited on
Commit
3ce86ea
1 Parent(s): c779966

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -29
app.py CHANGED
@@ -7,6 +7,7 @@ import datasets
7
  from datasets import load_dataset, DatasetDict, Audio
8
  from huggingface_hub import PyTorchModelHubMixin
9
  import numpy as np
 
10
 
11
  # Ensure you have the device setup (cuda or cpu)
12
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -56,16 +57,24 @@ class SpeechClassifier(nn.Module, PyTorchModelHubMixin):
56
 
57
  # Prepare data function
58
  def prepare_data(audio_data, sampling_rate, model_checkpoint="openai/whisper-base"):
 
 
 
 
 
59
  feature_extractor = WhisperFeatureExtractor.from_pretrained(model_checkpoint)
60
- inputs = feature_extractor(audio_data, sampling_rate=sampling_rate, return_tensors="pt")
61
- input_features = inputs.input_features
62
- decoder_input_ids = torch.tensor([[1, 1]]) # Modify as per your model's requirements
63
- return input_features.to(device), decoder_input_ids.to(device)
 
 
 
64
 
65
  # Prediction function
66
  def predict(audio_data, sampling_rate, config):
67
  input_features, decoder_input_ids = prepare_data(audio_data, sampling_rate, config["encoder"])
68
-
69
  model = SpeechClassifier(config).to(device)
70
  # Here we load the model from Hugging Face Hub
71
  model.load_state_dict(torch.hub.load_state_dict_from_url("https://huggingface.co/jcho02/whisper_cleft/resolve/main/pytorch_model.bin", map_location=device))
@@ -86,34 +95,29 @@ def gradio_file_interface(uploaded_file):
86
  return label
87
 
88
  def gradio_mic_interface(mic_input):
89
- # mic_input is a dictionary with 'data' and 'sample_rate' keys
90
- prediction = predict(mic_input['data'], mic_input['sample_rate'], config)
 
91
  label = "Hypernasality Detected" if prediction == 1 else "No Hypernasality Detected"
92
  return label
93
 
94
- # Initialize Blocks
95
- demo = gr.Blocks()
96
-
97
- # Tab 1: Upload File
98
- file_interface = gr.Interface(
99
- fn=gradio_file_interface,
100
- inputs=gr.Audio(sources="upload", type="filepath"), # Use filepath for uploaded audio files
101
- outputs=gr.Textbox(label="Prediction")
102
- )
103
-
104
- # Tab 2: Record with Mic
105
- mic_interface = gr.Interface(
106
- fn=gradio_mic_interface,
107
- inputs=gr.Audio(sources="microphone", type="numpy"), # Use numpy for real-time audio like microphone
108
- outputs=gr.Textbox(label="Prediction")
109
- )
110
-
111
  # Define the interfaces inside the Blocks context
112
- with demo:
113
- gr.TabbedInterface(
114
- [file_interface, mic_interface],
115
- ["Upload File", "Record Using Microphone"]
116
- )
 
 
 
 
 
 
 
 
 
 
 
117
 
118
  # Launch the demo with debugging enabled
119
  demo.launch(debug=True)
 
7
  from datasets import load_dataset, DatasetDict, Audio
8
  from huggingface_hub import PyTorchModelHubMixin
9
  import numpy as np
10
+ import librosa
11
 
12
  # Ensure you have the device setup (cuda or cpu)
13
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
57
 
58
  # Prepare data function
59
  def prepare_data(audio_data, sampling_rate, model_checkpoint="openai/whisper-base"):
60
+
61
+ # Resample audio data to 16000 Hz
62
+ audio_data_resampled = librosa.resample(audio_data, orig_sr=sampling_rate, target_sr=16000)
63
+
64
+ # Initialize the feature extractor
65
  feature_extractor = WhisperFeatureExtractor.from_pretrained(model_checkpoint)
66
+
67
+ # Use Dataset class
68
+ dataset = SpeechInferenceDataset([{"audio": {"array": audio_data_resampled, "sampling_rate": 16000}}],
69
+ text_processor=feature_extractor)
70
+
71
+ return dataset
72
+
73
 
74
  # Prediction function
75
  def predict(audio_data, sampling_rate, config):
76
  input_features, decoder_input_ids = prepare_data(audio_data, sampling_rate, config["encoder"])
77
+
78
  model = SpeechClassifier(config).to(device)
79
  # Here we load the model from Hugging Face Hub
80
  model.load_state_dict(torch.hub.load_state_dict_from_url("https://huggingface.co/jcho02/whisper_cleft/resolve/main/pytorch_model.bin", map_location=device))
 
95
  return label
96
 
97
  def gradio_mic_interface(mic_input):
98
+ # mic_input is a tuple with sample_rate and data as entries
99
+ # (44100, array([ 0, 0, 0, ..., -153, -140, -120], dtype=int16))
100
+ prediction = predict(mic_input[1], mic_input[0], config)
101
  label = "Hypernasality Detected" if prediction == 1 else "No Hypernasality Detected"
102
  return label
103
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
  # Define the interfaces inside the Blocks context
105
+ with gr.Blocks() as demo:
106
+ # File Upload Tab
107
+ with gr.Tab("Upload File"):
108
+ gr.Interface(
109
+ fn=gradio_file_interface,
110
+ inputs=gr.Audio(sources="upload", type="filepath"), # Use filepath for uploaded audio files
111
+ outputs=gr.Textbox(label="Prediction")
112
+ )
113
+
114
+ # Mic Tab
115
+ with gr.Tab("Record Using Microphone"):
116
+ gr.Interface(
117
+ fn=gradio_mic_interface,
118
+ inputs=gr.Audio(sources="microphone", type="numpy"), # Use numpy for real-time audio like microphone
119
+ outputs=gr.Textbox(label="Prediction")
120
+ )
121
 
122
  # Launch the demo with debugging enabled
123
  demo.launch(debug=True)