Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
import torch.nn as nn | |
from torch.utils.data import Dataset, DataLoader | |
from transformers import WhisperModel, WhisperFeatureExtractor | |
import datasets | |
from datasets import load_dataset, DatasetDict, Audio | |
from huggingface_hub import PyTorchModelHubMixin | |
import numpy as np | |
import librosa | |
# Ensure you have the device setup (cuda or cpu) | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# Define the config for your model | |
config = {"encoder": "openai/whisper-base", "num_labels": 2} | |
# Define data class | |
class SpeechInferenceDataset(Dataset): | |
def __init__(self, audio_data, text_processor): | |
self.audio_data = audio_data | |
self.text_processor = text_processor | |
def __len__(self): | |
return len(self.audio_data) | |
def __getitem__(self, index): | |
inputs = self.text_processor(self.audio_data[index]["audio"]["array"], | |
return_tensors="pt", | |
sampling_rate=self.audio_data[index]["audio"]["sampling_rate"]) | |
input_features = inputs.input_features | |
decoder_input_ids = torch.tensor([[1, 1]]) # Modify as per your model's requirements | |
return input_features, decoder_input_ids | |
# Define model class | |
class SpeechClassifier(nn.Module, PyTorchModelHubMixin): | |
def __init__(self, config): | |
super(SpeechClassifier, self).__init__() | |
self.encoder = WhisperModel.from_pretrained(config["encoder"]) | |
self.classifier = nn.Sequential( | |
nn.Linear(self.encoder.config.hidden_size, 4096), | |
nn.ReLU(), | |
nn.Linear(4096, 2048), | |
nn.ReLU(), | |
nn.Linear(2048, 1024), | |
nn.ReLU(), | |
nn.Linear(1024, 512), | |
nn.ReLU(), | |
nn.Linear(512, config["num_labels"]) | |
) | |
def forward(self, input_features, decoder_input_ids): | |
outputs = self.encoder(input_features, decoder_input_ids=decoder_input_ids) | |
pooled_output = outputs['last_hidden_state'][:, 0, :] | |
logits = self.classifier(pooled_output) | |
return logits | |
# Prepare data function | |
def prepare_data(audio_data, sampling_rate, model_checkpoint="openai/whisper-base"): | |
# Resample audio data to 16000 Hz | |
audio_data_resampled = librosa.resample(audio_data, orig_sr=sampling_rate, target_sr=16000) | |
# Initialize the feature extractor | |
feature_extractor = WhisperFeatureExtractor.from_pretrained(model_checkpoint) | |
# Use Dataset class | |
dataset = SpeechInferenceDataset([{"audio": {"array": audio_data_resampled, "sampling_rate": 16000}}], | |
text_processor=feature_extractor) | |
return dataset | |
# Prediction function | |
def predict(audio_data, sampling_rate, config): | |
input_features, decoder_input_ids = prepare_data(audio_data, sampling_rate, config["encoder"]) | |
model = SpeechClassifier(config).to(device) | |
# Here we load the model from Hugging Face Hub | |
model.load_state_dict(torch.hub.load_state_dict_from_url("https://huggingface.co/jcho02/whisper_cleft/resolve/main/pytorch_model.bin", map_location=device)) | |
model.eval() | |
with torch.no_grad(): | |
logits = model(input_features, decoder_input_ids) | |
predicted_ids = int(torch.argmax(logits, dim=-1)) | |
return predicted_ids | |
# Gradio Interface functions | |
def gradio_file_interface(uploaded_file): | |
# Assuming the uploaded_file is a filepath (str) | |
with open(uploaded_file, "rb") as f: | |
audio_data = np.frombuffer(f.read(), np.int16) | |
prediction = predict(audio_data, 16000, config) # Assume 16kHz sample rate | |
label = "Hypernasality Detected" if prediction == 1 else "No Hypernasality Detected" | |
return label | |
def gradio_mic_interface(mic_input): | |
# mic_input is a tuple with sample_rate and data as entries | |
# (44100, array([ 0, 0, 0, ..., -153, -140, -120], dtype=int16)) | |
prediction = predict(mic_input[1], mic_input[0], config) | |
label = "Hypernasality Detected" if prediction == 1 else "No Hypernasality Detected" | |
return label | |
# Define the interfaces inside the Blocks context | |
with gr.Blocks() as demo: | |
# File Upload Tab | |
with gr.Tab("Upload File"): | |
gr.Interface( | |
fn=gradio_file_interface, | |
inputs=gr.Audio(sources="upload", type="filepath"), # Use filepath for uploaded audio files | |
outputs=gr.Textbox(label="Prediction") | |
) | |
# Mic Tab | |
with gr.Tab("Record Using Microphone"): | |
gr.Interface( | |
fn=gradio_mic_interface, | |
inputs=gr.Audio(sources="microphone", type="numpy"), # Use numpy for real-time audio like microphone | |
outputs=gr.Textbox(label="Prediction") | |
) | |
# Launch the demo with debugging enabled | |
demo.launch(debug=True) |