alibayram commited on
Commit
e1b26df
·
1 Parent(s): cd17133

Refactor sketch recognition app: simplify image preprocessing, update app description, and enhance prediction function

Browse files
Files changed (1) hide show
  1. app.py +35 -57
app.py CHANGED
@@ -1,77 +1,55 @@
1
- import os
2
  import numpy as np
3
  import cv2
4
  import gradio as gr
5
  import tensorflow as tf
6
 
7
- # Disable oneDNN optimizations for consistent results
8
- os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
9
 
10
- # App configuration
11
- title = "Welcome to your first sketch recognition app!"
12
- description = (
13
- "<center>"
14
- "<img src='mnist-classes.png' width=400>"
15
- "<p>The robot was trained to classify numbers (from 0 to 9). "
16
- "To test it, write your number in the space provided!</p>"
17
- "</center>"
18
  )
19
- article = "Find the complete code [here](https://github.com/ovh/ai-training-examples/tree/main/apps/gradio/sketch-recognition)."
20
 
21
- # Image size and labels
 
 
 
22
  img_size = 28
23
- labels = ["zero", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine"]
24
 
25
- # Load the trained MNIST model
26
- model_path = "./sketch_recognition_numbers_model.h5"
27
- try:
28
- model = tf.keras.models.load_model(model_path)
29
- except Exception as e:
30
- raise FileNotFoundError(f"Model file '{model_path}' not found or failed to load. {str(e)}")
31
 
 
 
32
 
33
- def preprocess_image(img):
34
- """
35
- Convert PIL image to grayscale NumPy array, resize, normalize, and reshape.
36
- """
37
- # Convert PIL to NumPy array
38
  img = np.array(img)
39
-
40
- # Ensure grayscale format
41
- if len(img.shape) == 3: # Check if it's RGB/RGBA
42
  img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
43
 
44
- # Resize to 28x28
45
  img = cv2.resize(img, (img_size, img_size))
 
 
 
 
 
 
46
 
47
- # Normalize pixel values to [0, 1]
48
- img = img / 255.0
49
-
50
- # Reshape for model input
51
- return img.reshape(1, img_size, img_size, 1)
52
 
 
 
53
 
54
- def predict(img):
55
- """
56
- Predict the digit class probabilities from the input sketch image.
57
- """
58
- try:
59
- processed_img = preprocess_image(img)
60
- predictions = model.predict(processed_img)[0]
61
- return {label: float(pred) for label, pred in zip(labels, predictions)}
62
- except Exception as e:
63
- return {"error": f"Prediction failed: {str(e)}"}
64
-
65
-
66
- # Gradio interface
67
- interface = gr.Interface(
68
- fn=predict,
69
- inputs="sketchpad",
70
- outputs=gr.Label(num_top_classes=3),
71
- title=title,
72
- description=description,
73
- article=article,
74
- )
75
-
76
- # Launch the app
77
  interface.launch()
 
 
1
  import numpy as np
2
  import cv2
3
  import gradio as gr
4
  import tensorflow as tf
5
 
6
+ # app title
7
+ title = "Welcome on your first sketch recognition app!"
8
 
9
+ # app description
10
+ head = (
11
+ "<center>"
12
+ "<img src='./mnist-classes.png' width=400>"
13
+ "The robot was trained to classify numbers (from 0 to 9). To test it, write your number in the space provided."
14
+ "</center>"
 
 
15
  )
 
16
 
17
+ # GitHub repository link
18
+ ref = "Find the whole code [here](https://github.com/ovh/ai-training-examples/tree/main/apps/gradio/sketch-recognition)."
19
+
20
+ # image size: 28x28
21
  img_size = 28
 
22
 
23
+ # classes name (from 0 to 9)
24
+ labels = ["zero", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine"]
 
 
 
 
25
 
26
+ # load model (trained on MNIST dataset)
27
+ model = tf.keras.models.load_model("./sketch_recognition_numbers_model.h5")
28
 
29
+ # prediction function for sketch recognition
30
+ def predict(img):
31
+ # Convert from PIL to NumPy
 
 
32
  img = np.array(img)
33
+
34
+ # If the image is in RGB format, convert it to grayscale
35
+ if len(img.shape) == 3:
36
  img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
37
 
38
+ # Resize the image to 28x28
39
  img = cv2.resize(img, (img_size, img_size))
40
+
41
+ # Reshape to the model's input shape (1,28,28,1)
42
+ img = img.reshape(1, img_size, img_size, 1)
43
+
44
+ # model predictions
45
+ preds = model.predict(img)[0]
46
 
47
+ # return the probability for each class
48
+ return {label: float(pred) for label, pred in zip(labels, preds)}
 
 
 
49
 
50
+ # top 3 of classes
51
+ label = gr.Label(num_top_classes=3)
52
 
53
+ # open Gradio interface for sketch recognition
54
+ interface = gr.Interface(fn=predict, inputs="sketchpad", outputs=label, title=title, description=head, article=ref)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  interface.launch()