import gradio as gr import numpy as np import tensorflow as tf from huggingface_hub import from_pretrained_keras from PIL import Image MODEL_CKPT = "chansung/segmentation-training-pipeline@v1667722548" MODEL = from_pretrained_keras(MODEL_CKPT) RESOLTUION = 128 PETS_PALETTE = [] with open(r"./palette.txt", "r") as fp: for line in fp: if "#" not in line: tmp_list = list(map(int, line[:-1].strip("][").split(", "))) PETS_PALETTE.append(tmp_list) def preprocess_input(image: Image) -> tf.Tensor: image = np.array(image) image = tf.convert_to_tensor(image) image = tf.image.resize(image, (RESOLTUION, RESOLTUION)) image = image / 255 return tf.expand_dims(image, 0) # The below utility get_seg_overlay() are from: # https://github.com/deep-diver/semantic-segmentation-ml-pipeline/blob/main/notebooks/inference_from_SavedModel.ipynb def get_seg_overlay(image, seg): color_seg = np.zeros( (seg.shape[0], seg.shape[1], 3), dtype=np.uint8 ) # height, width, 3 palette = np.array(PETS_PALETTE) for label, color in enumerate(palette): color_seg[seg == label, :] = color # Show image + mask img = np.array(image) * 0.5 + color_seg * 0.5 img *= 255 img = np.clip(img, 0, 255) img = img.astype(np.uint8) return img def run_model(image: Image) -> tf.Tensor: preprocessed_image = preprocess_input(image) prediction = MODEL.predict(preprocessed_image) seg_mask = tf.math.argmax(prediction, -1) seg_mask = tf.squeeze(seg_mask) return seg_mask def get_predictions(image: Image): predicted_segmentation_mask = run_model(image) preprocessed_image = preprocess_input(image) preprocessed_image = tf.squeeze(preprocessed_image, 0) pred_img = get_seg_overlay( preprocessed_image.numpy(), predicted_segmentation_mask.numpy() ) return Image.fromarray(pred_img) title = ( "Simple demo for a semantic segmentation model trained on the PETS dataset." ) description = """ Note that the outputs obtained in this demo won't be state-of-the-art. The underlying project has a different objective focusing more on the ops side of deploying a semantic segmentation model. For more details, check out the repository: https://github.com/deep-diver/semantic-segmentation-ml-pipeline/. """ demo = gr.Interface( get_predictions, gr.inputs.Image(type="pil"), "pil", allow_flagging="never", title=title, description=description, examples=[["test-image1.png"], ["test-image2.png"], ["test-image3.png"], ["test-image4.png"], ["test-image5.png"]], ) demo.launch()