amroa commited on
Commit
d73fb39
1 Parent(s): dcf7d14

update weights

Browse files
Files changed (4) hide show
  1. __pycache__/app.cpython-311.pyc +0 -0
  2. app.py +12 -2
  3. birdvec.py +95 -0
  4. fetch_img.py +0 -3
__pycache__/app.cpython-311.pyc CHANGED
Binary files a/__pycache__/app.cpython-311.pyc and b/__pycache__/app.cpython-311.pyc differ
 
app.py CHANGED
@@ -23,7 +23,7 @@ from fetch_img import download_images, scientific_to_species_code
23
  from audio_class_predictor import predict_class
24
  from bird_ast_model import birdast_preprocess, birdast_inference
25
  from bird_ast_seq_model import birdast_seq_preprocess, birdast_seq_inference
26
-
27
  from utils import plot_wave, plot_mel, download_model, bandpass_filter
28
 
29
  # Define the default parameters
@@ -60,10 +60,20 @@ birdast_seq_assets = {
60
  "inference_fn": birdast_seq_inference,
61
  }
62
 
 
 
 
 
 
 
 
 
 
63
  # maintain a dictionary of assets
64
  ASSET_DICT = {
65
  "BirdAST": birdast_assets,
66
  "BirdAST_Seq": birdast_seq_assets,
 
67
  }
68
 
69
 
@@ -251,7 +261,7 @@ with gr.Blocks(theme = seafoam, css = css, js = js) as demo:
251
  gr.Markdown(DESCRIPTION)
252
 
253
  # add dropdown for model selection
254
- model_names = ['BirdAST', 'BirdAST_Seq'] #, 'EfficientNet']
255
  model_dropdown = gr.Dropdown(label="Choose a model", choices=model_names)
256
  download_status = gr.Textbox(label="Model Status", lines=3, value='', interactive=False) # Non-interactive textbox for status
257
  model_dropdown.change(handle_model_selection, inputs=[model_dropdown, download_status], outputs=download_status)
 
23
  from audio_class_predictor import predict_class
24
  from bird_ast_model import birdast_preprocess, birdast_inference
25
  from bird_ast_seq_model import birdast_seq_preprocess, birdast_seq_inference
26
+ from birdvec import birdvec_preprocess, birdvec_inference
27
  from utils import plot_wave, plot_mel, download_model, bandpass_filter
28
 
29
  # Define the default parameters
 
60
  "inference_fn": birdast_seq_inference,
61
  }
62
 
63
+ birdvec_assets = {
64
+ "model_weights": [
65
+ f"https://huggingface.co/amroa/BirdVec/resolve/main/fold{i}/best-model{i}.ckpt" for i in range(3)
66
+ ],
67
+ "label_mapping": "https://huggingface.co/amroa/BirdVec/resolve/main/new_label_map.csv",
68
+ "preprocess_fn": birdvec_preprocess,
69
+ "inference_fn": birdvec_inference,
70
+ }
71
+
72
  # maintain a dictionary of assets
73
  ASSET_DICT = {
74
  "BirdAST": birdast_assets,
75
  "BirdAST_Seq": birdast_seq_assets,
76
+ "BirdWav2Vec": birdvec_assets,
77
  }
78
 
79
 
 
261
  gr.Markdown(DESCRIPTION)
262
 
263
  # add dropdown for model selection
264
+ model_names = ['BirdAST', 'BirdAST_Seq', 'BirdWav2Vec'] #, 'EfficientNet']
265
  model_dropdown = gr.Dropdown(label="Choose a model", choices=model_names)
266
  download_status = gr.Textbox(label="Model Status", lines=3, value='', interactive=False) # Non-interactive textbox for status
267
  model_dropdown.change(handle_model_selection, inputs=[model_dropdown, download_status], outputs=download_status)
birdvec.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import torch.nn.functional as F
4
+ import pytorch_lightning as pl
5
+ from transformers import AutoConfig, AutoFeatureExtractor, AutoModelForAudioClassification
6
+
7
+ DEFAULT_SR = 16_000
8
+ DEFAULT_BACKBONE = "MIT/ast-finetuned-audioset-10-10-0.4593"
9
+ DEFAULT_N_CLASSES = 728
10
+ MODEL_STR = "dima806/bird_sounds_classification" #"facebook/wav2vec2-base-960h"
11
+ RATE_HZ = 16000
12
+ # Define the maximum audio interval length to consider in seconds
13
+ MAX_SECONDS = 10
14
+ # Calculate the maximum audio interval length in samples by multiplying the rate and seconds
15
+ MAX_LENGTH = RATE_HZ * MAX_SECONDS
16
+
17
+ # Create an instance of the feature extractor for audio.
18
+ FEATURE_EXTRACTOR = AutoFeatureExtractor.from_pretrained(MODEL_STR)
19
+
20
+
21
+
22
+ def birdvec_preprocess(audio_array, sr=DEFAULT_SR):
23
+ """
24
+ Preprocess audio array for BirdAST model
25
+ audio_array: np.array, audio array of the recording, shape (n_samples,) Note: The audio array should be normalized to [-1, 1]
26
+ sr: int, sampling rate of the audio array (default: 16_000)
27
+
28
+ Note:
29
+ 1. The audio array should be normalized to [-1, 1].
30
+ 2. The audio length should be 10 seconds (or 10.24 seconds). Longer audio will be truncated.
31
+ """
32
+ # Extract features
33
+ features = FEATURE_EXTRACTOR(audio_array, sampling_rate=DEFAULT_SR, max_length=MAX_LENGTH, truncation=True, return_tensors="pt")
34
+
35
+ return features.input_values
36
+
37
+
38
+ def birdvec_inference(
39
+ model_weights,
40
+ spectrogram,
41
+ device = 'cpu',
42
+ backbone_name=None,
43
+ n_classes=728,
44
+ activation=None,
45
+ n_mlp_layers=None
46
+ ):
47
+
48
+ """
49
+ Perform inference on BirdAST model
50
+ model_weights: list, list of model weights
51
+ spectrogram: torch.Tensor, spectrogram tensor, shape (batch_size, n_frames, n_mels,)
52
+ device: str, device to run inference (default: 'cpu')
53
+ backbone_name: str, name of the backbone model (default: 'MIT/ast-finetuned-audioset-10-10-0.4593')
54
+ n_classes: int, number of classes (default: 728)
55
+ activation: str, activation function (default: 'silu')
56
+ n_mlp_layers: int, number of MLP layers (default: 1)
57
+
58
+ Returns:
59
+ predictions: np.array, array of predictions, shape (n_models, batch_size, n_classes)
60
+ """
61
+
62
+
63
+
64
+ predict_collects = []
65
+ for _weights in model_weights:
66
+ #model.load_state_dict(torch.load(_weights, map_location=device)['state_dict'])
67
+ model = BirdSongClassifier.load_from_checkpoint(_weights, map_location=device, class_weights = None)
68
+ if device != 'cpu': model.to(device)
69
+ model.eval()
70
+
71
+ with torch.no_grad():
72
+ if device != 'cpu': spectrogram = spectrogram.to(device)
73
+
74
+ output = model(spectrogram)
75
+ logits = output['logits']
76
+ probs = F.softmax(logits, dim=-1)
77
+ predict_collects.append(probs)
78
+
79
+ if device != 'cpu':
80
+ predict_collects = [pred.cpu() for pred in predict_collects]
81
+
82
+ predict_collects = torch.cat(predict_collects, dim=0).numpy()
83
+
84
+ return predict_collects
85
+
86
+
87
+ class BirdSongClassifier(pl.LightningModule):
88
+ def __init__(self, class_weights):
89
+ super().__init__()
90
+ config = AutoConfig.from_pretrained("dima806/bird_sounds_classification")
91
+ config.num_labels = 728
92
+ self.model = AutoModelForAudioClassification.from_config(config)
93
+
94
+ def forward(self, x):
95
+ return self.model(x)
fetch_img.py CHANGED
@@ -13,9 +13,6 @@ REQ_FMT = {
13
  "url": 'https://api.ebird.org/v2/ref/taxonomy/ebird',
14
  "params" : {
15
  'species': 'CHANGE THIS TO SPECIES CODE'
16
- },
17
- "headers" : {
18
- 'X-eBirdApiToken': 'id1a0e3q2lt3'
19
  }
20
  }
21
  bird_df = pd.read_csv("ebird_taxonomy_v2023.csv")
 
13
  "url": 'https://api.ebird.org/v2/ref/taxonomy/ebird',
14
  "params" : {
15
  'species': 'CHANGE THIS TO SPECIES CODE'
 
 
 
16
  }
17
  }
18
  bird_df = pd.read_csv("ebird_taxonomy_v2023.csv")