import tensorflow as tf import matplotlib.pyplot as plt from PIL import Image, ImageOps from tensorflow.keras.utils import img_to_array from streamlit_drawable_canvas import st_canvas import streamlit as st # st.set_page_config(layout="wide") st.write('# MNIST Digit Recognition') st.write('## Using trained CNN `Keras` model') st.write('To view how this model was trained go to the `Files and Versions` tab and download the `Week1.ipynb` notebook') # Import Pre-trained Model model = tf.keras.models.load_model('mnist.h5') tf.device('/cpu:0') plt.rcParams.update({'font.size': 18}) # Create a sidebar to hold the settings stroke_width = st.sidebar.slider("Stroke width: ", 1, 25, 9) realtime_update = st.sidebar.checkbox("Update in realtime", True) canvas_result = st_canvas( fill_color="rgba(255, 165, 0, 0.3)", # Fixed fill color with some opacity stroke_width=stroke_width, stroke_color='#FFFFFF', background_color='#000000', #background_image=Image.open(bg_image) if bg_image else None, update_streamlit=realtime_update, height=28*9, width=28*9, drawing_mode='freedraw', key="canvas", ) if canvas_result.image_data is not None: # Get image data from canvas im = ImageOps.grayscale(Image.fromarray(canvas_result.image_data.astype( 'uint8'), mode="RGBA")).resize((28, 28)) # Convert image to array and reshape data = img_to_array(im) data = data / 255 data = data.reshape(1, 28, 28, 1) data = data.astype('float32') # Predict digit st.write('### Predicted Digit') prediction = model.predict(data) # Plot prediction result = plt.figure(figsize=(12, 3)) plt.bar(range(10), prediction[0]) plt.xticks(range(10)) plt.xlabel('Digit') plt.ylabel('Probability') plt.title('Drawing Prediction') plt.ylim(0, 1) st.write(result) # Show resized image with st.expander('Show Resized Image'): st.write( "The image needs to be resized, because it can only input 28x28 images") st.image(im, caption='Resized Image', width=28*9)