andrewzamp commited on
Commit
1364165
1 Parent(s): b503865

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -111
app.py CHANGED
@@ -1,116 +1,16 @@
1
- # Import the libraries
2
- import numpy as np
3
- import pandas as pd
4
- from tensorflow.keras.models import load_model
5
- from tensorflow.keras.preprocessing.image import load_img, img_to_array
6
- from tensorflow.keras.applications.convnext import preprocess_input
7
- import gradio as gr
8
-
9
- # Load the model
10
- model = load_model('models/ConvNeXtBase_80_tresh_spp.tf')
11
-
12
- # Load the taxonomy .csv
13
- taxo_df = pd.read_csv('taxonomy/taxonomy_mapping.csv')
14
- taxo_df['species'] = taxo_df['species'].str.replace('_', ' ')
15
-
16
- # Available taxonomic levels
17
- taxonomic_levels = ['species', 'genus', 'family', 'order', 'class']
18
-
19
- # Function to map predicted class index to class name at the selected taxonomic level
20
- def get_class_name(predicted_class, taxonomic_level):
21
- unique_labels = sorted(taxo_df[taxonomic_level].unique())
22
- return unique_labels[predicted_class]
23
-
24
- # Function to aggregate predictions to a higher taxonomic level
25
- def aggregate_predictions(predicted_probs, taxonomic_level, class_names):
26
- unique_labels = sorted(taxo_df[taxonomic_level].unique())
27
- aggregated_predictions = np.zeros((predicted_probs.shape[0], len(unique_labels)))
28
-
29
- for idx, row in taxo_df.iterrows():
30
- species = row['species']
31
- higher_level = row[taxonomic_level]
32
-
33
- species_index = class_names.index(species) # Index of the species in the prediction array
34
- higher_level_index = unique_labels.index(higher_level)
35
-
36
- aggregated_predictions[:, higher_level_index] += predicted_probs[:, species_index]
37
-
38
- return aggregated_predictions, unique_labels
39
-
40
- # Function to load and preprocess the image
41
- def load_and_preprocess_image(image, target_size=(224, 224)):
42
- # Resize the image
43
- img_array = img_to_array(image.resize(target_size))
44
- # Expand the dimensions to match model input
45
- img_array = np.expand_dims(img_array, axis=0)
46
- # Preprocess the image
47
- img_array = preprocess_input(img_array)
48
- return img_array
49
-
50
- # Function to make predictions
51
- def make_prediction(image, taxonomic_level):
52
- # Preprocess the image
53
- img_array = load_and_preprocess_image(image)
54
-
55
- # Get the class names from the 'species' column
56
- class_names = sorted(taxo_df['species'].unique()) # Add this line to define class_names
57
-
58
- # Make a prediction
59
- prediction = model.predict(img_array)
60
-
61
- # Aggregate predictions based on the selected taxonomic level
62
- aggregated_predictions, aggregated_class_labels = aggregate_predictions(prediction, taxonomic_level, class_names)
63
-
64
- # Get the top 5 predictions
65
- top_indices = np.argsort(aggregated_predictions[0])[-5:][::-1]
66
-
67
- # Get predicted class for the top prediction
68
- predicted_class_index = np.argmax(aggregated_predictions)
69
- predicted_class_name = aggregated_class_labels[predicted_class_index]
70
-
71
- # Check if common name should be displayed (only at species level)
72
- if taxonomic_level == "species":
73
- predicted_common_name = taxo_df[taxo_df[taxonomic_level] == predicted_class_name]['common_name'].values[0]
74
- output_text = f"<h1 style='font-weight: bold;'><span style='font-style: italic;'>{predicted_class_name}</span> ({predicted_common_name})</h1>"
75
- else:
76
- output_text = f"<h1 style='font-weight: bold;'>{predicted_class_name}</h1>"
77
-
78
- # Add the top 5 predictions
79
- output_text += "<h4 style='font-weight: bold; font-size: 1.2em;'>Top 5 Predictions:</h4>"
80
-
81
- for i in top_indices:
82
- class_name = aggregated_class_labels[i]
83
-
84
- if taxonomic_level == "species":
85
- # Display common names only at species level and make it italic
86
- common_name = taxo_df[taxo_df[taxonomic_level] == class_name]['common_name'].values[0]
87
- confidence_percentage = aggregated_predictions[0][i] * 100
88
- output_text += f"<div style='display: flex; justify-content: space-between;'>" \
89
- f"<span style='font-style: italic;'>{class_name}</span>&nbsp;(<span>{common_name}</span>)" \
90
- f"<span style='margin-left: auto;'>{confidence_percentage:.2f}%</span></div>"
91
- else:
92
- # No common names at higher taxonomic levels
93
- confidence_percentage = aggregated_predictions[0][i] * 100
94
- output_text += f"<div style='display: flex; justify-content: space-between;'>" \
95
- f"<span>{class_name}</span>" \
96
- f"<span style='margin-left: auto;'>{confidence_percentage:.2f}%</span></div>"
97
-
98
- return output_text
99
-
100
- with gr.Blocks() as demo:
101
- # Define the input and output components for predictions
102
- image_input = gr.Image(type="pil", label="Upload Image") # Input type: Image (PIL format)
103
- taxonomic_level_input = gr.Dropdown(choices=taxonomic_levels, label="Taxonomic level", value="species") # Dropdown for taxonomic level
104
- output_html = gr.HTML(label="Prediction Result") # Output type: HTML for formatting
105
-
106
- # Create the prediction button
107
- predict_button = gr.Button("Make Prediction")
108
-
109
- # Define what happens when the button is clicked
110
- predict_button.click(make_prediction, inputs=[image_input, taxonomic_level_input], outputs=output_html)
111
 
112
  # Launch the Gradio interface with authentication for the specified users
113
- demo.launch(auth=[
 
114
  ("Luca Santini", "lucasantini"),
115
  ("Ana Ben铆tez L贸pez", "anaben铆tezl贸pez")
116
  ])
 
1
+ # Define the Gradio interface
2
+ interface = gr.Interface(
3
+ fn=make_prediction, # Function to be called for predictions
4
+ inputs=[gr.Image(type="pil", label="Upload Image"), # Input type: Image (PIL format)
5
+ gr.Dropdown(choices=taxonomic_levels, label="Taxonomic level", value="species")], # Dropdown for taxonomic level
6
+ outputs="html", # Output type: HTML for formatting
7
+ title="Amazon arboreal species classification",
8
+ description="Upload an image and select the taxonomic level to classify the species."
9
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  # Launch the Gradio interface with authentication for the specified users
12
+ interface.launch(auth=[
13
+ ("Andrea Zampetti", "andreazampetti"),
14
  ("Luca Santini", "lucasantini"),
15
  ("Ana Ben铆tez L贸pez", "anaben铆tezl贸pez")
16
  ])