Kyle Dampier
tried changing tensorflow device to cpu
a11556c
raw
history blame
No virus
2.09 kB
import tensorflow as tf
import matplotlib.pyplot as plt
from PIL import Image, ImageOps
from tensorflow.keras.utils import img_to_array
from streamlit_drawable_canvas import st_canvas
import streamlit as st
# st.set_page_config(layout="wide")
st.write('# MNIST Digit Recognition')
st.write('## Using trained CNN `Keras` model')
st.write('To view how this model was trained go to the `Files and Versions` tab and download the `Week1.ipynb` notebook')
# Import Pre-trained Model
model = tf.keras.models.load_model('mnist.h5')
tf.device('/cpu:0')
plt.rcParams.update({'font.size': 18})
# Create a sidebar to hold the settings
stroke_width = st.sidebar.slider("Stroke width: ", 1, 25, 9)
realtime_update = st.sidebar.checkbox("Update in realtime", True)
canvas_result = st_canvas(
fill_color="rgba(255, 165, 0, 0.3)", # Fixed fill color with some opacity
stroke_width=stroke_width,
stroke_color='#FFFFFF',
background_color='#000000',
#background_image=Image.open(bg_image) if bg_image else None,
update_streamlit=realtime_update,
height=28*9,
width=28*9,
drawing_mode='freedraw',
key="canvas",
)
if canvas_result.image_data is not None:
# Get image data from canvas
im = ImageOps.grayscale(Image.fromarray(canvas_result.image_data.astype(
'uint8'), mode="RGBA")).resize((28, 28))
# Convert image to array and reshape
data = img_to_array(im)
data = data / 255
data = data.reshape(1, 28, 28, 1)
data = data.astype('float32')
# Predict digit
st.write('### Predicted Digit')
prediction = model.predict(data)
# Plot prediction
result = plt.figure(figsize=(12, 3))
plt.bar(range(10), prediction[0])
plt.xticks(range(10))
plt.xlabel('Digit')
plt.ylabel('Probability')
plt.title('Drawing Prediction')
plt.ylim(0, 1)
st.write(result)
# Show resized image
with st.expander('Show Resized Image'):
st.write(
"The image needs to be resized, because it can only input 28x28 images")
st.image(im, caption='Resized Image', width=28*9)