Kyle Dampier
MNIST GUI Example
5db5524
raw
history blame
No virus
1.94 kB
import tensorflow as tf
import matplotlib.pyplot as plt
from PIL import Image, ImageOps
from keras.preprocessing.image import img_to_array
from streamlit_drawable_canvas import st_canvas
import streamlit as st
# st.set_page_config(layout="wide")
st.write('# MNIST Digit Recognition')
st.write('## Using a CNN `Keras` model')
# Import Pre-trained Model
model = tf.keras.models.load_model('mnist.h5')
plt.rcParams.update({'font.size': 18})
# Create a sidebar to hold the settings
stroke_width = st.sidebar.slider("Stroke width: ", 1, 25, 9)
realtime_update = st.sidebar.checkbox("Update in realtime", True)
canvas_result = st_canvas(
fill_color="rgba(255, 165, 0, 0.3)", # Fixed fill color with some opacity
stroke_width=stroke_width,
stroke_color='#FFFFFF',
background_color='#000000',
#background_image=Image.open(bg_image) if bg_image else None,
update_streamlit=realtime_update,
height=224,
width=224,
drawing_mode='freedraw',
key="canvas",
)
if canvas_result.image_data is not None:
st.write('### Resized Image')
st.write("The image needs to be resized, because it can only input 28x28 images")
# st.image(canvas_result.image_data)
# st.write(type(canvas_result.image_data))
# st.write(canvas_result.image_data.shape)
# st.write(canvas_result.image_data)
im = ImageOps.grayscale(Image.fromarray(canvas_result.image_data.astype(
'uint8'), mode="RGBA")).resize((28, 28))
# img_data = im.
st.image(im, width=224)
data = img_to_array(im)
data = data / 255
data = data.reshape(1, 28, 28, 1)
data = data.astype('float32')
st.write('### Predicted Digit')
prediction = model.predict(data)
result = plt.figure(figsize=(12, 3))
plt.bar(range(10), prediction[0])
plt.xticks(range(10))
plt.xlabel('Digit')
plt.ylabel('Probability')
plt.title('Drawing Prediction')
plt.ylim(0, 1)
st.write(result)