voj / birdvec.py
amroa's picture
update weights
d73fb39
raw
history blame
3.43 kB
import torch
from torch import nn
import torch.nn.functional as F
import pytorch_lightning as pl
from transformers import AutoConfig, AutoFeatureExtractor, AutoModelForAudioClassification
DEFAULT_SR = 16_000
DEFAULT_BACKBONE = "MIT/ast-finetuned-audioset-10-10-0.4593"
DEFAULT_N_CLASSES = 728
MODEL_STR = "dima806/bird_sounds_classification" #"facebook/wav2vec2-base-960h"
RATE_HZ = 16000
# Define the maximum audio interval length to consider in seconds
MAX_SECONDS = 10
# Calculate the maximum audio interval length in samples by multiplying the rate and seconds
MAX_LENGTH = RATE_HZ * MAX_SECONDS
# Create an instance of the feature extractor for audio.
FEATURE_EXTRACTOR = AutoFeatureExtractor.from_pretrained(MODEL_STR)
def birdvec_preprocess(audio_array, sr=DEFAULT_SR):
"""
Preprocess audio array for BirdAST model
audio_array: np.array, audio array of the recording, shape (n_samples,) Note: The audio array should be normalized to [-1, 1]
sr: int, sampling rate of the audio array (default: 16_000)
Note:
1. The audio array should be normalized to [-1, 1].
2. The audio length should be 10 seconds (or 10.24 seconds). Longer audio will be truncated.
"""
# Extract features
features = FEATURE_EXTRACTOR(audio_array, sampling_rate=DEFAULT_SR, max_length=MAX_LENGTH, truncation=True, return_tensors="pt")
return features.input_values
def birdvec_inference(
model_weights,
spectrogram,
device = 'cpu',
backbone_name=None,
n_classes=728,
activation=None,
n_mlp_layers=None
):
"""
Perform inference on BirdAST model
model_weights: list, list of model weights
spectrogram: torch.Tensor, spectrogram tensor, shape (batch_size, n_frames, n_mels,)
device: str, device to run inference (default: 'cpu')
backbone_name: str, name of the backbone model (default: 'MIT/ast-finetuned-audioset-10-10-0.4593')
n_classes: int, number of classes (default: 728)
activation: str, activation function (default: 'silu')
n_mlp_layers: int, number of MLP layers (default: 1)
Returns:
predictions: np.array, array of predictions, shape (n_models, batch_size, n_classes)
"""
predict_collects = []
for _weights in model_weights:
#model.load_state_dict(torch.load(_weights, map_location=device)['state_dict'])
model = BirdSongClassifier.load_from_checkpoint(_weights, map_location=device, class_weights = None)
if device != 'cpu': model.to(device)
model.eval()
with torch.no_grad():
if device != 'cpu': spectrogram = spectrogram.to(device)
output = model(spectrogram)
logits = output['logits']
probs = F.softmax(logits, dim=-1)
predict_collects.append(probs)
if device != 'cpu':
predict_collects = [pred.cpu() for pred in predict_collects]
predict_collects = torch.cat(predict_collects, dim=0).numpy()
return predict_collects
class BirdSongClassifier(pl.LightningModule):
def __init__(self, class_weights):
super().__init__()
config = AutoConfig.from_pretrained("dima806/bird_sounds_classification")
config.num_labels = 728
self.model = AutoModelForAudioClassification.from_config(config)
def forward(self, x):
return self.model(x)