import gradio as gr import torch import torch.nn as nn from torch.utils.data import Dataset, DataLoader from transformers import WhisperModel, WhisperFeatureExtractor import datasets from datasets import load_dataset, DatasetDict, Audio from huggingface_hub import PyTorchModelHubMixin import numpy as np # Ensure you have the device setup (cuda or cpu) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Define the config for your model config = {"encoder": "openai/whisper-base", "num_labels": 2} # Define data class class SpeechInferenceDataset(Dataset): def __init__(self, audio_data, text_processor): self.audio_data = audio_data self.text_processor = text_processor def __len__(self): return len(self.audio_data) def __getitem__(self, index): inputs = self.text_processor(self.audio_data[index]["audio"]["array"], return_tensors="pt", sampling_rate=self.audio_data[index]["audio"]["sampling_rate"]) input_features = inputs.input_features decoder_input_ids = torch.tensor([[1, 1]]) # Modify as per your model's requirements return input_features, decoder_input_ids # Define model class class SpeechClassifier(nn.Module, PyTorchModelHubMixin): def __init__(self, config): super(SpeechClassifier, self).__init__() self.encoder = WhisperModel.from_pretrained(config["encoder"]) self.classifier = nn.Sequential( nn.Linear(self.encoder.config.hidden_size, 4096), nn.ReLU(), nn.Linear(4096, 2048), nn.ReLU(), nn.Linear(2048, 1024), nn.ReLU(), nn.Linear(1024, 512), nn.ReLU(), nn.Linear(512, config["num_labels"]) ) def forward(self, input_features, decoder_input_ids): outputs = self.encoder(input_features, decoder_input_ids=decoder_input_ids) pooled_output = outputs['last_hidden_state'][:, 0, :] logits = self.classifier(pooled_output) return logits # Prepare data function def prepare_data(audio_data, sampling_rate, model_checkpoint="openai/whisper-base"): feature_extractor = WhisperFeatureExtractor.from_pretrained(model_checkpoint) inputs = feature_extractor(audio_data, sampling_rate=sampling_rate, return_tensors="pt") input_features = inputs.input_features decoder_input_ids = torch.tensor([[1, 1]]) # Modify as per your model's requirements return input_features.to(device), decoder_input_ids.to(device) # Prediction function def predict(audio_data, sampling_rate, config): input_features, decoder_input_ids = prepare_data(audio_data, sampling_rate, config["encoder"]) model = SpeechClassifier(config).to(device) # Here we load the model from Hugging Face Hub model.load_state_dict(torch.hub.load_state_dict_from_url("https://huggingface.co/jcho02/whisper_cleft/resolve/main/pytorch_model.bin", map_location=device)) model.eval() with torch.no_grad(): logits = model(input_features, decoder_input_ids) predicted_ids = int(torch.argmax(logits, dim=-1)) return predicted_ids # Gradio Interface functions def gradio_file_interface(uploaded_file): # Assuming the uploaded_file is a filepath (str) with open(uploaded_file, "rb") as f: audio_data = np.frombuffer(f.read(), np.int16) prediction = predict(audio_data, 16000, config) # Assume 16kHz sample rate label = "Hypernasality Detected" if prediction == 1 else "No Hypernasality Detected" return label def gradio_mic_interface(mic_input): # mic_input is a dictionary with 'data' and 'sample_rate' keys prediction = predict(mic_input['data'], mic_input['sample_rate'], config) label = "Hypernasality Detected" if prediction == 1 else "No Hypernasality Detected" return label # Initialize Blocks demo = gr.Blocks() # Define the interfaces inside the Blocks context with demo: #mic_transcribe = gr.Interface( # fn=gradio_mic_interface, # inputs=gr.Audio(type="numpy"), # Use numpy for real-time audio like microphone # outputs=gr.Textbox(label="Prediction") #) file_transcribe = gr.Interface( fn=gradio_file_interface, inputs=gr.Audio(type="filepath"), # Use filepath for uploaded audio files outputs=gr.Textbox(label="Prediction") ) # Combine interfaces into a tabbed interface #gr.TabbedInterface([mic_transcribe, file_transcribe], ["Transcribe Microphone", "Transcribe Audio File"]) # Launch the demo with debugging enabled demo.launch(debug=True)