import collections
import heapq
import json
import os
import logging
import faiss
import requests
import gradio as gr
import numpy as np
import torch
import torch.nn.functional as F
from open_clip import create_model, get_tokenizer
from torchvision import transforms
from PIL import Image
import io
from pathlib import Path
from huggingface_hub import hf_hub_download
log_format = "[%(asctime)s] [%(levelname)s] [%(name)s] %(message)s"
logging.basicConfig(level=logging.INFO, format=log_format)
logger = logging.getLogger()
hf_token = os.getenv("HF_TOKEN")
model_str = "hf-hub:imageomics/bioclip"
tokenizer_str = "ViT-B-16"
txt_emb_npy = hf_hub_download(repo_id="pyesonekyaw/biome_lfs", filename='txt_emb_species.npy', repo_type="dataset")
txt_names_json = "txt_emb_species.json"
min_prob = 1e-9
k = 5
ranks = ("Kingdom", "Phylum", "Class", "Order", "Family", "Genus", "Species")
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
preprocess_img = transforms.Compose(
[
transforms.ToTensor(),
transforms.Resize((224, 224), antialias=True),
transforms.Normalize(
mean=(0.48145466, 0.4578275, 0.40821073),
std=(0.26862954, 0.26130258, 0.27577711),
),
]
)
MIN_PROB = 1e-9
TOP_K_PREDICTIONS = 5
TOP_K_CANDIDATES = 250
TOP_N_SIMILAR = 22
SIMILARITY_BOOST = 0.2
VOTE_THRESHOLD = 3
SIMILARITY_THRESHOLD = 0.99
# Add paths for RAG
PHOTO_LOOKUP_PATH = f"./photo_lookup.json"
SPECIES_LOOKUP_PATH = f"./species_lookup.json"
theme = gr.themes.Base(
primary_hue=gr.themes.colors.teal,
secondary_hue=gr.themes.colors.blue,
neutral_hue=gr.themes.colors.gray,
text_size=gr.themes.sizes.text_lg,
).set(
button_primary_background_fill="#114A56",
button_primary_background_fill_hover="#114A56",
block_title_text_weight="600",
block_label_text_weight="600",
block_label_text_size="*text_md",
)
EXAMPLES_DIR = Path("examples")
example_images = sorted(str(p) for p in EXAMPLES_DIR.glob("*.jpg"))
def indexed(lst, indices):
return [lst[i] for i in indices]
def format_name(taxon, common):
taxon = " ".join(taxon)
if not common:
return taxon
return f"{taxon} ({common})"
def combine_duplicate_predictions(predictions):
"""Combine predictions where one name is contained within another."""
combined = {}
used = set()
# Sort by length of name (longer names first) and probability
items = sorted(predictions.items(), key=lambda x: (-len(x[0]), -x[1]))
for name1, prob1 in items:
if name1 in used:
continue
total_prob = prob1
used.add(name1)
# Check remaining predictions
for name2, prob2 in predictions.items():
if name2 in used:
continue
# Convert to lowercase for comparison
name1_lower = name1.lower()
name2_lower = name2.lower()
# Check if one name contains the other
if name1_lower in name2_lower or name2_lower in name1_lower:
total_prob += prob2
used.add(name2)
combined[name1] = total_prob
# Normalize probabilities
total = sum(combined.values())
return {k: v/total for k, v in combined.items()}
@torch.no_grad()
def open_domain_classification(img, rank: int, return_all=False):
"""
Predicts from the entire tree of life using RAG approach.
"""
logger.info(f"Starting open domain classification for rank: {rank}")
img = preprocess_img(img).to(device)
img_features = model.encode_image(img.unsqueeze(0))
img_features = F.normalize(img_features, dim=-1)
# Get zero-shot predictions
logits = (model.logit_scale.exp() * img_features @ txt_emb).squeeze()
probs = F.softmax(logits, dim=0)
# Get similar images votes and metadata
species_votes, similar_images = get_similar_images_metadata(img_features, faiss_index, id_mapping, name_mapping)
if rank + 1 == len(ranks):
# Species level prediction
topk = probs.topk(TOP_K_CANDIDATES)
predictions = {
format_name(*txt_names[i]): prob.item()
for i, prob in zip(topk.indices, topk.values)
}
# Augment predictions with votes
augmented_predictions = predictions.copy()
for pred_name in predictions:
pred_name_lower = pred_name.lower()
for voted_species, vote_count in species_votes.items():
if voted_species in pred_name_lower or pred_name_lower in voted_species:
augmented_predictions[pred_name] += SIMILARITY_BOOST * vote_count
elif vote_count >= VOTE_THRESHOLD:
augmented_predictions[voted_species] = vote_count * SIMILARITY_BOOST
# Sort predictions
sorted_predictions = dict(sorted(
augmented_predictions.items(),
key=lambda x: x[1],
reverse=True
)[:k])
# Normalize and combine duplicates
total = sum(sorted_predictions.values())
sorted_predictions = {k: v/total for k, v in sorted_predictions.items()}
sorted_predictions = combine_duplicate_predictions(sorted_predictions)
logger.info(f"Top K predictions after combining duplicates: {sorted_predictions}")
return sorted_predictions, similar_images
# Higher rank prediction
output = collections.defaultdict(float)
for i in torch.nonzero(probs > MIN_PROB).squeeze():
output[" ".join(txt_names[i][0][: rank + 1])] += probs[i]
# Incorporate votes for higher ranks
for species, vote_count in species_votes.items():
try:
# Find matching taxonomy in txt_names
for taxonomy, _ in txt_names:
if species in " ".join(taxonomy).lower():
higher_rank = " ".join(taxonomy[: rank + 1])
output[higher_rank] += SIMILARITY_BOOST * vote_count
break
except Exception as e:
logger.error(f"Error processing vote for species {species}: {e}")
# Get top-k predictions and normalize
topk_names = heapq.nlargest(k, output, key=output.get)
prediction_dict = {name: output[name] for name in topk_names}
# Normalize probabilities to sum to 1
total = sum(prediction_dict.values())
prediction_dict = {k: v/total for k, v in prediction_dict.items()}
prediction_dict = combine_duplicate_predictions(prediction_dict)
logger.info(f"Prediction dictionary after combining duplicates: {prediction_dict}")
return prediction_dict, similar_images
def change_output(choice):
return gr.Label(num_top_classes=k, label=ranks[choice], show_label=True, value=None)
def get_cache_paths(name="demo"):
"""Get paths for cached FAISS index and ID mapping."""
return {
'index': hf_hub_download(repo_id="pyesonekyaw/biome_lfs", filename='cache/faiss_cache_demo.index', repo_type="dataset"),
'mapping': hf_hub_download(repo_id="pyesonekyaw/biome_lfs", filename='cache/faiss_cache_demo_mapping.json', repo_type="dataset")
}
def build_name_mapping(txt_names):
"""Build mapping between scientific names and common names."""
name_mapping = {}
for taxonomy, common_name in txt_names:
if not common_name:
continue
if len(taxonomy) >= 2:
scientific_name = f"{taxonomy[-2]} {taxonomy[-1]}".lower()
common_name = common_name.lower()
name_mapping[scientific_name] = (scientific_name, common_name)
name_mapping[common_name] = (scientific_name, common_name)
return name_mapping
def load_faiss_index():
"""Load FAISS index from cache."""
cache_paths = get_cache_paths()
logger.info("Loading FAISS index from cache...")
index = faiss.read_index(cache_paths['index'])
with open(cache_paths['mapping'], 'r') as f:
id_mapping = json.load(f)
return index, id_mapping
def get_similar_images_metadata(img_embedding, faiss_index, id_mapping, name_mapping):
"""Get metadata for similar images using FAISS search."""
img_embedding_np = img_embedding.cpu().numpy()
if img_embedding_np.ndim == 1:
img_embedding_np = img_embedding_np.reshape(1, -1)
# Search for more images than needed to account for filtered matches
distances, indices = faiss_index.search(img_embedding_np, TOP_N_SIMILAR * 2)
# Filter out near-exact matches
valid_indices = []
valid_distances = []
valid_count = 0
for dist, idx in zip(distances[0], indices[0]):
# For inner product similarity, the distance is already the similarity
similarity = dist
if similarity > SIMILARITY_THRESHOLD:
continue
valid_indices.append(idx)
valid_distances.append(similarity)
valid_count += 1
if valid_count >= TOP_N_SIMILAR:
break
species_votes = {}
similar_images = []
for idx, similarity in zip(valid_indices[:5], valid_distances[:5]): # Only process top 5 for display
similar_img_id = id_mapping[idx]
try:
species_names = id_to_species_info.get(similar_img_id)
species_names = [name for name in species_names if name]
processed_names = set()
for species in species_names:
if not species:
continue
name_tuple = name_mapping.get(species)
if name_tuple:
processed_names.add(name_tuple[0])
else:
processed_names.add(species)
for species in processed_names:
species_votes[species] = species_votes.get(species, 0) + 1
# Store similar image info if the image file exists
# if img_path and os.path.exists(img_path):
similar_images.append({
'id': similar_img_id,
'species': next(iter(processed_names)) if processed_names else 'Unknown',
'common_name': species_names[-1],
'similarity': similarity # Add similarity score
})
except Exception as e:
logger.error(f"Error processing JSON for image {similar_img_id}: {e}")
continue
return species_votes, similar_images
if __name__ == "__main__":
logger.info("Starting.")
model = create_model(model_str, output_dict=True, require_pretrained=True)
model = model.to(device)
logger.info("Created model.")
model = torch.compile(model)
logger.info("Compiled model.")
tokenizer = get_tokenizer(tokenizer_str)
id_to_photo_url = json.load(open(PHOTO_LOOKUP_PATH))
id_to_species_info = json.load(open(SPECIES_LOOKUP_PATH))
logger.info(f"Loaded {len(id_to_photo_url)} photo mappings")
logger.info(f"Loaded {len(id_to_species_info)} species mappings")
# Load text embeddings and build name mapping
txt_emb = torch.from_numpy(np.load(txt_emb_npy, mmap_mode="r")).to(device)
with open(txt_names_json) as fd:
txt_names = json.load(fd)
# Build name mapping
name_mapping = build_name_mapping(txt_names)
# Build or load FAISS index with test IDs
faiss_index, id_mapping = load_faiss_index()
# Define process_output function before using it
def process_output(img, rank):
predictions, similar_imgs = open_domain_classification(img, rank)
logger.info(f"Number of similar images found: {len(similar_imgs)}")
images = []
labels = []
for img_info in similar_imgs:
img_id = img_info['id']
img_url = id_to_photo_url.get(img_id)
img_url = img_url.replace("square", "small")
logger.info(f"Processing image URL: {img_url}")
try:
# Try fetching from URL first
response = requests.get(img_url)
if response.status_code == 200:
try:
img = Image.open(io.BytesIO(response.content))
images.append(img)
except Exception as e:
logger.info(f"Failed to load image from URL: {e}")
images.append(None)
else:
logger.info(f"Failed to fetch image from URL: {response}")
images.append(None)
# Add label regardless of image load success
label = f"**{img_info['species']}**"
if img_info['common_name']:
label += f" ({img_info['common_name']})"
label += f"\nSimilarity: {img_info['similarity']:.3f}"
label += f"\n[View on iNaturalist](https://www.inaturalist.org/observations/{img_id})"
labels.append(label)
except Exception as e:
logger.error(f"Error processing image {img_id}: {e}")
images.append(None)
labels.append("")
# Pad arrays if needed
images += [None] * (5 - len(images))
labels += [""] * (5 - len(labels))
logger.info(f"Final number of images: {len(images)}")
logger.info(f"Final number of labels: {len(labels)}")
return [predictions] + images + labels
with gr.Blocks(theme=theme) as app:
# Add header
with gr.Row(variant="panel"):
with gr.Column(scale=1):
gr.Image("image.jpg", elem_id="logo-img",
show_label=False )
with gr.Column(scale=30):
gr.Markdown("""Biome is a vision foundation model-powered tool customized to identify Singapore's local biodiversity.
**Developed by**: Pye Sone Kyaw - AI Engineer @ Multimodal AI Team - AI Practice - GovTech SG
Under the hood, Biome is using [BioCLIP](https://github.com/Imageomics/BioCLIP) augmented with multimodal search and retrieval to enhance its Singapore-specific biodiversity classification capabilities.
""")
with gr.Row(variant="panel", elem_id="images_panel"):
img_input = gr.Image(
height=400,
sources=["upload"],
type="pil"
)
with gr.Row():
with gr.Column():
with gr.Row():
gr.Examples(
examples=example_images,
inputs=img_input,
label="Example Images"
)
rank_dropdown = gr.Dropdown(
label="Taxonomic Rank",
info="Which taxonomic rank to predict. Fine-grained ranks (genus, species) are more challenging.",
choices=ranks,
value="Species",
type="index",
)
open_domain_btn = gr.Button("Submit", variant="primary")
with gr.Column():
open_domain_output = gr.Label(
num_top_classes=k,
label="Prediction",
show_label=True,
value=None,
)
# New section for similar images
with gr.Row(variant="panel"):
with gr.Column():
gr.Markdown("### Most Similar Images from Database")
with gr.Row():
similar_images = [
gr.Image(label="Similar Image 1", height=200, show_label=True),
gr.Image(label="Similar Image 2", height=200, show_label=True),
gr.Image(label="Similar Image 3", height=200, show_label=True),
gr.Image(label="Similar Image 4", height=200, show_label=True),
gr.Image(label="Similar Image 5", height=200, show_label=True),
]
with gr.Row():
similar_labels = [
gr.Markdown("Species 1"),
gr.Markdown("Species 2"),
gr.Markdown("Species 3"),
gr.Markdown("Species 4"),
gr.Markdown("Species 5"),
]
rank_dropdown.change(
fn=change_output,
inputs=rank_dropdown,
outputs=[open_domain_output]
)
open_domain_btn.click(
fn=process_output,
inputs=[img_input, rank_dropdown],
outputs=[open_domain_output] + similar_images + similar_labels,
)
with gr.Row(variant="panel"):
gr.Markdown("""
**Disclaimer**: This is a proof-of-concept demo for non-commercial purposes. No data is stored or used for any form of training, and all data used for retrieval are from [iNaturalist](https://inaturalist.org/).
The adage of garbage in, garbage out applies here - uploading images not biodiversity-related will yield unpredictable results.
""")
app.queue(max_size=20)
app.launch(share=False, enable_monitoring=False, allowed_paths=["/app/"])