alibayram commited on
Commit
c40b85e
·
1 Parent(s): e1b26df

Refactor sketch recognition app: improve prediction function, enhance image processing, and update app title and description

Browse files
Files changed (1) hide show
  1. app.py +36 -26
app.py CHANGED
@@ -1,17 +1,17 @@
1
- import numpy as np
2
- import cv2
3
  import gradio as gr
4
  import tensorflow as tf
 
 
5
 
6
  # app title
7
- title = "Welcome on your first sketch recognition app!"
8
 
9
  # app description
10
  head = (
11
- "<center>"
12
- "<img src='./mnist-classes.png' width=400>"
13
- "The robot was trained to classify numbers (from 0 to 9). To test it, write your number in the space provided."
14
- "</center>"
15
  )
16
 
17
  # GitHub repository link
@@ -28,28 +28,38 @@ model = tf.keras.models.load_model("./sketch_recognition_numbers_model.h5")
28
 
29
  # prediction function for sketch recognition
30
  def predict(img):
31
- # Convert from PIL to NumPy
32
- img = np.array(img)
33
-
34
- # If the image is in RGB format, convert it to grayscale
35
- if len(img.shape) == 3:
36
- img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
37
-
38
- # Resize the image to 28x28
39
- img = cv2.resize(img, (img_size, img_size))
40
-
41
- # Reshape to the model's input shape (1,28,28,1)
42
- img = img.reshape(1, img_size, img_size, 1)
43
-
44
- # model predictions
45
- preds = model.predict(img)[0]
46
-
47
- # return the probability for each class
48
- return {label: float(pred) for label, pred in zip(labels, preds)}
 
 
 
49
 
50
  # top 3 of classes
51
  label = gr.Label(num_top_classes=3)
52
 
53
  # open Gradio interface for sketch recognition
54
- interface = gr.Interface(fn=predict, inputs="sketchpad", outputs=label, title=title, description=head, article=ref)
 
 
 
 
 
 
 
55
  interface.launch()
 
1
+ # import dependencies
 
2
  import gradio as gr
3
  import tensorflow as tf
4
+ import cv2
5
+ import numpy as np
6
 
7
  # app title
8
+ title = "Welcome to your first sketch recognition app!"
9
 
10
  # app description
11
  head = (
12
+ "<center>"
13
+ "The robot was trained to classify numbers (from 0 to 9). To test it, write your number in the space provided."
14
+ "</center>"
 
15
  )
16
 
17
  # GitHub repository link
 
28
 
29
  # prediction function for sketch recognition
30
  def predict(img):
31
+ if img is not None:
32
+ # Convert to numpy array if not already
33
+ img = np.array(img)
34
+
35
+ # Ensure grayscale
36
+ if len(img.shape) == 3:
37
+ img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
38
+
39
+ # Resize to required dimensions
40
+ img = cv2.resize(img, (img_size, img_size))
41
+
42
+ # Normalize and reshape
43
+ img = img.astype('float32') / 255.0
44
+ img = img.reshape(1, img_size, img_size, 1)
45
+
46
+ # model predictions
47
+ preds = model.predict(img)[0]
48
+
49
+ # return the probability for each class
50
+ return {label: float(pred) for label, pred in zip(labels, preds)}
51
+ return None
52
 
53
  # top 3 of classes
54
  label = gr.Label(num_top_classes=3)
55
 
56
  # open Gradio interface for sketch recognition
57
+ interface = gr.Interface(
58
+ fn=predict,
59
+ inputs=gr.Sketchpad(shape=(280, 280)),
60
+ outputs=label,
61
+ title=title,
62
+ description=head,
63
+ article=ref
64
+ )
65
  interface.launch()