alibayram commited on
Commit
4c5bd22
·
1 Parent(s): 095137d

Refactor app.py: update app title and description, simplify label handling, and enhance prediction function

Browse files
Files changed (2) hide show
  1. app.py +28 -77
  2. requirements.txt +3 -3
app.py CHANGED
@@ -1,96 +1,47 @@
1
- import numpy as np
2
  import gradio as gr
3
  import tensorflow as tf
4
  import cv2
5
 
6
- # App title
7
- title = "Welcome to your first sketch recognition app!"
8
 
9
- # App description
10
  head = (
11
- "<center>"
12
- "<img src='./mnist-classes.png' width=400>"
13
- "<p>The model is trained to classify numbers (from 0 to 9). "
14
- "To test it, draw your number in the space provided.</p>"
15
- "</center>"
16
  )
17
 
18
  # GitHub repository link
19
- ref = "Find the complete code [here](https://github.com/ovh/ai-training-examples/tree/main/apps/gradio/sketch-recognition)."
20
 
 
 
21
 
22
- # Class names (from 0 to 9)
23
- labels = {
24
- 0: "zero",
25
- 1: "one",
26
- 2: "two",
27
- 3: "three",
28
- 4: "four",
29
- 5: "five",
30
- 6: "six",
31
- 7: "seven",
32
- 8: "eight",
33
- 9: "nine"
34
- }
35
- # Load model (trained on MNIST dataset)
36
- model = tf.keras.models.load_model("./sketch_recognition_numbers_model.h5")
37
 
38
- def predict(data):
39
- # Convert to NumPy array
40
- img = np.array(data['composite'])
41
 
42
- # print non-zero values
43
- print("non-zero values", np.count_nonzero(img))
44
- for i in range(img.shape[0]):
45
- for j in range(img.shape[1]):
46
- if img[i][j] > 0:
47
- print(i, j, img[i][j])
48
 
49
- print("img.shape", img.shape)
 
 
50
 
51
- # Handle RGBA or RGB images
52
- if img.shape[-1] == 4: # RGBA
53
- img = cv2.cvtColor(img, cv2.COLOR_RGBA2RGB)
54
- if img.shape[-1] == 3: # RGB
55
- img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
56
 
57
- # Resize image to 28x28
58
- img = cv2.resize(img, (28, 28))
59
 
60
- # Normalize pixel values to [0, 1]
61
- img = img / 255.0
62
 
63
- # Reshape to match model input
64
- img = img.reshape(1, 28, 28, 1)
65
-
66
- print("img", img)
67
-
68
- # Model predictions
69
- preds = model.predict(img)[0]
70
-
71
- print("preds", preds)
72
- values_map = {preds[i]: i for i in range(len(preds))}
73
-
74
- sorted_values = sorted(preds, reverse=True)
75
-
76
- labels_map = dict()
77
- for i in range(3):
78
- print("sorted_values[i]", sorted_values[i], values_map[sorted_values[i]])
79
- labels_map[labels[values_map[sorted_values[i]]]] = sorted_values[i]
80
-
81
- print("labels_map", labels_map)
82
- return labels_map
83
-
84
- # Top 3 classes
85
- label = gr.Label(num_top_classes=3)
86
-
87
- # Open Gradio interface for sketch recognition
88
- interface = gr.Interface(
89
- fn=predict,
90
- inputs=gr.Sketchpad(type='numpy', image_mode='L', brush=gr.Brush()),
91
- outputs=label,
92
- title=title,
93
- description=head,
94
- article=ref
95
- )
96
  interface.launch(share=True)
 
1
+ # import dependencies
2
  import gradio as gr
3
  import tensorflow as tf
4
  import cv2
5
 
6
+ # app title
7
+ title = "Welcome on your first sketch recognition app!"
8
 
9
+ # app description
10
  head = (
11
+ "<center>"
12
+ "<img src='file/mnist-classes.png' width=400>"
13
+ "The robot was trained to classify numbers (from 0 to 9). To test it, write your number in the space provided."
14
+ "</center>"
 
15
  )
16
 
17
  # GitHub repository link
18
+ ref = "Find the whole code [here](https://github.com/ovh/ai-training-examples/tree/main/apps/gradio/sketch-recognition)."
19
 
20
+ # image size: 28x28
21
+ img_size = 28
22
 
23
+ # classes name (from 0 to 9)
24
+ labels = ["zero", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine"]
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
+ # load model (trained on MNIST dataset)
27
+ model = tf.keras.models.load_model("model/sketch_recognition_numbers_model.h5")
 
28
 
29
+ # prediction function for sketch recognition
30
+ def predict(img):
 
 
 
 
31
 
32
+ # image shape: 28x28x1
33
+ img = cv2.resize(img, (img_size, img_size))
34
+ img = img.reshape(1, img_size, img_size, 1)
35
 
36
+ # model predictions
37
+ preds = model.predict(img)[0]
 
 
 
38
 
39
+ # return the probability for each classe
40
+ return {label: float(pred) for label, pred in zip(labels, preds)}
41
 
42
+ # top 3 of classes
43
+ label = gr.outputs.Label(num_top_classes=3)
44
 
45
+ # open Gradio interface for sketch recognition
46
+ interface = gr.Interface(fn=predict, inputs="sketchpad", outputs=label, title=title, description=head, article=ref)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  interface.launch(share=True)
requirements.txt CHANGED
@@ -1,3 +1,3 @@
1
- tensorflow
2
- opencv-python-headless
3
- numpy
 
1
+ gradio==3.0.10
2
+ tensorflow==2.9.1
3
+ opencv-python-headless==4.6.0.66