Kyle Dampier commited on
Commit
5db5524
1 Parent(s): f0030f9

MNIST GUI Example

Browse files
Files changed (4) hide show
  1. app.py +63 -0
  2. mnist.h5 +3 -0
  3. requirements.txt +3 -0
  4. user_input.png +0 -0
app.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ import matplotlib.pyplot as plt
3
+ from PIL import Image, ImageOps
4
+ from keras.preprocessing.image import img_to_array
5
+
6
+ from streamlit_drawable_canvas import st_canvas
7
+ import streamlit as st
8
+
9
+ # st.set_page_config(layout="wide")
10
+
11
+ st.write('# MNIST Digit Recognition')
12
+ st.write('## Using a CNN `Keras` model')
13
+
14
+ # Import Pre-trained Model
15
+ model = tf.keras.models.load_model('mnist.h5')
16
+ plt.rcParams.update({'font.size': 18})
17
+
18
+ # Create a sidebar to hold the settings
19
+ stroke_width = st.sidebar.slider("Stroke width: ", 1, 25, 9)
20
+ realtime_update = st.sidebar.checkbox("Update in realtime", True)
21
+
22
+
23
+ canvas_result = st_canvas(
24
+ fill_color="rgba(255, 165, 0, 0.3)", # Fixed fill color with some opacity
25
+ stroke_width=stroke_width,
26
+ stroke_color='#FFFFFF',
27
+ background_color='#000000',
28
+ #background_image=Image.open(bg_image) if bg_image else None,
29
+ update_streamlit=realtime_update,
30
+ height=224,
31
+ width=224,
32
+ drawing_mode='freedraw',
33
+ key="canvas",
34
+ )
35
+
36
+ if canvas_result.image_data is not None:
37
+ st.write('### Resized Image')
38
+ st.write("The image needs to be resized, because it can only input 28x28 images")
39
+ # st.image(canvas_result.image_data)
40
+ # st.write(type(canvas_result.image_data))
41
+ # st.write(canvas_result.image_data.shape)
42
+ # st.write(canvas_result.image_data)
43
+ im = ImageOps.grayscale(Image.fromarray(canvas_result.image_data.astype(
44
+ 'uint8'), mode="RGBA")).resize((28, 28))
45
+ # img_data = im.
46
+ st.image(im, width=224)
47
+
48
+ data = img_to_array(im)
49
+ data = data / 255
50
+ data = data.reshape(1, 28, 28, 1)
51
+ data = data.astype('float32')
52
+
53
+ st.write('### Predicted Digit')
54
+ prediction = model.predict(data)
55
+
56
+ result = plt.figure(figsize=(12, 3))
57
+ plt.bar(range(10), prediction[0])
58
+ plt.xticks(range(10))
59
+ plt.xlabel('Digit')
60
+ plt.ylabel('Probability')
61
+ plt.title('Drawing Prediction')
62
+ plt.ylim(0, 1)
63
+ st.write(result)
mnist.h5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f6fb1eb48a18fd769f17f093224aa1246b41d132a56a4cbe28e0b73382bd7e28
3
+ size 455304
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ pandas
2
+ tensorflow
3
+ streamlit-drawable-canvas
user_input.png ADDED