Spaces:
Sleeping
Sleeping
# Import the libraries | |
import numpy as np | |
import pandas as pd | |
from tensorflow.keras.models import load_model | |
from tensorflow.keras.preprocessing.image import load_img, img_to_array | |
from tensorflow.keras.applications.convnext import preprocess_input | |
import gradio as gr | |
# Load the model | |
model = load_model('models/TropiCam-AI_ConvNeXtBase') | |
# Load the taxonomy .csv | |
taxo_df = pd.read_csv('taxonomy/taxonomy_mapping.csv') | |
taxo_df['species'] = taxo_df['species'].str.replace('_', ' ') | |
# Available taxonomic levels for prediction | |
taxonomic_levels = ['species', 'genus', 'family', 'order', 'class'] | |
# Function to map predicted class index to class name at the selected taxonomic level | |
def get_class_name(predicted_class, taxonomic_level): | |
unique_labels = sorted(taxo_df[taxonomic_level].unique()) | |
return unique_labels[predicted_class] | |
# Function to aggregate predictions to a higher taxonomic level | |
def aggregate_predictions(predicted_probs, taxonomic_level, class_names): | |
unique_labels = sorted(taxo_df[taxonomic_level].unique()) | |
aggregated_predictions = np.zeros((predicted_probs.shape[0], len(unique_labels))) | |
for idx, row in taxo_df.iterrows(): | |
species = row['species'] | |
higher_level = row[taxonomic_level] | |
species_index = class_names.index(species) # Index of the species in the prediction array | |
higher_level_index = unique_labels.index(higher_level) | |
aggregated_predictions[:, higher_level_index] += predicted_probs[:, species_index] | |
return aggregated_predictions, unique_labels | |
# Function to load and preprocess the image | |
def load_and_preprocess_image(image, target_size=(224, 224)): | |
# Resize the image | |
img_array = img_to_array(image.resize(target_size)) | |
# Expand the dimensions to match model input | |
img_array = np.expand_dims(img_array, axis=0) | |
# Preprocess the image | |
img_array = preprocess_input(img_array) | |
return img_array | |
# Function to make predictions | |
def make_prediction(image, taxonomic_decision, taxonomic_level): | |
# Preprocess the image | |
img_array = load_and_preprocess_image(image) | |
# Get the class names from the 'species' column | |
class_names = sorted(taxo_df['species'].unique()) | |
# Make a prediction | |
prediction = model.predict(img_array) | |
# Initialize variables for aggregated predictions and level index | |
aggregated_predictions = None | |
current_level_index = 0 # Start from the species level | |
# Determine the initial taxonomic level based on the user's decision | |
if taxonomic_decision == "No, I will let the model decide": | |
current_level_index = 0 # Start at species level if letting the model decide | |
else: | |
current_level_index = taxonomic_levels.index(taxonomic_level) # Use specified level | |
# Aggregate predictions based on the current taxonomic level | |
aggregated_predictions, aggregated_class_labels = aggregate_predictions(prediction, taxonomic_levels[current_level_index], class_names) | |
# If the user specified a taxonomic level, simply get the highest prediction at that level | |
if taxonomic_decision == "Yes, I want to specify the taxonomic level": | |
# Get the predicted class index for the current level | |
predicted_class_index = np.argmax(aggregated_predictions) | |
predicted_class_name = aggregated_class_labels[predicted_class_index] | |
# Check if common name should be displayed (only at species level) | |
if taxonomic_levels[current_level_index] == "species": | |
predicted_common_name = taxo_df[taxo_df[taxonomic_levels[current_level_index]] == predicted_class_name]['common_name'].values[0] | |
output_text = f"<h1 style='font-weight: bold;'><span style='font-style: italic;'>{predicted_class_name}</span> ({predicted_common_name})</h1>" | |
else: | |
output_text = f"<h1 style='font-weight: bold;'>{predicted_class_name}</h1>" | |
# Add the top 5 predictions | |
output_text += "<h4 style='font-weight: bold; font-size: 1.2em;'>Top-5 predictions:</h4>" | |
top_indices = np.argsort(aggregated_predictions[0])[-5:][::-1] # Get top 5 predictions | |
for i in top_indices: | |
class_name = aggregated_class_labels[i] | |
confidence_percentage = aggregated_predictions[0][i] * 100 | |
output_text += f"<div style='display: flex; justify-content: space-between;'>" \ | |
f"<span>{class_name}</span>" \ | |
f"<span style='margin-left: auto;'>{confidence_percentage:.2f}%</span></div>" | |
return output_text | |
# Confidence checking for the automatic model decision | |
# Loop through taxonomic levels if the user lets the model decide | |
while current_level_index < len(taxonomic_levels): | |
# Aggregate predictions for the next level | |
aggregated_predictions, aggregated_class_labels = aggregate_predictions(prediction, taxonomic_levels[current_level_index], class_names) | |
# Check if the confidence of the top prediction meets the threshold | |
top_prediction_index = np.argmax(aggregated_predictions) | |
top_prediction_confidence = aggregated_predictions[0][top_prediction_index] | |
if top_prediction_confidence >= 0.75: | |
break # Confidence threshold met, exit loop | |
current_level_index += 1 # Move to the next taxonomic level | |
# Check if a valid prediction was made | |
if current_level_index == len(taxonomic_levels): | |
return "<h1 style='font-weight: bold;'>Unknown animal</h1>" # No valid predictions met the confidence criteria | |
# Get the predicted class name for the top prediction | |
predicted_class_index = np.argmax(aggregated_predictions) | |
predicted_class_name = aggregated_class_labels[predicted_class_index] | |
# Check if common name should be displayed (only at species level) | |
if taxonomic_levels[current_level_index] == "species": | |
predicted_common_name = taxo_df[taxo_df[taxonomic_levels[current_level_index]] == predicted_class_name]['common_name'].values[0] | |
output_text = f"<h1 style='font-weight: bold;'><span style='font-style: italic;'>{predicted_class_name}</span> ({predicted_common_name})</h1>" | |
else: | |
output_text = f"<h1 style='font-weight: bold;'>{predicted_class_name}</h1>" | |
# Add the top 5 predictions | |
output_text += "<h4 style='font-weight: bold; font-size: 1.2em;'>Top-5 predictions:</h4>" | |
top_indices = np.argsort(aggregated_predictions[0])[-5:][::-1] # Get top 5 predictions | |
for i in top_indices: | |
class_name = aggregated_class_labels[i] | |
if taxonomic_levels[current_level_index] == "species": | |
# Display common names only at species level and make it italic | |
common_name = taxo_df[taxo_df[taxonomic_levels[current_level_index]] == class_name]['common_name'].values[0] | |
confidence_percentage = aggregated_predictions[0][i] * 100 | |
output_text += f"<div style='display: flex; justify-content: space-between;'>" \ | |
f"<span style='font-style: italic;'>{class_name}</span> (<span>{common_name}</span>)" \ | |
f"<span style='margin-left: auto;'>{confidence_percentage:.2f}%</span></div>" | |
else: | |
# No common names at higher taxonomic levels | |
confidence_percentage = aggregated_predictions[0][i] * 100 | |
output_text += f"<div style='display: flex; justify-content: space-between;'>" \ | |
f"<span>{class_name}</span>" \ | |
f"<span style='margin-left: auto;'>{confidence_percentage:.2f}%</span></div>" | |
return output_text | |
# Confidence checking for the automatic model decision | |
# Loop through taxonomic levels if the user lets the model decide | |
while current_level_index < len(taxonomic_levels): | |
# Aggregate predictions for the next level | |
aggregated_predictions, aggregated_class_labels = aggregate_predictions(prediction, taxonomic_levels[current_level_index], class_names) | |
# Check if the confidence of the top prediction meets the threshold | |
top_prediction_index = np.argmax(aggregated_predictions) | |
top_prediction_confidence = aggregated_predictions[0][top_prediction_index] | |
if top_prediction_confidence >= 0.75: | |
break # Confidence threshold met, exit loop | |
current_level_index += 1 # Move to the next taxonomic level | |
# Check if a valid prediction was made | |
if current_level_index == len(taxonomic_levels): | |
return "<h1 style='font-weight: bold;'>Unknown animal</h1>" # No valid predictions met the confidence criteria | |
# Get the predicted class name for the top prediction | |
predicted_class_index = np.argmax(aggregated_predictions) | |
predicted_class_name = aggregated_class_labels[predicted_class_index] | |
# Check if common name should be displayed (only at species level) | |
if taxonomic_levels[current_level_index] == "species": | |
predicted_common_name = taxo_df[taxo_df[taxonomic_levels[current_level_index]] == predicted_class_name]['common_name'].values[0] | |
output_text = f"<h1 style='font-weight: bold;'><span style='font-style: italic;'>{predicted_class_name}</span> ({predicted_common_name})</h1>" | |
else: | |
output_text = f"<h1 style='font-weight: bold;'>{predicted_class_name}</h1>" | |
# Add the top-5 predictions | |
output_text += "<h4 style='font-weight: bold; font-size: 1.2em;'>Top-5 predictions:</h4>" | |
top_indices = np.argsort(aggregated_predictions[0])[-5:][::-1] # Get top 5 predictions | |
for i in top_indices: | |
class_name = aggregated_class_labels[i] | |
if taxonomic_levels[current_level_index] == "species": | |
# Display common names only at species level and make it italic | |
common_name = taxo_df[taxo_df[taxonomic_levels[current_level_index]] == class_name]['common_name'].values[0] | |
confidence_percentage = aggregated_predictions[0][i] * 100 | |
output_text += f"<div style='display: flex; justify-content: space-between;'>" \ | |
f"<span style='font-style: italic;'>{class_name}</span> (<span>{common_name}</span>)" \ | |
f"<span style='margin-left: auto;'>{confidence_percentage:.2f}%</span></div>" | |
else: | |
# No common names at higher taxonomic levels | |
confidence_percentage = aggregated_predictions[0][i] * 100 | |
output_text += f"<div style='display: flex; justify-content: space-between;'>" \ | |
f"<span>{class_name}</span>" \ | |
f"<span style='margin-left: auto;'>{confidence_percentage:.2f}%</span></div>" | |
return output_text | |
# Define the Gradio interface | |
interface = gr.Interface( | |
fn=make_prediction, # Function to be called for predictions | |
inputs=[ | |
gr.Image(type="pil", label="Upload Image"), # Input type: Image (PIL format) | |
gr.Radio(choices=["Yes, I want to specify the taxonomic level", "No, I will let the model decide"], | |
label="Do you want to specify the taxonomic resolution for predictions? If you select 'No', the 'Taxonomic level' drop-down menu will be bypassed.", | |
value="No, I will let the model decide"), # Radio button for taxonomic resolution choice | |
gr.Dropdown(choices=taxonomic_levels, label="Taxonomic level:", value="species") # Dropdown for taxonomic level | |
], | |
outputs="html", # Output type: HTML for formatting | |
title="Neotropical arboreal species classification", | |
description="Upload an image and our AI will classify the animal. NOTE: it's best not to feed the whole image but just the cropped animal (in the final model this will be done automatically)." | |
) | |
# Launch the Gradio interface with authentication for the specified users | |
interface.launch() |