fischjos's picture
Update app.py
e3efc57 verified
import gradio as gr
import tensorflow as tf
from PIL import Image
import numpy as np
# Load the pre-trained Pokémon model
model_path = "pokemon_classifier_model.keras"
model = tf.keras.models.load_model(model_path)
# Define the Pokémon classes
classes = ['Doduo', 'Geodude', 'Zubat'] # Adjust these as per your model's classes
# Define the image classification function
def classify_image(image):
try:
# Ensure the image is in RGB and normalize it
if image.ndim == 2: # Check if the image is grayscale
image = np.stack((image,)*3, axis=-1) # Convert grayscale to RGB by repeating the gray channel
elif image.shape[2] == 4: # Check if the image has an alpha channel
image = image[:, :, :3] # Drop the alpha channel
image = Image.fromarray(image.astype('uint8'), 'RGB') # Convert to PIL Image to resize
image = image.resize((150, 150)) # Resize to match the model's input size
image_array = np.array(image) / 255.0 # Convert to array and normalize
image_array = np.expand_dims(image_array, axis=0) # Add batch dimension
# Predict using the model
prediction = model.predict(image_array)
predicted_class = classes[np.argmax(prediction)]
confidence = np.max(prediction)
return f"Predicted Pokémon: {predicted_class}, Confidence: {np.round(confidence * 100, 2)}%"
except Exception as e:
return str(e) # Return the error message if something goes wrong
# Create Gradio interface
input_image = gr.Image() # Using Gradio's Image component correctly
output_label = gr.Label()
interface = gr.Interface(fn=classify_image,
inputs=input_image,
outputs=output_label,
examples=["pokemon/doduo.png", "pokemon/geodude.png", "pokemon/zubat.png"],
description="Upload an image of a Pokémon (Doduo, Geodude or Zubat) to classify!")
interface.launch()