import collections import heapq import json import os import logging import faiss import requests import gradio as gr import numpy as np import torch import torch.nn.functional as F from open_clip import create_model, get_tokenizer from torchvision import transforms from PIL import Image import io from pathlib import Path from huggingface_hub import hf_hub_download log_format = "[%(asctime)s] [%(levelname)s] [%(name)s] %(message)s" logging.basicConfig(level=logging.INFO, format=log_format) logger = logging.getLogger() hf_token = os.getenv("HF_TOKEN") model_str = "hf-hub:imageomics/bioclip" tokenizer_str = "ViT-B-16" txt_emb_npy = hf_hub_download(repo_id="pyesonekyaw/biome_lfs", filename='txt_emb_species.npy', repo_type="dataset") txt_names_json = "txt_emb_species.json" min_prob = 1e-9 k = 5 ranks = ("Kingdom", "Phylum", "Class", "Order", "Family", "Genus", "Species") device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") preprocess_img = transforms.Compose( [ transforms.ToTensor(), transforms.Resize((224, 224), antialias=True), transforms.Normalize( mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711), ), ] ) MIN_PROB = 1e-9 TOP_K_PREDICTIONS = 5 TOP_K_CANDIDATES = 250 TOP_N_SIMILAR = 22 SIMILARITY_BOOST = 0.2 VOTE_THRESHOLD = 3 SIMILARITY_THRESHOLD = 0.99 # Add paths for RAG PHOTO_LOOKUP_PATH = f"./photo_lookup.json" SPECIES_LOOKUP_PATH = f"./species_lookup.json" theme = gr.themes.Base( primary_hue=gr.themes.colors.teal, secondary_hue=gr.themes.colors.blue, neutral_hue=gr.themes.colors.gray, text_size=gr.themes.sizes.text_lg, ).set( button_primary_background_fill="#114A56", button_primary_background_fill_hover="#114A56", block_title_text_weight="600", block_label_text_weight="600", block_label_text_size="*text_md", ) EXAMPLES_DIR = Path("examples") example_images = sorted(str(p) for p in EXAMPLES_DIR.glob("*.jpg")) def indexed(lst, indices): return [lst[i] for i in indices] def format_name(taxon, common): taxon = " ".join(taxon) if not common: return taxon return f"{taxon} ({common})" def combine_duplicate_predictions(predictions): """Combine predictions where one name is contained within another.""" combined = {} used = set() # Sort by length of name (longer names first) and probability items = sorted(predictions.items(), key=lambda x: (-len(x[0]), -x[1])) for name1, prob1 in items: if name1 in used: continue total_prob = prob1 used.add(name1) # Check remaining predictions for name2, prob2 in predictions.items(): if name2 in used: continue # Convert to lowercase for comparison name1_lower = name1.lower() name2_lower = name2.lower() # Check if one name contains the other if name1_lower in name2_lower or name2_lower in name1_lower: total_prob += prob2 used.add(name2) combined[name1] = total_prob # Normalize probabilities total = sum(combined.values()) return {k: v/total for k, v in combined.items()} @torch.no_grad() def open_domain_classification(img, rank: int, return_all=False): """ Predicts from the entire tree of life using RAG approach. """ logger.info(f"Starting open domain classification for rank: {rank}") img = preprocess_img(img).to(device) img_features = model.encode_image(img.unsqueeze(0)) img_features = F.normalize(img_features, dim=-1) # Get zero-shot predictions logits = (model.logit_scale.exp() * img_features @ txt_emb).squeeze() probs = F.softmax(logits, dim=0) # Get similar images votes and metadata species_votes, similar_images = get_similar_images_metadata(img_features, faiss_index, id_mapping, name_mapping) if rank + 1 == len(ranks): # Species level prediction topk = probs.topk(TOP_K_CANDIDATES) predictions = { format_name(*txt_names[i]): prob.item() for i, prob in zip(topk.indices, topk.values) } # Augment predictions with votes augmented_predictions = predictions.copy() for pred_name in predictions: pred_name_lower = pred_name.lower() for voted_species, vote_count in species_votes.items(): if voted_species in pred_name_lower or pred_name_lower in voted_species: augmented_predictions[pred_name] += SIMILARITY_BOOST * vote_count elif vote_count >= VOTE_THRESHOLD: augmented_predictions[voted_species] = vote_count * SIMILARITY_BOOST # Sort predictions sorted_predictions = dict(sorted( augmented_predictions.items(), key=lambda x: x[1], reverse=True )[:k]) # Normalize and combine duplicates total = sum(sorted_predictions.values()) sorted_predictions = {k: v/total for k, v in sorted_predictions.items()} sorted_predictions = combine_duplicate_predictions(sorted_predictions) logger.info(f"Top K predictions after combining duplicates: {sorted_predictions}") return sorted_predictions, similar_images # Higher rank prediction output = collections.defaultdict(float) for i in torch.nonzero(probs > MIN_PROB).squeeze(): output[" ".join(txt_names[i][0][: rank + 1])] += probs[i] # Incorporate votes for higher ranks for species, vote_count in species_votes.items(): try: # Find matching taxonomy in txt_names for taxonomy, _ in txt_names: if species in " ".join(taxonomy).lower(): higher_rank = " ".join(taxonomy[: rank + 1]) output[higher_rank] += SIMILARITY_BOOST * vote_count break except Exception as e: logger.error(f"Error processing vote for species {species}: {e}") # Get top-k predictions and normalize topk_names = heapq.nlargest(k, output, key=output.get) prediction_dict = {name: output[name] for name in topk_names} # Normalize probabilities to sum to 1 total = sum(prediction_dict.values()) prediction_dict = {k: v/total for k, v in prediction_dict.items()} prediction_dict = combine_duplicate_predictions(prediction_dict) logger.info(f"Prediction dictionary after combining duplicates: {prediction_dict}") return prediction_dict, similar_images def change_output(choice): return gr.Label(num_top_classes=k, label=ranks[choice], show_label=True, value=None) def get_cache_paths(name="demo"): """Get paths for cached FAISS index and ID mapping.""" return { 'index': hf_hub_download(repo_id="pyesonekyaw/biome_lfs", filename='cache/faiss_cache_demo.index', repo_type="dataset"), 'mapping': hf_hub_download(repo_id="pyesonekyaw/biome_lfs", filename='cache/faiss_cache_demo_mapping.json', repo_type="dataset") } def build_name_mapping(txt_names): """Build mapping between scientific names and common names.""" name_mapping = {} for taxonomy, common_name in txt_names: if not common_name: continue if len(taxonomy) >= 2: scientific_name = f"{taxonomy[-2]} {taxonomy[-1]}".lower() common_name = common_name.lower() name_mapping[scientific_name] = (scientific_name, common_name) name_mapping[common_name] = (scientific_name, common_name) return name_mapping def load_faiss_index(): """Load FAISS index from cache.""" cache_paths = get_cache_paths() logger.info("Loading FAISS index from cache...") index = faiss.read_index(cache_paths['index']) with open(cache_paths['mapping'], 'r') as f: id_mapping = json.load(f) return index, id_mapping def get_similar_images_metadata(img_embedding, faiss_index, id_mapping, name_mapping): """Get metadata for similar images using FAISS search.""" img_embedding_np = img_embedding.cpu().numpy() if img_embedding_np.ndim == 1: img_embedding_np = img_embedding_np.reshape(1, -1) # Search for more images than needed to account for filtered matches distances, indices = faiss_index.search(img_embedding_np, TOP_N_SIMILAR * 2) # Filter out near-exact matches valid_indices = [] valid_distances = [] valid_count = 0 for dist, idx in zip(distances[0], indices[0]): # For inner product similarity, the distance is already the similarity similarity = dist if similarity > SIMILARITY_THRESHOLD: continue valid_indices.append(idx) valid_distances.append(similarity) valid_count += 1 if valid_count >= TOP_N_SIMILAR: break species_votes = {} similar_images = [] for idx, similarity in zip(valid_indices[:5], valid_distances[:5]): # Only process top 5 for display similar_img_id = id_mapping[idx] try: species_names = id_to_species_info.get(similar_img_id) species_names = [name for name in species_names if name] processed_names = set() for species in species_names: if not species: continue name_tuple = name_mapping.get(species) if name_tuple: processed_names.add(name_tuple[0]) else: processed_names.add(species) for species in processed_names: species_votes[species] = species_votes.get(species, 0) + 1 # Store similar image info if the image file exists # if img_path and os.path.exists(img_path): similar_images.append({ 'id': similar_img_id, 'species': next(iter(processed_names)) if processed_names else 'Unknown', 'common_name': species_names[-1], 'similarity': similarity # Add similarity score }) except Exception as e: logger.error(f"Error processing JSON for image {similar_img_id}: {e}") continue return species_votes, similar_images if __name__ == "__main__": logger.info("Starting.") model = create_model(model_str, output_dict=True, require_pretrained=True) model = model.to(device) logger.info("Created model.") model = torch.compile(model) logger.info("Compiled model.") tokenizer = get_tokenizer(tokenizer_str) id_to_photo_url = json.load(open(PHOTO_LOOKUP_PATH)) id_to_species_info = json.load(open(SPECIES_LOOKUP_PATH)) logger.info(f"Loaded {len(id_to_photo_url)} photo mappings") logger.info(f"Loaded {len(id_to_species_info)} species mappings") # Load text embeddings and build name mapping txt_emb = torch.from_numpy(np.load(txt_emb_npy, mmap_mode="r")).to(device) with open(txt_names_json) as fd: txt_names = json.load(fd) # Build name mapping name_mapping = build_name_mapping(txt_names) # Build or load FAISS index with test IDs faiss_index, id_mapping = load_faiss_index() # Define process_output function before using it def process_output(img, rank): predictions, similar_imgs = open_domain_classification(img, rank) logger.info(f"Number of similar images found: {len(similar_imgs)}") images = [] labels = [] for img_info in similar_imgs: img_id = img_info['id'] img_url = id_to_photo_url.get(img_id) img_url = img_url.replace("square", "small") logger.info(f"Processing image URL: {img_url}") try: # Try fetching from URL first response = requests.get(img_url) if response.status_code == 200: try: img = Image.open(io.BytesIO(response.content)) images.append(img) except Exception as e: logger.info(f"Failed to load image from URL: {e}") images.append(None) else: logger.info(f"Failed to fetch image from URL: {response}") images.append(None) # Add label regardless of image load success label = f"**{img_info['species']}**" if img_info['common_name']: label += f" ({img_info['common_name']})" label += f"\nSimilarity: {img_info['similarity']:.3f}" label += f"\n[View on iNaturalist](https://www.inaturalist.org/observations/{img_id})" labels.append(label) except Exception as e: logger.error(f"Error processing image {img_id}: {e}") images.append(None) labels.append("") # Pad arrays if needed images += [None] * (5 - len(images)) labels += [""] * (5 - len(labels)) logger.info(f"Final number of images: {len(images)}") logger.info(f"Final number of labels: {len(labels)}") return [predictions] + images + labels with gr.Blocks(theme=theme) as app: # Add header with gr.Row(variant="panel"): with gr.Column(scale=1): gr.Image("image.jpg", elem_id="logo-img", show_label=False ) with gr.Column(scale=30): gr.Markdown("""Biome is a vision foundation model-powered tool customized to identify Singapore's local biodiversity.

**Developed by**: Pye Sone Kyaw - AI Engineer @ Multimodal AI Team - AI Practice - GovTech SG

Under the hood, Biome is using [BioCLIP](https://github.com/Imageomics/BioCLIP) augmented with multimodal search and retrieval to enhance its Singapore-specific biodiversity classification capabilities. """) with gr.Row(variant="panel", elem_id="images_panel"): img_input = gr.Image( height=400, sources=["upload"], type="pil" ) with gr.Row(): with gr.Column(): with gr.Row(): gr.Examples( examples=example_images, inputs=img_input, label="Example Images" ) rank_dropdown = gr.Dropdown( label="Taxonomic Rank", info="Which taxonomic rank to predict. Fine-grained ranks (genus, species) are more challenging.", choices=ranks, value="Species", type="index", ) open_domain_btn = gr.Button("Submit", variant="primary") with gr.Column(): open_domain_output = gr.Label( num_top_classes=k, label="Prediction", show_label=True, value=None, ) # New section for similar images with gr.Row(variant="panel"): with gr.Column(): gr.Markdown("### Most Similar Images from Database") with gr.Row(): similar_images = [ gr.Image(label="Similar Image 1", height=200, show_label=True), gr.Image(label="Similar Image 2", height=200, show_label=True), gr.Image(label="Similar Image 3", height=200, show_label=True), gr.Image(label="Similar Image 4", height=200, show_label=True), gr.Image(label="Similar Image 5", height=200, show_label=True), ] with gr.Row(): similar_labels = [ gr.Markdown("Species 1"), gr.Markdown("Species 2"), gr.Markdown("Species 3"), gr.Markdown("Species 4"), gr.Markdown("Species 5"), ] rank_dropdown.change( fn=change_output, inputs=rank_dropdown, outputs=[open_domain_output] ) open_domain_btn.click( fn=process_output, inputs=[img_input, rank_dropdown], outputs=[open_domain_output] + similar_images + similar_labels, ) with gr.Row(variant="panel"): gr.Markdown(""" **Disclaimer**: This is a proof-of-concept demo for non-commercial purposes. No data is stored or used for any form of training, and all data used for retrieval are from [iNaturalist](https://inaturalist.org/). The adage of garbage in, garbage out applies here - uploading images not biodiversity-related will yield unpredictable results. """) app.queue(max_size=20) app.launch(share=False, enable_monitoring=False, allowed_paths=["/app/"])