File size: 3,430 Bytes
d73fb39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import torch
from torch import nn
import torch.nn.functional as F
import pytorch_lightning as pl
from transformers import AutoConfig, AutoFeatureExtractor, AutoModelForAudioClassification

DEFAULT_SR = 16_000
DEFAULT_BACKBONE = "MIT/ast-finetuned-audioset-10-10-0.4593"
DEFAULT_N_CLASSES = 728
MODEL_STR = "dima806/bird_sounds_classification" #"facebook/wav2vec2-base-960h"
RATE_HZ = 16000
# Define the maximum audio interval length to consider in seconds
MAX_SECONDS = 10
# Calculate the maximum audio interval length in samples by multiplying the rate and seconds
MAX_LENGTH = RATE_HZ * MAX_SECONDS

# Create an instance of the feature extractor for audio.
FEATURE_EXTRACTOR = AutoFeatureExtractor.from_pretrained(MODEL_STR)



def birdvec_preprocess(audio_array, sr=DEFAULT_SR):
    """
    Preprocess audio array for BirdAST model
    audio_array: np.array, audio array of the recording, shape (n_samples,) Note: The audio array should be normalized to [-1, 1]
    sr: int, sampling rate of the audio array (default: 16_000)
    
    Note: 
    1. The audio array should be normalized to [-1, 1].
    2. The audio length should be 10 seconds (or 10.24 seconds). Longer audio will be truncated.
    """
    # Extract features
    features = FEATURE_EXTRACTOR(audio_array, sampling_rate=DEFAULT_SR, max_length=MAX_LENGTH, truncation=True, return_tensors="pt")
    
    return features.input_values


def birdvec_inference(
    model_weights, 
    spectrogram, 
    device = 'cpu', 
    backbone_name=None, 
    n_classes=728,
    activation=None,
    n_mlp_layers=None
    ):
    
    """
    Perform inference on BirdAST model
    model_weights: list, list of model weights
    spectrogram: torch.Tensor, spectrogram tensor, shape (batch_size, n_frames, n_mels,)
    device: str, device to run inference (default: 'cpu')
    backbone_name: str, name of the backbone model (default: 'MIT/ast-finetuned-audioset-10-10-0.4593')
    n_classes: int, number of classes (default: 728)
    activation: str, activation function (default: 'silu')
    n_mlp_layers: int, number of MLP layers (default: 1)
    
    Returns:
    predictions: np.array, array of predictions, shape (n_models, batch_size, n_classes)
    """
    
     
    
    predict_collects = []
    for _weights in model_weights:
        #model.load_state_dict(torch.load(_weights, map_location=device)['state_dict'])
        model = BirdSongClassifier.load_from_checkpoint(_weights, map_location=device, class_weights = None)
        if device != 'cpu': model.to(device)
        model.eval()
        
        with torch.no_grad():
            if device != 'cpu': spectrogram = spectrogram.to(device)
                
            output = model(spectrogram)
            logits = output['logits']
            probs = F.softmax(logits, dim=-1)
            predict_collects.append(probs)
            
    if device != 'cpu':
        predict_collects = [pred.cpu() for pred in predict_collects]
        
    predict_collects = torch.cat(predict_collects, dim=0).numpy()
            
    return predict_collects
    
    
class BirdSongClassifier(pl.LightningModule):
    def __init__(self, class_weights):
        super().__init__()
        config = AutoConfig.from_pretrained("dima806/bird_sounds_classification")
        config.num_labels = 728
        self.model = AutoModelForAudioClassification.from_config(config)

    def forward(self, x):
        return self.model(x)