File size: 16,192 Bytes
dddb9f9
 
 
d967830
dddb9f9
 
8c4ff63
dddb9f9
8c4ff63
 
 
dddb9f9
 
 
d967830
 
dddb9f9
 
 
 
d967830
dddb9f9
d967830
dddb9f9
 
 
 
 
e593cad
d967830
e593cad
faad4ba
dddb9f9
 
d967830
 
dddb9f9
 
 
 
 
 
 
 
 
 
d967830
dddb9f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d967830
6ded705
dddb9f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
faad4ba
dddb9f9
 
 
 
 
 
 
 
 
faad4ba
dddb9f9
 
 
 
 
 
 
 
 
6ded705
dddb9f9
faad4ba
 
dddb9f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
00a5975
dddb9f9
 
 
 
 
 
 
 
 
 
d967830
dddb9f9
 
 
 
 
 
 
 
 
 
 
 
 
faad4ba
d967830
 
37989dd
 
 
 
 
 
 
 
 
 
 
8c4ff63
 
37989dd
 
bfb7f54
37989dd
 
 
 
e593cad
37989dd
 
 
 
 
 
 
 
 
 
 
 
 
 
e593cad
 
37989dd
 
e593cad
37989dd
 
36dbf7a
37989dd
36dbf7a
37989dd
 
 
d967830
 
 
 
36dbf7a
 
 
 
 
 
 
 
 
 
 
 
 
d967830
 
 
 
 
 
 
 
8c4ff63
 
d967830
8c4ff63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36dbf7a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35f2fe5
 
 
 
37989dd
36dbf7a
37989dd
36dbf7a
37989dd
36dbf7a
 
35f2fe5
 
 
 
36dbf7a
dddb9f9
 
35379ab
dddb9f9
 
 
e593cad
 
 
 
dddb9f9
 
 
 
35379ab
 
dddb9f9
 
 
35379ab
dddb9f9
 
 
 
 
 
 
e593cad
dddb9f9
35379ab
dddb9f9
 
35379ab
 
 
dddb9f9
 
36dbf7a
8c4ff63
36dbf7a
 
d967830
dddb9f9
 
37989dd
dddb9f9
35379ab
 
 
 
dddb9f9
d967830
 
 
abede12
d967830
 
 
 
8c4ff63
 
d967830
 
 
 
 
e593cad
 
 
 
 
 
 
 
d967830
dddb9f9
d967830
36dbf7a
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
import warnings
warnings.filterwarnings("ignore")

import os
import numpy as np
import pandas as pd
from typing import Iterable

import gradio as gr
from gradio.themes.base import Base
from gradio.themes.utils import colors, fonts, sizes

import torch
import librosa
import torch.nn.functional as F

# Import the necessary functions from the voj package
from audio_class_predictor import predict_class
from bird_ast_model import birdast_preprocess, birdast_inference
from bird_ast_seq_model import birdast_seq_preprocess, birdast_seq_inference

from utils import plot_wave, plot_mel, download_model, bandpass_filter

# Define the default parameters
ASSET_DIR = "./assets"
DEFUALT_SR = 16_000
DEFUALT_HIGH_CUT = 8_000
DEFUALT_LOW_CUT = 1_000
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Use Device: {DEVICE}")

if not os.path.exists(ASSET_DIR):
    os.makedirs(ASSET_DIR)


# define the assets for the models
birdast_assets = {
    "model_weights": [
        f"https://huggingface.co/shiyi-li/BirdAST/resolve/main/BirdAST_Baseline_GroupKFold_fold_{i}.pth"
        for i in range(5)
    ],
    "label_mapping": "https://huggingface.co/shiyi-li/BirdAST/resolve/main/BirdAST_Baseline_GroupKFold_label_map.csv",
    "preprocess_fn": birdast_preprocess,
    "inference_fn": birdast_inference,
}

birdast_seq_assets = {
    "model_weights": [
        f"https://huggingface.co/shiyi-li/BirdAST_Seq/resolve/main/BirdAST_SeqPool_GroupKFold_fold_{i}.pth"
        for i in range(5)
    ],
    "label_mapping": "https://huggingface.co/shiyi-li/BirdAST_Seq/resolve/main/BirdAST_SeqPool_GroupKFold_label_map.csv",
    "preprocess_fn": birdast_seq_preprocess,
    "inference_fn": birdast_seq_inference,
}

# maintain a dictionary of assets
ASSET_DICT = {
    "BirdAST": birdast_assets,
    "BirdAST_Seq": birdast_seq_assets,
}


def run_inference_with_model(audio_clip, sr, model_name):
    
    # download the model assets
    assets = ASSET_DICT[model_name]
    model_weights_url = assets["model_weights"]
    label_map_url = assets["label_mapping"]
    preprocess_fn = assets["preprocess_fn"]
    inference_fn = assets["inference_fn"]
    
    # download the model weights
    model_weights = []
    for model_weight in model_weights_url:
        weight_file = os.path.join(ASSET_DIR, model_weight.split("/")[-1])
        if not os.path.exists(weight_file):
            download_model(model_weight, weight_file)
        model_weights.append(weight_file)
    
    # download the label mapping
    label_map_csv = os.path.join(ASSET_DIR, label_map_url.split("/")[-1])
    if not os.path.exists(label_map_csv):
        download_model(label_map_url, label_map_csv)
    
    # load the label mapping
    label_mapping = pd.read_csv(label_map_csv)
    species_id_to_name = {row["species_id"]: row["scientific_name"] for _, row in label_mapping.iterrows()}
    
    # preprocess the audio clip
    spectrogram = preprocess_fn(audio_clip, sr=sr)
    
    # run inference
    predictions = inference_fn(model_weights, spectrogram, device=DEVICE)

    # aggregate the results
    final_predicts = predictions.mean(axis=0)
    topk_values, topk_indices = torch.topk(torch.from_numpy(final_predicts), 10)
    
    results = []
    for idx, scores in zip(topk_indices, topk_values):
        species_name = species_id_to_name[idx.item()]
        probability = scores.item() * 100
        results.append([species_name, probability])

    return results


def predict(audio, start, end, model_name="BirdAST_Seq"):
    
    raw_sr, audio_array = audio
    
    if audio_array.ndim > 1:
        audio_array = audio_array.mean(axis=1) # convert to mono
    
    print(f"Audio shape raw: {audio_array.shape}, sr: {raw_sr}")
    
    # sainty checks
    len_audio = audio_array.shape[0] / raw_sr
    if start >= end:
        raise gr.Error(f"`start` ({start}) must be smaller than end ({end}s)")
    
    if audio_array.shape[0] < start * raw_sr:
        raise gr.Error(f"`start` ({start}) must be smaller than audio duration ({len_audio:.0f}s)")
    
    if audio_array.shape[0] < end * raw_sr:
        end = audio_array.shape[0] / (1.0*raw_sr)
    
    audio_array = np.array(audio_array, dtype=np.float32) / 32768.0
    audio_array = audio_array[int(start*raw_sr) : int(end*raw_sr)]
    
    if raw_sr != DEFUALT_SR:
        # run bandpass filter & resample
        audio_array = bandpass_filter(audio_array, DEFUALT_LOW_CUT, DEFUALT_HIGH_CUT, raw_sr)
        audio_array = librosa.resample(audio_array, orig_sr=raw_sr, target_sr=DEFUALT_SR)
        print(f"Resampled Audio shape: {audio_array.shape}")
        
        audio_array = audio_array.astype(np.float32)

    # predict audio class 
    audio_class = predict_class(audio_array)
    
    fig_spectrogram = plot_mel(DEFUALT_SR, audio_array)
    fig_waveform = plot_wave(DEFUALT_SR, audio_array)
    
    # run inference with model
    print(f"Running inference with model: {model_name}")
    species_class = run_inference_with_model(audio_array, DEFUALT_SR, model_name)

    return audio_class, species_class, fig_waveform, fig_spectrogram


DESCRIPTION = """

<div align="center">
<b>Team Members: </b>

Amro Abdrabo [amro.abdrabo@gmail.com | [LinkedIn](https://www.linkedin.com/in/amroabdrabo/)]

Shiyi Li  [shiyi.li@ifu.baug.ethz.ch | [LinkedIn](www.linkedin.com/in/shiyili01)]

Thomas Radinger [ thomasrad@protonmail.com | [LinkedIn](https://www.linkedin.com/in/thomas-radinger-743958142/) ]
</div>

# Introduction 

Birds are key indicators of ecosystem health and play pivotal roles in maintaining biodiversity [1]. To monitor and protect bird species, automatic bird sound recognition systems are essential. These systems can help in identifying bird species, monitoring their populations, and understanding their behavior. However, building such systems is challenging due to the diversity of bird sounds, complex acoustic interference and limited labeled data. 

To tackle these challenges, we expored the potential of deep learning models for bird sound recognition. In our work, we developed two Audio Spectrogram Transformer (AST) based models: BirdAST and BirdAST_Seq, to predict bird species from audio recordings. We evaluated the models on a dataset of 728 bird species and achieved promising results.  Details of the models and evaluation results are provided in the table below. As the field-recordings may contain various types of audio rather than only bird songs/calls, we also employed an Audio Masked AutoEncoder (AudioMAE) model to pre-classify audio clips into bird, insects, rain, environmental noise, and other types [2]. For a full report on work workflow and results, please refer to [link](https://docs.google.com/document/d/17uRGEVz4hxShK4fvWQzIKFJlVwEg9p1rAT9XXDYGE3w/edit?usp=sharing).

Our contributions have shown the potential of deep learning models for bird sound recognition. We hope that our work can contribute to the development of automatic bird sound recognition systems and help in monitoring and protecting bird species.

<div align="center">
<b>Model Details</b>

| Model name       | Architecture                   | ROC-AUC Score |
| ---------------  |:------------------------------:|:-------------:|
| BirdAST          | AST* + MLP                     | 0.6825        |
| BirdAST_Seq      | AST* + Sequence Pooling + MLP  | 0.7335        |

</div>

# How to use the space:

1. Choose a model from the dropdown list. It will download the model weights automatically if not already downloaded (~30 seconds).
2. Upload an audio clip and specify the start and end time for prediction. 
3. Click on the "Predict" button to get the predictions.
4. In the output, you will get the audio type classification (e.g., bird, insects, rain, etc.) in the panel "Class Prediction" and the predicted bird species in the panel "Species Prediction". 
    * The audio types are predicted as multi-lable classification based on the AudioMAE model. The predicted classes indicate the possible presence of different types of audio in the recording.
    * The bird species are predicted as a multi-class classification using the selected model. The predicted classes indicate the most possible bird species present in the recording. 
5. The waveform and spectrogram of the audio clip are displayed in the respective panels.

**Notes:**
- For an unknown bird species, the model may predict the most similar bird species based on the training data.
- If an audio clip contains non-bird sounds (predicted by the AudioMAE), the bird species prediction may not be accurate.

**Disclaimer**: The model predictions are based on the training data and may not be accurate for all audio clips. The model is trained on a dataset of 728 bird species and may not generalize well to all bird species.

<div align="center">
  <b>Enjoy the Bird Songs! 🐦🎢
</div>
"""


css = """
#gradio-animation {
    font-size: 2em;
    font-weight: bold;
    text-align: center;
    margin-bottom: 20px;
}

.logo-container img {
    width: 14%;  /* Adjust width as necessary */
    display: block;
    margin: auto;
}

.number-input {
    height: 100%;
    padding-bottom: 60px; /* Adust the value as needed for more or less space */
}
.full-height {
    height: 100%;
}
.column-container {
    height: 100%; 
} 
"""



class Seafoam(Base):
    def __init__(
        self,
        *,
        primary_hue: colors.Color | str = colors.emerald,
        secondary_hue: colors.Color | str = colors.blue,
        neutral_hue: colors.Color | str = colors.gray,
        spacing_size: sizes.Size | str = sizes.spacing_md,
        radius_size: sizes.Size | str = sizes.radius_md,
        text_size: sizes.Size | str = sizes.text_lg,
        font: fonts.Font
        | str
        | Iterable[fonts.Font | str] = (
            fonts.GoogleFont("Quicksand"),
            "ui-sans-serif",
            "sans-serif",
        ),
        font_mono: fonts.Font
        | str
        | Iterable[fonts.Font | str] = (
            fonts.GoogleFont("IBM Plex Mono"),
            "ui-monospace",
            "monospace",
        ),
    ):
        super().__init__(
            primary_hue=primary_hue,
            secondary_hue=secondary_hue,
            neutral_hue=neutral_hue,
            spacing_size=spacing_size,
            radius_size=radius_size,
            text_size=text_size,
            font=font,
            font_mono=font_mono,
        )


seafoam = Seafoam()


js = """
function createGradioAnimation() {
    var container = document.getElementById('gradio-animation');
    var text = 'Voice of Jungle';
    for (var i = 0; i < text.length; i++) {
        (function(i){
            setTimeout(function(){
                var letter = document.createElement('span');
                letter.style.opacity = '0';
                letter.style.transition = 'opacity 0.5s';
                letter.innerText = text[i];
                container.appendChild(letter);
                setTimeout(function() {
                    letter.style.opacity = '1';
                }, 50);
            }, i * 250);
        })(i);
    }
}
"""

REFERENCES = """
# Appendix

We have applied the AudioMAE model to pre-classify the 23000+ unlabelled audio clips collected from the Greater Manaus region in the Amazon rainforest. The results of the audio type classification can be found in the following [link](https://drive.google.com/file/d/1uOT88LDnBD-Z3YcFz1e9XjvW2ugCo6EI/view?usp=drive_link). We hope that the pre-classification results can help researchers better exploring the vast collection of audio recordings and facilitate the study of biodiversity in the Amazon rainforest.

# References

[1] Torkington, S. (2023, February 7). 50% of the global economy is under threat from biodiversity loss. World Economic Forum. Retrieved from https://www.weforum.org/agenda/2023/02/biodiversity-nature-loss-cop15/. 

[2] Huang, P.-Y., Xu, H., Li, J., Baevski, A., Auli, M., Galuba, W., Metze, F., & Feichtenhofer, C. (2022). Masked Autoencoders that Listen. In NeurIPS.

[3] https://www.kaggle.com/code/dima806/bird-species-by-sound-detection

# Acknowledgements

We would like to thank all organizers, mentors and participants of the AI+Environment EcoHackathon 2024 event for their unwavering support and collaboration. We extend our gratitude to ETH BiodivX, GainForest and ETH AI Center for providing data, facilities and resources that enabled us to analyse the rich data in different ways. Our special thanks to David Dao, Sarah Tariq, Alessandro Amodio for always being there to help us! πŸ™πŸ™πŸ™
"""

# Function to handle model selection
def handle_model_selection(model_name, download_status):
    # Inform user that download is starting
    # gr.Info(f"Downloading model weights for {model_name}...")
    print(f"Downloading model weights for {model_name}...")
    
    if model_name is None:
        model_name = "BirdAST"
        
    assets = ASSET_DICT[model_name]
    model_weights_url = assets["model_weights"]
    download_flag = True
    try:
        total_files = len(model_weights_url)
        for idx, model_weight in enumerate(model_weights_url):
            weight_file = os.path.join(ASSET_DIR, model_weight.split("/")[-1])
            print(weight_file)
            if not os.path.exists(weight_file):
                download_status = f"Downloading {idx + 1} of {total_files}"
                download_model(model_weight, weight_file)
            
            if not os.path.exists(weight_file):
                download_flag = False
                break
            
        if download_flag:
            download_status =  f"Model <{model_name}> is ready! πŸŽ‰πŸŽ‰πŸŽ‰\nUsing Device: {DEVICE.upper()}"
        else:
            download_status = f"An error occurred while downloading model weights."
            
    except Exception as e:
        download_status = f"An error occurred while downloading model weights."
        
    return download_status


with gr.Blocks(theme = seafoam, css = css, js = js) as demo:
    
    gr.Markdown('<div class="logo-container"><img src="https://i.ibb.co/vcG9kr0/vojlogo.jpg" width="50px" alt="vojlogo"></div>')
    gr.Markdown('<div id="gradio-animation"></div>')
    gr.Markdown(DESCRIPTION)
    
    # add dropdown for model selection
    model_names = ['BirdAST', 'BirdAST_Seq'] #, 'EfficientNet']
    model_dropdown = gr.Dropdown(label="Choose a model", choices=model_names)
    download_status = gr.Textbox(label="Model Status", lines=3, value='', interactive=False) # Non-interactive textbox for status

    model_dropdown.change(handle_model_selection, inputs=[model_dropdown, download_status], outputs=download_status)

    
    with gr.Row():
        with gr.Column(elem_classes="column-container"):
            start_time_input = gr.Number(label="Start Time", value=0, elem_classes="number-input full-height")
            end_time_input = gr.Number(label="End Time", value=10, elem_classes="number-input full-height")
        with gr.Column():
            audio_input = gr.Audio(label="Input Audio", elem_classes="full-height")
  
    with gr.Row():
        raw_class_output = gr.Dataframe(headers=["Class", "Score [%]"], row_count=10, label="Class Prediction")
        species_output = gr.Dataframe(headers=["Class", "Score [%]"], row_count=10, label="Species Prediction")
        
    with gr.Row():
        waveform_output = gr.Plot(label="Waveform")
        spectrogram_output = gr.Plot(label="Spectrogram")
    
    gr.Examples(
        examples=[
            ["XC226833-Chestnut-belted_20Chat-Tyrant_20A_2010989.mp3", 0, 10],
            ["XC812290-Many-striped-Canastero_Teaben_Pe_1jul2022_FSchmitt_1.mp3", 0, 10],
            ["XC763511-Synallaxis-maronica_Bagua-grande_MixPre-1746.mp3", 0, 10]
        ],
        inputs=[audio_input, start_time_input, end_time_input]
    )
    
    gr.Button("Predict").click(predict, [audio_input, start_time_input, end_time_input, model_dropdown], [raw_class_output, species_output, waveform_output, spectrogram_output])

    gr.Markdown(REFERENCES)

demo.launch(share = True)

## logo: <img src="https://i.ibb.co/vcG9kr0/vojlogo.jpg" alt="vojlogo" border="0">
## cactus: <img src="https://i.ibb.co/3sW2mJN/spur.jpg" alt="spur" border="0">