File size: 2,278 Bytes
3646319 a49ba8d d4b4b25 cf0b1f5 a49ba8d 3646319 a49ba8d 3646319 d4b4b25 3646319 e1b26df 3646319 d4b4b25 3646319 c40b85e 3646319 676005c 3646319 4af9a59 3646319 4af9a59 3646319 69bd373 3646319 cf0b1f5 3646319 4af9a59 cf0b1f5 3646319 cf0b1f5 3646319 4af9a59 cf0b1f5 3646319 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 |
import numpy as np
import gradio as gr
import tensorflow as tf
import cv2
# Load the trained MNIST model
model = tf.keras.models.load_model("./sketch_recognition_numbers_model.h5")
# Class names (0 to 9)
labels = ["zero", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine"]
def predict(data):
# Extract the 'composite' key from the input dictionary
img = data["composite"]
img = np.array(img)
# Convert RGBA to RGB if needed
if img.shape[-1] == 4: # RGBA
img = cv2.cvtColor(img, cv2.COLOR_RGBA2RGB)
# Convert RGB to Grayscale
if img.shape[-1] == 3: # RGB
img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
# Resize image to 28x28
img = cv2.resize(img, (28, 28))
# Normalize pixel values to [0, 1]
img = img / 255.0
# Reshape to match model input (1, 28, 28, 1)
img = img.reshape(1, 28, 28, 1)
# Model predictions
preds = model.predict(img)[0]
print(preds)
# Get top 3 classes
top_3_classes = np.argsort(preds)[-3:][::-1]
top_3_probs = preds[top_3_classes]
class_names = [labels[i] for i in top_3_classes]
print(class_names, top_3_probs, top_3_classes)
# Return top 3 predictions as a dictionary
return {class_names[i]: float(top_3_probs[i]) for i in range(3)}
# Title and description
title = "Welcome to your first sketch recognition app!"
head = (
"<center>"
"<img src='./mnist-classes.png' width=400>"
"<p>The model is trained to classify numbers (from 0 to 9). "
"To test it, draw your number in the space provided (use the editing tools in the image editor).</p>"
"</center>"
)
ref = "Find the complete code [here](https://github.com/ovh/ai-training-examples/tree/main/apps/gradio/sketch-recognition)."
with gr.Blocks(title=title) as demo:
# Display title and description
gr.Markdown(head)
gr.Markdown(ref)
with gr.Row():
# Using ImageEditor with type='numpy'
im = gr.Sketchpad(type="numpy", label="Draw your digit here (use brush and eraser)")
# Output label (top 3 predictions)
label = gr.Label(num_top_classes=3, label="Predictions")
# Trigger prediction whenever the image changes
im.change(predict, inputs=im, outputs=label)
demo.launch(share=True) |