alibayram commited on
Commit
d4b4b25
1 Parent(s): a49ba8d

Add sketch recognition functionality with TensorFlow and OpenCV

Browse files
app.py CHANGED
@@ -1,7 +1,47 @@
 
1
  import gradio as gr
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import dependencies
2
  import gradio as gr
3
+ import tensorflow as tf
4
+ import cv2
5
 
6
+ # app title
7
+ title = "Welcome on your first sketch recognition app!"
8
 
9
+ # app description
10
+ head = (
11
+ "<center>"
12
+ "<img src='file/mnist-classes.png' width=400>"
13
+ "The robot was trained to classify numbers (from 0 to 9). To test it, write your number in the space provided."
14
+ "</center>"
15
+ )
16
+
17
+ # GitHub repository link
18
+ ref = "Find the whole code [here](https://github.com/ovh/ai-training-examples/tree/main/apps/gradio/sketch-recognition)."
19
+
20
+ # image size: 28x28
21
+ img_size = 28
22
+
23
+ # classes name (from 0 to 9)
24
+ labels = ["zero", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine"]
25
+
26
+ # load model (trained on MNIST dataset)
27
+ model = tf.keras.models.load_model("./sketch_recognition_numbers_model.h5")
28
+
29
+ # prediction function for sketch recognition
30
+ def predict(img):
31
+
32
+ # image shape: 28x28x1
33
+ img = cv2.resize(img, (img_size, img_size))
34
+ img = img.reshape(1, img_size, img_size, 1)
35
+
36
+ # model predictions
37
+ preds = model.predict(img)[0]
38
+
39
+ # return the probability for each classe
40
+ return {label: float(pred) for label, pred in zip(labels, preds)}
41
+
42
+ # top 3 of classes
43
+ label = gr.outputs.Label(num_top_classes=3)
44
+
45
+ # open Gradio interface for sketch recognition
46
+ interface = gr.Interface(fn=predict, inputs="sketchpad", outputs=label, title=title, description=head, article=ref)
47
+ interface.launch(server_name="0.0.0.0", server_port=8080)
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ tensorflow
2
+ opencv-python
sketch_recognition_numbers_model.h5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:044fd13c73edbf776ec5d7f9aa77c87e42f4c498ad02425628a96e425c6b51f2
3
+ size 1245272