from PIL import Image import streamlit as st from streamlit_drawable_canvas import st_canvas import tensorflow_addons as tfa import tensorflow as tf import numpy as np from tensorflow.keras.utils import custom_object_scope # Define a function to create the InstanceNormalization layer def create_in(): return tfa.layers.InstanceNormalization() def model_out(model_path,img): with custom_object_scope({'InstanceNormalization': create_in}): model = tf.keras.models.load_model(model_path) img = (img-127.5)/127.5 img = np.expand_dims(img, 0) pred = model.predict(img) pred = np.asarray(pred) return pred[0] # Specify canvas parameters in application drawing_mode = st.sidebar.selectbox( "Drawing tool:", ("freedraw", "line") ) stroke_width = st.sidebar.slider("Stroke width: ", 1, 25, 3) colors = { "Buildings": "#464646", "Tree": "#6A8E22", "Road": "#803F7F", "Car": "#00008E", "SideWalk": "#F71CEA", "SideWalk-Grass": "#9AFA98", "Sky": "#4184B1" } # Create a radio button for selecting the stroke color stroke_color = st.sidebar.radio("Select Stroke Color", list(colors.keys()), index=0) # Default color is Black # Use the selected color from the list stroke_color = colors[stroke_color] bg_color = st.sidebar.color_picker("Background color hex: ", "#eee") realtime_update = st.sidebar.checkbox("Update in realtime", True) # Create a canvas component canvas_result = st_canvas( stroke_width=stroke_width, stroke_color=stroke_color, background_color=bg_color, update_streamlit=realtime_update, height=256, width=256, drawing_mode=drawing_mode, key="canvas", ) button = st.button("Predict", type="primary") # Do something interesting with the image data and paths if button is not None: if canvas_result.image_data is not None: st.image(canvas_result.image_data) img = np.array(canvas_result.image_data) img_rgb = Image.fromarray(img).convert("RGB") img = np.array(img_rgb) pred = model_out('pix2pix.h5', img) pred = (pred + 1.0) / 2.0 # Undo tanh normalization to get values in [0.0, 1.0] range pred = Image.fromarray((pred * 255).astype(np.uint8)).convert("RGB") st.image(pred)