Spaces:
Runtime error
Runtime error
File size: 2,089 Bytes
5db5524 e3171cd 5db5524 60d926f 5db5524 a11556c 5db5524 1a84122 5db5524 60d926f f213495 5db5524 f213495 5db5524 f213495 5db5524 f213495 5db5524 60d926f f213495 a11556c f213495 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 |
import tensorflow as tf
import matplotlib.pyplot as plt
from PIL import Image, ImageOps
from tensorflow.keras.utils import img_to_array
from streamlit_drawable_canvas import st_canvas
import streamlit as st
# st.set_page_config(layout="wide")
st.write('# MNIST Digit Recognition')
st.write('## Using trained CNN `Keras` model')
st.write('To view how this model was trained go to the `Files and Versions` tab and download the `Week1.ipynb` notebook')
# Import Pre-trained Model
model = tf.keras.models.load_model('mnist.h5')
tf.device('/cpu:0')
plt.rcParams.update({'font.size': 18})
# Create a sidebar to hold the settings
stroke_width = st.sidebar.slider("Stroke width: ", 1, 25, 9)
realtime_update = st.sidebar.checkbox("Update in realtime", True)
canvas_result = st_canvas(
fill_color="rgba(255, 165, 0, 0.3)", # Fixed fill color with some opacity
stroke_width=stroke_width,
stroke_color='#FFFFFF',
background_color='#000000',
#background_image=Image.open(bg_image) if bg_image else None,
update_streamlit=realtime_update,
height=28*9,
width=28*9,
drawing_mode='freedraw',
key="canvas",
)
if canvas_result.image_data is not None:
# Get image data from canvas
im = ImageOps.grayscale(Image.fromarray(canvas_result.image_data.astype(
'uint8'), mode="RGBA")).resize((28, 28))
# Convert image to array and reshape
data = img_to_array(im)
data = data / 255
data = data.reshape(1, 28, 28, 1)
data = data.astype('float32')
# Predict digit
st.write('### Predicted Digit')
prediction = model.predict(data)
# Plot prediction
result = plt.figure(figsize=(12, 3))
plt.bar(range(10), prediction[0])
plt.xticks(range(10))
plt.xlabel('Digit')
plt.ylabel('Probability')
plt.title('Drawing Prediction')
plt.ylim(0, 1)
st.write(result)
# Show resized image
with st.expander('Show Resized Image'):
st.write(
"The image needs to be resized, because it can only input 28x28 images")
st.image(im, caption='Resized Image', width=28*9)
|