import gradio as gr import torch import torch.nn as nn from torch.utils.data import Dataset, DataLoader from transformers import WhisperModel, WhisperFeatureExtractor import datasets from datasets import load_dataset, DatasetDict, Audio from huggingface_hub import PyTorchModelHubMixin import numpy as np import librosa # Ensure you have the device setup (cuda or cpu) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Define the config for your model config = {"encoder": "openai/whisper-base", "num_labels": 2} # Define data class class SpeechInferenceDataset(Dataset): def __init__(self, audio_data, text_processor): self.audio_data = audio_data self.text_processor = text_processor def __len__(self): return len(self.audio_data) def __getitem__(self, index): inputs = self.text_processor(self.audio_data[index]["audio"]["array"], return_tensors="pt", sampling_rate=self.audio_data[index]["audio"]["sampling_rate"]) input_features = inputs.input_features.squeeze(0) decoder_input_ids = torch.tensor([[1, 1]]) # Modify as per your model's requirements return input_features, decoder_input_ids # Define model class class SpeechClassifier(nn.Module, PyTorchModelHubMixin): def __init__(self, config): super(SpeechClassifier, self).__init__() self.encoder = WhisperModel.from_pretrained(config["encoder"]) self.classifier = nn.Sequential( nn.Linear(self.encoder.config.hidden_size, 4096), nn.ReLU(), nn.Linear(4096, 2048), nn.ReLU(), nn.Linear(2048, 1024), nn.ReLU(), nn.Linear(1024, 512), nn.ReLU(), nn.Linear(512, config["num_labels"]) ) def forward(self, input_features, decoder_input_ids): outputs = self.encoder(input_features, decoder_input_ids=decoder_input_ids) pooled_output = outputs['last_hidden_state'][:, 0, :] logits = self.classifier(pooled_output) return logits # Prepare data function def prepare_data(audio_data, sampling_rate, model_checkpoint="openai/whisper-base"): # Convert audio data to float32 audio_data = audio_data.astype(np.float32) # Resample audio data to 16000 Hz audio_data_resampled = librosa.resample(audio_data, orig_sr=sampling_rate, target_sr=16000) # Initialize the feature extractor feature_extractor = WhisperFeatureExtractor.from_pretrained(model_checkpoint) # Use Dataset class dataset = SpeechInferenceDataset([{"audio": {"array": audio_data_resampled, "sampling_rate": 16000}}], text_processor=feature_extractor) dataloader = DataLoader(dataset, batch_size=1) return dataloader # return dataset # Prediction function def predict(audio_data, sampling_rate, config): dataloader = prepare_data(audio_data, sampling_rate, config["encoder"]) model = SpeechClassifier(config).to(device) # Here we load the model from Hugging Face Hub model.load_state_dict(torch.hub.load_state_dict_from_url("https://huggingface.co/jcho02/whisper_cleft/resolve/main/pytorch_model.bin", map_location=device)) model.eval() with torch.no_grad(): for input_features, decoder_input_ids in dataloader: input_features = input_features.to(device) decoder_input_ids = decoder_input_ids.to(device) logits = model(input_features, decoder_input_ids) predicted_ids = int(torch.argmax(logits, dim=-1)) return predicted_ids # input_features, decoder_input_ids = prepare_data(audio_data, sampling_rate, config["encoder"]) # model = SpeechClassifier(config).to(device) # # Here we load the model from Hugging Face Hub # model.load_state_dict(torch.hub.load_state_dict_from_url("https://huggingface.co/jcho02/whisper_cleft/resolve/main/pytorch_model.bin", map_location=device)) # model.eval() # with torch.no_grad(): # logits = model(input_features, decoder_input_ids) # predicted_ids = int(torch.argmax(logits, dim=-1)) # return predicted_ids # Gradio Interface functions def gradio_file_interface(uploaded_file): # Assuming the uploaded_file is a filepath (str) audio_data, sampling_rate = librosa.load(uploaded_file, sr=None) prediction = predict(audio_data, sampling_rate, config) label = "Hypernasality Detected" if prediction == 1 else "No Hypernasality Detected" return label # with open(uploaded_file, "rb") as f: # audio_data = np.frombuffer(f.read(), np.int16) # prediction = predict(audio_data, 16000, config) # Assume 16kHz sample rate # label = "Hypernasality Detected" if prediction == 1 else "No Hypernasality Detected" # return label def gradio_mic_interface(mic_input): # mic_input is a tuple with sample_rate and data as entries # (44100, array([ 0, 0, 0, ..., -153, -140, -120], dtype=int16)) prediction = predict(mic_input[1].astype(np.float32), mic_input[0], config) label = "Hypernasality Detected" if prediction == 1 else "No Hypernasality Detected" return label # Define the interfaces inside the Blocks context with gr.Blocks() as demo: # File Upload Tab with gr.Tab("Upload File"): gr.Interface( fn=gradio_file_interface, inputs=gr.Audio(sources="upload", type="filepath"), # Use filepath for uploaded audio files outputs=gr.Textbox(label="Prediction") ) # Mic Tab with gr.Tab("Record Using Microphone"): gr.Interface( fn=gradio_mic_interface, inputs=gr.Audio(sources="microphone", type="numpy"), # Use numpy for real-time audio like microphone outputs=gr.Textbox(label="Prediction") ) # Launch the demo with debugging enabled demo.launch(debug=True)