update weights
Browse files- __pycache__/app.cpython-311.pyc +0 -0
- app.py +12 -2
- birdvec.py +95 -0
- fetch_img.py +0 -3
__pycache__/app.cpython-311.pyc
CHANGED
Binary files a/__pycache__/app.cpython-311.pyc and b/__pycache__/app.cpython-311.pyc differ
|
|
app.py
CHANGED
@@ -23,7 +23,7 @@ from fetch_img import download_images, scientific_to_species_code
|
|
23 |
from audio_class_predictor import predict_class
|
24 |
from bird_ast_model import birdast_preprocess, birdast_inference
|
25 |
from bird_ast_seq_model import birdast_seq_preprocess, birdast_seq_inference
|
26 |
-
|
27 |
from utils import plot_wave, plot_mel, download_model, bandpass_filter
|
28 |
|
29 |
# Define the default parameters
|
@@ -60,10 +60,20 @@ birdast_seq_assets = {
|
|
60 |
"inference_fn": birdast_seq_inference,
|
61 |
}
|
62 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
# maintain a dictionary of assets
|
64 |
ASSET_DICT = {
|
65 |
"BirdAST": birdast_assets,
|
66 |
"BirdAST_Seq": birdast_seq_assets,
|
|
|
67 |
}
|
68 |
|
69 |
|
@@ -251,7 +261,7 @@ with gr.Blocks(theme = seafoam, css = css, js = js) as demo:
|
|
251 |
gr.Markdown(DESCRIPTION)
|
252 |
|
253 |
# add dropdown for model selection
|
254 |
-
model_names = ['BirdAST', 'BirdAST_Seq'] #, 'EfficientNet']
|
255 |
model_dropdown = gr.Dropdown(label="Choose a model", choices=model_names)
|
256 |
download_status = gr.Textbox(label="Model Status", lines=3, value='', interactive=False) # Non-interactive textbox for status
|
257 |
model_dropdown.change(handle_model_selection, inputs=[model_dropdown, download_status], outputs=download_status)
|
|
|
23 |
from audio_class_predictor import predict_class
|
24 |
from bird_ast_model import birdast_preprocess, birdast_inference
|
25 |
from bird_ast_seq_model import birdast_seq_preprocess, birdast_seq_inference
|
26 |
+
from birdvec import birdvec_preprocess, birdvec_inference
|
27 |
from utils import plot_wave, plot_mel, download_model, bandpass_filter
|
28 |
|
29 |
# Define the default parameters
|
|
|
60 |
"inference_fn": birdast_seq_inference,
|
61 |
}
|
62 |
|
63 |
+
birdvec_assets = {
|
64 |
+
"model_weights": [
|
65 |
+
f"https://huggingface.co/amroa/BirdVec/resolve/main/fold{i}/best-model{i}.ckpt" for i in range(3)
|
66 |
+
],
|
67 |
+
"label_mapping": "https://huggingface.co/amroa/BirdVec/resolve/main/new_label_map.csv",
|
68 |
+
"preprocess_fn": birdvec_preprocess,
|
69 |
+
"inference_fn": birdvec_inference,
|
70 |
+
}
|
71 |
+
|
72 |
# maintain a dictionary of assets
|
73 |
ASSET_DICT = {
|
74 |
"BirdAST": birdast_assets,
|
75 |
"BirdAST_Seq": birdast_seq_assets,
|
76 |
+
"BirdWav2Vec": birdvec_assets,
|
77 |
}
|
78 |
|
79 |
|
|
|
261 |
gr.Markdown(DESCRIPTION)
|
262 |
|
263 |
# add dropdown for model selection
|
264 |
+
model_names = ['BirdAST', 'BirdAST_Seq', 'BirdWav2Vec'] #, 'EfficientNet']
|
265 |
model_dropdown = gr.Dropdown(label="Choose a model", choices=model_names)
|
266 |
download_status = gr.Textbox(label="Model Status", lines=3, value='', interactive=False) # Non-interactive textbox for status
|
267 |
model_dropdown.change(handle_model_selection, inputs=[model_dropdown, download_status], outputs=download_status)
|
birdvec.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import pytorch_lightning as pl
|
5 |
+
from transformers import AutoConfig, AutoFeatureExtractor, AutoModelForAudioClassification
|
6 |
+
|
7 |
+
DEFAULT_SR = 16_000
|
8 |
+
DEFAULT_BACKBONE = "MIT/ast-finetuned-audioset-10-10-0.4593"
|
9 |
+
DEFAULT_N_CLASSES = 728
|
10 |
+
MODEL_STR = "dima806/bird_sounds_classification" #"facebook/wav2vec2-base-960h"
|
11 |
+
RATE_HZ = 16000
|
12 |
+
# Define the maximum audio interval length to consider in seconds
|
13 |
+
MAX_SECONDS = 10
|
14 |
+
# Calculate the maximum audio interval length in samples by multiplying the rate and seconds
|
15 |
+
MAX_LENGTH = RATE_HZ * MAX_SECONDS
|
16 |
+
|
17 |
+
# Create an instance of the feature extractor for audio.
|
18 |
+
FEATURE_EXTRACTOR = AutoFeatureExtractor.from_pretrained(MODEL_STR)
|
19 |
+
|
20 |
+
|
21 |
+
|
22 |
+
def birdvec_preprocess(audio_array, sr=DEFAULT_SR):
|
23 |
+
"""
|
24 |
+
Preprocess audio array for BirdAST model
|
25 |
+
audio_array: np.array, audio array of the recording, shape (n_samples,) Note: The audio array should be normalized to [-1, 1]
|
26 |
+
sr: int, sampling rate of the audio array (default: 16_000)
|
27 |
+
|
28 |
+
Note:
|
29 |
+
1. The audio array should be normalized to [-1, 1].
|
30 |
+
2. The audio length should be 10 seconds (or 10.24 seconds). Longer audio will be truncated.
|
31 |
+
"""
|
32 |
+
# Extract features
|
33 |
+
features = FEATURE_EXTRACTOR(audio_array, sampling_rate=DEFAULT_SR, max_length=MAX_LENGTH, truncation=True, return_tensors="pt")
|
34 |
+
|
35 |
+
return features.input_values
|
36 |
+
|
37 |
+
|
38 |
+
def birdvec_inference(
|
39 |
+
model_weights,
|
40 |
+
spectrogram,
|
41 |
+
device = 'cpu',
|
42 |
+
backbone_name=None,
|
43 |
+
n_classes=728,
|
44 |
+
activation=None,
|
45 |
+
n_mlp_layers=None
|
46 |
+
):
|
47 |
+
|
48 |
+
"""
|
49 |
+
Perform inference on BirdAST model
|
50 |
+
model_weights: list, list of model weights
|
51 |
+
spectrogram: torch.Tensor, spectrogram tensor, shape (batch_size, n_frames, n_mels,)
|
52 |
+
device: str, device to run inference (default: 'cpu')
|
53 |
+
backbone_name: str, name of the backbone model (default: 'MIT/ast-finetuned-audioset-10-10-0.4593')
|
54 |
+
n_classes: int, number of classes (default: 728)
|
55 |
+
activation: str, activation function (default: 'silu')
|
56 |
+
n_mlp_layers: int, number of MLP layers (default: 1)
|
57 |
+
|
58 |
+
Returns:
|
59 |
+
predictions: np.array, array of predictions, shape (n_models, batch_size, n_classes)
|
60 |
+
"""
|
61 |
+
|
62 |
+
|
63 |
+
|
64 |
+
predict_collects = []
|
65 |
+
for _weights in model_weights:
|
66 |
+
#model.load_state_dict(torch.load(_weights, map_location=device)['state_dict'])
|
67 |
+
model = BirdSongClassifier.load_from_checkpoint(_weights, map_location=device, class_weights = None)
|
68 |
+
if device != 'cpu': model.to(device)
|
69 |
+
model.eval()
|
70 |
+
|
71 |
+
with torch.no_grad():
|
72 |
+
if device != 'cpu': spectrogram = spectrogram.to(device)
|
73 |
+
|
74 |
+
output = model(spectrogram)
|
75 |
+
logits = output['logits']
|
76 |
+
probs = F.softmax(logits, dim=-1)
|
77 |
+
predict_collects.append(probs)
|
78 |
+
|
79 |
+
if device != 'cpu':
|
80 |
+
predict_collects = [pred.cpu() for pred in predict_collects]
|
81 |
+
|
82 |
+
predict_collects = torch.cat(predict_collects, dim=0).numpy()
|
83 |
+
|
84 |
+
return predict_collects
|
85 |
+
|
86 |
+
|
87 |
+
class BirdSongClassifier(pl.LightningModule):
|
88 |
+
def __init__(self, class_weights):
|
89 |
+
super().__init__()
|
90 |
+
config = AutoConfig.from_pretrained("dima806/bird_sounds_classification")
|
91 |
+
config.num_labels = 728
|
92 |
+
self.model = AutoModelForAudioClassification.from_config(config)
|
93 |
+
|
94 |
+
def forward(self, x):
|
95 |
+
return self.model(x)
|
fetch_img.py
CHANGED
@@ -13,9 +13,6 @@ REQ_FMT = {
|
|
13 |
"url": 'https://api.ebird.org/v2/ref/taxonomy/ebird',
|
14 |
"params" : {
|
15 |
'species': 'CHANGE THIS TO SPECIES CODE'
|
16 |
-
},
|
17 |
-
"headers" : {
|
18 |
-
'X-eBirdApiToken': 'id1a0e3q2lt3'
|
19 |
}
|
20 |
}
|
21 |
bird_df = pd.read_csv("ebird_taxonomy_v2023.csv")
|
|
|
13 |
"url": 'https://api.ebird.org/v2/ref/taxonomy/ebird',
|
14 |
"params" : {
|
15 |
'species': 'CHANGE THIS TO SPECIES CODE'
|
|
|
|
|
|
|
16 |
}
|
17 |
}
|
18 |
bird_df = pd.read_csv("ebird_taxonomy_v2023.csv")
|