import torch from torch import nn import torch.nn.functional as F import pytorch_lightning as pl from transformers import AutoConfig, AutoFeatureExtractor, AutoModelForAudioClassification DEFAULT_SR = 16_000 DEFAULT_BACKBONE = "MIT/ast-finetuned-audioset-10-10-0.4593" DEFAULT_N_CLASSES = 728 MODEL_STR = "dima806/bird_sounds_classification" #"facebook/wav2vec2-base-960h" RATE_HZ = 16000 # Define the maximum audio interval length to consider in seconds MAX_SECONDS = 10 # Calculate the maximum audio interval length in samples by multiplying the rate and seconds MAX_LENGTH = RATE_HZ * MAX_SECONDS # Create an instance of the feature extractor for audio. FEATURE_EXTRACTOR = AutoFeatureExtractor.from_pretrained(MODEL_STR) def birdvec_preprocess(audio_array, sr=DEFAULT_SR): """ Preprocess audio array for BirdAST model audio_array: np.array, audio array of the recording, shape (n_samples,) Note: The audio array should be normalized to [-1, 1] sr: int, sampling rate of the audio array (default: 16_000) Note: 1. The audio array should be normalized to [-1, 1]. 2. The audio length should be 10 seconds (or 10.24 seconds). Longer audio will be truncated. """ # Extract features features = FEATURE_EXTRACTOR(audio_array, sampling_rate=DEFAULT_SR, max_length=MAX_LENGTH, truncation=True, return_tensors="pt") return features.input_values def birdvec_inference( model_weights, spectrogram, device = 'cpu', backbone_name=None, n_classes=728, activation=None, n_mlp_layers=None ): """ Perform inference on BirdAST model model_weights: list, list of model weights spectrogram: torch.Tensor, spectrogram tensor, shape (batch_size, n_frames, n_mels,) device: str, device to run inference (default: 'cpu') backbone_name: str, name of the backbone model (default: 'MIT/ast-finetuned-audioset-10-10-0.4593') n_classes: int, number of classes (default: 728) activation: str, activation function (default: 'silu') n_mlp_layers: int, number of MLP layers (default: 1) Returns: predictions: np.array, array of predictions, shape (n_models, batch_size, n_classes) """ predict_collects = [] for _weights in model_weights: #model.load_state_dict(torch.load(_weights, map_location=device)['state_dict']) model = BirdSongClassifier.load_from_checkpoint(_weights, map_location=device, class_weights = None) if device != 'cpu': model.to(device) model.eval() with torch.no_grad(): if device != 'cpu': spectrogram = spectrogram.to(device) output = model(spectrogram) logits = output['logits'] probs = F.softmax(logits, dim=-1) predict_collects.append(probs) if device != 'cpu': predict_collects = [pred.cpu() for pred in predict_collects] predict_collects = torch.cat(predict_collects, dim=0).numpy() return predict_collects class BirdSongClassifier(pl.LightningModule): def __init__(self, class_weights): super().__init__() config = AutoConfig.from_pretrained("dima806/bird_sounds_classification") config.num_labels = 728 self.model = AutoModelForAudioClassification.from_config(config) def forward(self, x): return self.model(x)