File size: 2,278 Bytes
3646319
a49ba8d
d4b4b25
cf0b1f5
a49ba8d
3646319
 
a49ba8d
3646319
 
d4b4b25
3646319
 
 
 
e1b26df
3646319
 
 
d4b4b25
3646319
 
 
c40b85e
3646319
 
 
 
 
676005c
3646319
 
 
 
 
 
4af9a59
 
3646319
 
 
 
4af9a59
3646319
 
 
 
 
 
 
 
 
 
 
 
 
 
69bd373
3646319
 
 
 
cf0b1f5
3646319
 
4af9a59
cf0b1f5
3646319
 
cf0b1f5
3646319
4af9a59
cf0b1f5
3646319
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import numpy as np
import gradio as gr
import tensorflow as tf
import cv2

# Load the trained MNIST model
model = tf.keras.models.load_model("./sketch_recognition_numbers_model.h5")

# Class names (0 to 9)
labels = ["zero", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine"]

def predict(data):
    # Extract the 'composite' key from the input dictionary
    img = data["composite"]
    img = np.array(img)

    # Convert RGBA to RGB if needed
    if img.shape[-1] == 4:  # RGBA
        img = cv2.cvtColor(img, cv2.COLOR_RGBA2RGB)

    # Convert RGB to Grayscale
    if img.shape[-1] == 3:  # RGB
        img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)

    # Resize image to 28x28
    img = cv2.resize(img, (28, 28))

    # Normalize pixel values to [0, 1]
    img = img / 255.0

    # Reshape to match model input (1, 28, 28, 1)
    img = img.reshape(1, 28, 28, 1)

    # Model predictions
    preds = model.predict(img)[0]

    print(preds)

    # Get top 3 classes
    top_3_classes = np.argsort(preds)[-3:][::-1]
    top_3_probs = preds[top_3_classes]
    class_names = [labels[i] for i in top_3_classes]
    print(class_names, top_3_probs, top_3_classes)

    # Return top 3 predictions as a dictionary
    return {class_names[i]: float(top_3_probs[i]) for i in range(3)}

# Title and description
title = "Welcome to your first sketch recognition app!"
head = (
    "<center>"
    "<img src='./mnist-classes.png' width=400>"
    "<p>The model is trained to classify numbers (from 0 to 9). "
    "To test it, draw your number in the space provided (use the editing tools in the image editor).</p>"
    "</center>"
)
ref = "Find the complete code [here](https://github.com/ovh/ai-training-examples/tree/main/apps/gradio/sketch-recognition)."

with gr.Blocks(title=title) as demo:
    # Display title and description
    gr.Markdown(head)
    gr.Markdown(ref)

    with gr.Row():
        # Using ImageEditor with type='numpy'
        im = gr.Sketchpad(type="numpy", label="Draw your digit here (use brush and eraser)")

        # Output label (top 3 predictions)
        label = gr.Label(num_top_classes=3, label="Predictions")

    # Trigger prediction whenever the image changes
    im.change(predict, inputs=im, outputs=label)

    demo.launch(share=True)