|
import warnings |
|
warnings.filterwarnings("ignore") |
|
|
|
import os |
|
import numpy as np |
|
import pandas as pd |
|
from typing import Iterable |
|
from styling import js, seafoam, css, DESCRIPTION |
|
|
|
import gradio as gr |
|
from gradio.themes.base import Base |
|
from gradio.themes.utils import colors, fonts, sizes |
|
import requests |
|
import torch |
|
import shutil |
|
import librosa |
|
import torch.nn.functional as F |
|
|
|
|
|
from fetch_img import download_images, scientific_to_species_code |
|
|
|
|
|
from audio_class_predictor import predict_class |
|
from bird_ast_model import birdast_preprocess, birdast_inference |
|
from bird_ast_seq_model import birdast_seq_preprocess, birdast_seq_inference |
|
from birdvec import birdvec_preprocess, birdvec_inference |
|
from utils import plot_wave, plot_mel, download_model, bandpass_filter |
|
|
|
|
|
ASSET_DIR = "./assets" |
|
DEFUALT_SR = 16_000 |
|
DEFUALT_HIGH_CUT = 8_000 |
|
DEFUALT_LOW_CUT = 1_000 |
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
print(f"Use Device: {DEVICE}") |
|
|
|
if not os.path.exists(ASSET_DIR): |
|
os.makedirs(ASSET_DIR) |
|
|
|
|
|
|
|
birdast_assets = { |
|
"model_weights": [ |
|
f"https://huggingface.co/shiyi-li/BirdAST/resolve/main/BirdAST_Baseline_GroupKFold_fold_{i}.pth" |
|
for i in range(5) |
|
], |
|
"label_mapping": "https://huggingface.co/shiyi-li/BirdAST/resolve/main/BirdAST_Baseline_GroupKFold_label_map.csv", |
|
"preprocess_fn": birdast_preprocess, |
|
"inference_fn": birdast_inference, |
|
} |
|
|
|
birdast_seq_assets = { |
|
"model_weights": [ |
|
f"https://huggingface.co/shiyi-li/BirdAST_Seq/resolve/main/BirdAST_SeqPool_GroupKFold_fold_{i}.pth" |
|
for i in range(5) |
|
], |
|
"label_mapping": "https://huggingface.co/shiyi-li/BirdAST_Seq/resolve/main/BirdAST_SeqPool_GroupKFold_label_map.csv", |
|
"preprocess_fn": birdast_seq_preprocess, |
|
"inference_fn": birdast_seq_inference, |
|
} |
|
|
|
birdvec_assets = { |
|
"model_weights": [ |
|
f"https://huggingface.co/amroa/BirdVec/resolve/main/fold{i}/best-model{i}.ckpt" for i in range(3) |
|
], |
|
"label_mapping": "https://huggingface.co/amroa/BirdVec/resolve/main/new_label_map.csv", |
|
"preprocess_fn": birdvec_preprocess, |
|
"inference_fn": birdvec_inference, |
|
} |
|
|
|
|
|
ASSET_DICT = { |
|
"BirdAST": birdast_assets, |
|
"BirdAST_Seq": birdast_seq_assets, |
|
"BirdWav2Vec": birdvec_assets, |
|
} |
|
|
|
|
|
def run_inference_with_model(audio_clip, sr, model_name): |
|
|
|
|
|
assets = ASSET_DICT[model_name] |
|
model_weights_url = assets["model_weights"] |
|
label_map_url = assets["label_mapping"] |
|
preprocess_fn = assets["preprocess_fn"] |
|
inference_fn = assets["inference_fn"] |
|
|
|
|
|
model_weights = [] |
|
for model_weight in model_weights_url: |
|
weight_file = os.path.join(ASSET_DIR, model_weight.split("/")[-1]) |
|
if not os.path.exists(weight_file): |
|
download_model(model_weight, weight_file) |
|
model_weights.append(weight_file) |
|
|
|
|
|
label_map_csv = os.path.join(ASSET_DIR, label_map_url.split("/")[-1]) |
|
if not os.path.exists(label_map_csv): |
|
download_model(label_map_url, label_map_csv) |
|
|
|
|
|
label_mapping = pd.read_csv(label_map_csv) |
|
species_id_to_name = {row["species_id"]: row["scientific_name"] for _, row in label_mapping.iterrows()} |
|
|
|
|
|
spectrogram = preprocess_fn(audio_clip, sr=sr) |
|
|
|
|
|
predictions = inference_fn(model_weights, spectrogram, device=DEVICE) |
|
|
|
|
|
final_predicts = predictions.mean(axis=0) |
|
topk_values, topk_indices = torch.topk(torch.from_numpy(final_predicts), 10) |
|
|
|
results = [] |
|
for idx, scores in zip(topk_indices, topk_values): |
|
species_name = species_id_to_name[idx.item()] |
|
probability = scores.item() * 100 |
|
results.append([species_name, probability]) |
|
|
|
return results |
|
|
|
def predict(audio, start, end, model_name="BirdAST_Seq"): |
|
|
|
raw_sr, audio_array = audio |
|
|
|
if audio_array.ndim > 1: |
|
audio_array = audio_array.mean(axis=1) |
|
|
|
print(f"Audio shape raw: {audio_array.shape}, sr: {raw_sr}") |
|
|
|
|
|
len_audio = audio_array.shape[0] / raw_sr |
|
if start >= end: |
|
raise gr.Error(f"`start` ({start}) must be smaller than end ({end}s)") |
|
|
|
if audio_array.shape[0] < start * raw_sr: |
|
raise gr.Error(f"`start` ({start}) must be smaller than audio duration ({len_audio:.0f}s)") |
|
|
|
if audio_array.shape[0] < end * raw_sr: |
|
end = audio_array.shape[0] / (1.0*raw_sr) |
|
|
|
audio_array = np.array(audio_array, dtype=np.float32) / 32768.0 |
|
audio_array = audio_array[int(start*raw_sr) : int(end*raw_sr)] |
|
|
|
if raw_sr != DEFUALT_SR: |
|
|
|
audio_array = bandpass_filter(audio_array, DEFUALT_LOW_CUT, DEFUALT_HIGH_CUT, raw_sr) |
|
audio_array = librosa.resample(audio_array, orig_sr=raw_sr, target_sr=DEFUALT_SR) |
|
print(f"Resampled Audio shape: {audio_array.shape}") |
|
|
|
audio_array = audio_array.astype(np.float32) |
|
|
|
|
|
audio_class = predict_class(audio_array) |
|
|
|
fig_spectrogram = plot_mel(DEFUALT_SR, audio_array) |
|
fig_waveform = plot_wave(DEFUALT_SR, audio_array) |
|
|
|
|
|
print(f"Running inference with model: {model_name}") |
|
species_class = run_inference_with_model(audio_array, DEFUALT_SR, model_name) |
|
print("Species is ", species_class[0][0].strip().replace("_", " ")) |
|
images = prepare_images(species_class[0][0].strip().replace("_", " ")) |
|
if len(images) == 0: |
|
images.append(("noimg.png", "No image")) |
|
return audio_class, species_class, fig_waveform, fig_spectrogram, images |
|
|
|
|
|
REFERENCES = """ |
|
# Appendix |
|
|
|
We have applied the AudioMAE model to pre-classify the 23000+ unlabelled audio clips collected from the Greater Manaus region in the Amazon rainforest. The results of the audio type classification can be found in the following [link](https://drive.google.com/file/d/1uOT88LDnBD-Z3YcFz1e9XjvW2ugCo6EI/view?usp=drive_link). We hope that the pre-classification results can help researchers better exploring the vast collection of audio recordings and facilitate the study of biodiversity in the Amazon rainforest. |
|
|
|
# References |
|
|
|
[1] Torkington, S. (2023, February 7). 50% of the global economy is under threat from biodiversity loss. World Economic Forum. Retrieved from https://www.weforum.org/agenda/2023/02/biodiversity-nature-loss-cop15/. |
|
|
|
[2] Huang, P.-Y., Xu, H., Li, J., Baevski, A., Auli, M., Galuba, W., Metze, F., & Feichtenhofer, C. (2022). Masked Autoencoders that Listen. In NeurIPS. |
|
|
|
[3] https://www.kaggle.com/code/dima806/bird-species-by-sound-detection |
|
|
|
# Acknowledgements |
|
|
|
We would like to thank all organizers, mentors and participants of the AI+Environment EcoHackathon 2024 event for their unwavering support and collaboration. We extend our gratitude to ETH BiodivX, GainForest and ETH AI Center for providing data, facilities and resources that enabled us to analyse the rich data in different ways. Our special thanks to David Dao, Sarah Tariq, Alessandro Amodio for always being there to help us! πππ |
|
""" |
|
|
|
|
|
def handle_model_selection(model_name, download_status): |
|
|
|
|
|
print(f"Downloading model weights for {model_name}...") |
|
|
|
if model_name is None: |
|
model_name = "BirdAST" |
|
|
|
assets = ASSET_DICT[model_name] |
|
model_weights_url = assets["model_weights"] |
|
download_flag = True |
|
try: |
|
total_files = len(model_weights_url) |
|
for idx, model_weight in enumerate(model_weights_url): |
|
weight_file = os.path.join(ASSET_DIR, model_weight.split("/")[-1]) |
|
print(weight_file) |
|
if not os.path.exists(weight_file): |
|
download_status = f"Downloading {idx + 1} of {total_files}" |
|
download_model(model_weight, weight_file) |
|
|
|
if not os.path.exists(weight_file): |
|
download_flag = False |
|
break |
|
|
|
if download_flag: |
|
download_status = f"Model <{model_name}> is ready! πππ\nUsing Device: {DEVICE.upper()}" |
|
else: |
|
download_status = f"An error occurred while downloading model weights." |
|
|
|
except Exception as e: |
|
download_status = f"An error occurred while downloading model weights." |
|
|
|
return download_status |
|
|
|
|
|
|
|
def prepare_images(scientific_name: str): |
|
|
|
scode = scientific_to_species_code(scientific_name) |
|
if not scode: |
|
return [] |
|
|
|
|
|
urls = download_images(f"https://ebird.org/species/{scode}") |
|
|
|
|
|
nsplit = scientific_name.split(" ") |
|
abbreviate_name = nsplit[0][0] + "." + " " + nsplit[1] |
|
|
|
|
|
if not urls: |
|
return [] |
|
|
|
return [(url, abbreviate_name) for url in urls] |
|
|
|
sp_and_cl = """<div align="center"> |
|
<b> <h2> Class and Species Prediction </h2> </b> |
|
</div>""" |
|
|
|
sig_prop = """<div align="center"> |
|
<b> <h2> Signal Visualization </h2> </b> |
|
</div>""" |
|
|
|
imgs = """<div align="center"> |
|
<b> <h2> Bird Gallery </h2> </b> |
|
</div>""" |
|
|
|
with gr.Blocks(theme = seafoam, css = css, js = js) as demo: |
|
|
|
gr.Markdown('<div class="logo-container"><img src="https://i.ibb.co/pQLcLwf/vojlogo.png" width="50px" alt="vojlogo"></div>') |
|
gr.Markdown('<div id="gradio-animation"></div>') |
|
gr.Markdown(DESCRIPTION) |
|
|
|
|
|
model_names = ['BirdAST', 'BirdAST_Seq', 'BirdWav2Vec'] |
|
model_dropdown = gr.Dropdown(label="Choose a model", choices=model_names) |
|
download_status = gr.Textbox(label="Model Status", lines=3, value='', interactive=False) |
|
model_dropdown.change(handle_model_selection, inputs=[model_dropdown, download_status], outputs=download_status) |
|
|
|
|
|
with gr.Row(): |
|
with gr.Column(elem_classes="column-container"): |
|
start_time_input = gr.Number(label="Start Time", value=0, elem_classes="number-input full-height") |
|
end_time_input = gr.Number(label="End Time", value=10, elem_classes="number-input full-height") |
|
with gr.Column(): |
|
audio_input = gr.Audio(label="Input Audio", elem_classes="full-height") |
|
|
|
gr.Markdown(sp_and_cl) |
|
with gr.Column(): |
|
with gr.Row(): |
|
raw_class_output = gr.Dataframe(headers=["Class", "Score [%]"], row_count=10, label="Class Prediction") |
|
species_output = gr.Dataframe(headers=["Class", "Score [%]"], row_count=10, label="Species Prediction") |
|
|
|
gr.Markdown(sig_prop) |
|
with gr.Column(): |
|
with gr.Row(): |
|
waveform_output = gr.Plot(label="Waveform") |
|
spectrogram_output = gr.Plot(label="Spectrogram") |
|
gr.Markdown(imgs) |
|
gallery = gallery = gr.Gallery(label="Species Images", show_label=False, elem_id="gallery",columns=[3], rows=[1], object_fit="contain", height="auto") |
|
|
|
gr.Button("Predict").click(predict, [audio_input, start_time_input, end_time_input, model_dropdown], [raw_class_output, species_output, waveform_output, spectrogram_output, gallery]) |
|
|
|
gr.Examples( |
|
examples=[ |
|
["XC226833-Chestnut-belted_20Chat-Tyrant_20A_2010989.mp3", 0, 10], |
|
["XC812290-Many-striped-Canastero_Teaben_Pe_1jul2022_FSchmitt_1.mp3", 0, 10], |
|
["XC763511-Synallaxis-maronica_Bagua-grande_MixPre-1746.mp3", 0, 10] |
|
], |
|
inputs=[audio_input, start_time_input, end_time_input] |
|
) |
|
|
|
gr.Markdown(REFERENCES) |
|
|
|
demo.launch(share = True) |
|
|
|
|
|
|