import torch | |
from transformers import AutoFeatureExtractor, WhisperForAudioClassification | |
import librosa | |
def get_language(audio_path): | |
feature_extractor = AutoFeatureExtractor.from_pretrained("/home/investigacion/disco4TB/workspace_pablo/firvox_whisper_research/whisper-medium-fleurs-lang-id/lang_identification_models_noFirVox_audios") | |
model = WhisperForAudioClassification.from_pretrained("/home/investigacion/disco4TB/workspace_pablo/firvox_whisper_research/whisper-medium-fleurs-lang-id/lang_identification_models_noFirVox_audios").to("cuda") | |
audio, sr= librosa.load(audio_path, sr=16000) | |
inputs = feature_extractor(audio, sampling_rate=sr, return_tensors="pt") | |
input_features = inputs.input_features.to("cuda") | |
with torch.no_grad(): | |
logits = model(input_features).logits | |
predicted_class_ids = torch.argmax(logits).item() | |
predicted_label = model.config.id2label[predicted_class_ids] | |
return predicted_label |