mnist / app.py
alibayram's picture
Refactor app.py: enhance prediction function, update title and description, and improve image handling
3646319
raw
history blame
2.26 kB
import numpy as np
import gradio as gr
import tensorflow as tf
import cv2
# Load the trained MNIST model
model = tf.keras.models.load_model("./sketch_recognition_numbers_model.h5")
# Class names (0 to 9)
labels = ["zero", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine"]
def predict(data):
# Extract the 'composite' key from the input dictionary
img = data["composite"]
img = np.array(img)
# Convert RGBA to RGB if needed
if img.shape[-1] == 4: # RGBA
img = cv2.cvtColor(img, cv2.COLOR_RGBA2RGB)
# Convert RGB to Grayscale
if img.shape[-1] == 3: # RGB
img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
# Resize image to 28x28
img = cv2.resize(img, (28, 28))
# Normalize pixel values to [0, 1]
img = img / 255.0
# Reshape to match model input (1, 28, 28, 1)
img = img.reshape(1, 28, 28, 1)
# Model predictions
preds = model.predict(img)[0]
# Get top 3 classes
top_3_classes = np.argsort(preds)[-3:][::-1]
top_3_probs = preds[top_3_classes]
class_names = [labels[i] for i in top_3_classes]
# Return top 3 predictions as a dictionary
return {class_names[i]: float(top_3_probs[i]) for i in range(3)}
# Title and description
title = "Welcome to your first sketch recognition app!"
head = (
"<center>"
"<img src='./mnist-classes.png' width=400>"
"<p>The model is trained to classify numbers (from 0 to 9). "
"To test it, draw your number in the space provided (use the editing tools in the image editor).</p>"
"</center>"
)
ref = "Find the complete code [here](https://github.com/ovh/ai-training-examples/tree/main/apps/gradio/sketch-recognition)."
with gr.Blocks(title=title) as demo:
# Display title and description
gr.Markdown(head)
gr.Markdown(ref)
with gr.Row():
# Using ImageEditor with type='numpy'
im = gr.ImageEditor(type="numpy", label="Draw your digit here (use brush and eraser)")
# Output label (top 3 predictions)
label = gr.Label(num_top_classes=3, label="Predictions")
# Trigger prediction whenever the image changes
im.change(predict, inputs=im, outputs=label, show_progress="hidden")
if __name__ == "__main__":
demo.launch(share=True)