Spaces:
Runtime error
Runtime error
import tensorflow as tf | |
import matplotlib.pyplot as plt | |
from PIL import Image, ImageOps | |
from tensorflow.keras.utils import img_to_array | |
from streamlit_drawable_canvas import st_canvas | |
import streamlit as st | |
# st.set_page_config(layout="wide") | |
st.write('# MNIST Digit Recognition') | |
st.write('## Using trained CNN `Keras` model') | |
st.write('To view how this model was trained go to the `Files and Versions` tab and download the `Week1.ipynb` notebook') | |
# Import Pre-trained Model | |
model = tf.keras.models.load_model('mnist.h5') | |
tf.device('/cpu:0') | |
plt.rcParams.update({'font.size': 18}) | |
# Create a sidebar to hold the settings | |
stroke_width = st.sidebar.slider("Stroke width: ", 1, 25, 9) | |
realtime_update = st.sidebar.checkbox("Update in realtime", True) | |
canvas_result = st_canvas( | |
fill_color="rgba(255, 165, 0, 0.3)", # Fixed fill color with some opacity | |
stroke_width=stroke_width, | |
stroke_color='#FFFFFF', | |
background_color='#000000', | |
#background_image=Image.open(bg_image) if bg_image else None, | |
update_streamlit=realtime_update, | |
height=28*9, | |
width=28*9, | |
drawing_mode='freedraw', | |
key="canvas", | |
) | |
if canvas_result.image_data is not None: | |
# Get image data from canvas | |
im = ImageOps.grayscale(Image.fromarray(canvas_result.image_data.astype( | |
'uint8'), mode="RGBA")).resize((28, 28)) | |
# Convert image to array and reshape | |
data = img_to_array(im) | |
data = data / 255 | |
data = data.reshape(1, 28, 28, 1) | |
data = data.astype('float32') | |
# Predict digit | |
st.write('### Predicted Digit') | |
prediction = model.predict(data) | |
# Plot prediction | |
result = plt.figure(figsize=(12, 3)) | |
plt.bar(range(10), prediction[0]) | |
plt.xticks(range(10)) | |
plt.xlabel('Digit') | |
plt.ylabel('Probability') | |
plt.title('Drawing Prediction') | |
plt.ylim(0, 1) | |
st.write(result) | |
# Show resized image | |
with st.expander('Show Resized Image'): | |
st.write( | |
"The image needs to be resized, because it can only input 28x28 images") | |
st.image(im, caption='Resized Image', width=28*9) | |