whisper_tiny_fleurs / utils /get_language.py
Pablex's picture
Training in progress, epoch 1
b7c959f verified
raw
history blame
958 Bytes
import torch
from transformers import AutoFeatureExtractor, WhisperForAudioClassification
import librosa
def get_language(audio_path):
feature_extractor = AutoFeatureExtractor.from_pretrained("/home/investigacion/disco4TB/workspace_pablo/firvox_whisper_research/whisper-medium-fleurs-lang-id/lang_identification_models_noFirVox_audios")
model = WhisperForAudioClassification.from_pretrained("/home/investigacion/disco4TB/workspace_pablo/firvox_whisper_research/whisper-medium-fleurs-lang-id/lang_identification_models_noFirVox_audios").to("cuda")
audio, sr= librosa.load(audio_path, sr=16000)
inputs = feature_extractor(audio, sampling_rate=sr, return_tensors="pt")
input_features = inputs.input_features.to("cuda")
with torch.no_grad():
logits = model(input_features).logits
predicted_class_ids = torch.argmax(logits).item()
predicted_label = model.config.id2label[predicted_class_ids]
return predicted_label