|
import timm |
|
import json |
|
import torch |
|
from torchaudio.functional import resample |
|
import numpy as np |
|
from torchaudio.compliance import kaldi |
|
import torch.nn.functional as F |
|
import requests |
|
|
|
|
|
TAG = "gaunernst/vit_base_patch16_1024_128.audiomae_as2m_ft_as20k" |
|
MODEL = timm.create_model(f"hf_hub:{TAG}", pretrained=True).eval() |
|
|
|
LABEL_URL = "https://huggingface.co/datasets/huggingface/label-files/raw/main/audioset-id2label.json" |
|
AUDIOSET_LABELS = list(json.loads(requests.get(LABEL_URL).content).values()) |
|
|
|
SAMPLING_RATE = 16_000 |
|
MEAN = -4.2677393 |
|
STD = 4.5689974 |
|
|
|
def preprocess(x: torch.Tensor): |
|
x = x - x.mean() |
|
melspec = kaldi.fbank(x.unsqueeze(0), htk_compat=True, window_type="hanning", num_mel_bins=128) |
|
if melspec.shape[0] < 1024: |
|
melspec = F.pad(melspec, (0, 0, 0, 1024 - melspec.shape[0])) |
|
else: |
|
melspec = melspec[:1024] |
|
melspec = (melspec - MEAN) / (STD * 2) |
|
return melspec |
|
|
|
def predict_class(x: np.ndarray): |
|
|
|
x = torch.from_numpy(x) |
|
if x.ndim > 1: |
|
x = x.mean(-1) |
|
assert x.ndim == 1 |
|
|
|
x = preprocess(x) |
|
|
|
with torch.inference_mode(): |
|
logits = MODEL(x.view(1, 1, 1024, 128)).squeeze(0) |
|
|
|
topk_probs, topk_classes = logits.sigmoid().topk(10) |
|
preds = [[AUDIOSET_LABELS[cls], prob.item()*100] for cls, prob in zip(topk_classes, topk_probs)] |
|
|
|
return preds |