chansung's picture
Update app.py
39340ac
import gradio as gr
import numpy as np
import tensorflow as tf
from huggingface_hub import from_pretrained_keras
from PIL import Image
MODEL_CKPT = "chansung/segmentation-training-pipeline@v1667722548"
MODEL = from_pretrained_keras(MODEL_CKPT)
RESOLTUION = 128
PETS_PALETTE = []
with open(r"./palette.txt", "r") as fp:
for line in fp:
if "#" not in line:
tmp_list = list(map(int, line[:-1].strip("][").split(", ")))
PETS_PALETTE.append(tmp_list)
def preprocess_input(image: Image) -> tf.Tensor:
image = np.array(image)
image = tf.convert_to_tensor(image)
image = tf.image.resize(image, (RESOLTUION, RESOLTUION))
image = image / 255
return tf.expand_dims(image, 0)
# The below utility get_seg_overlay() are from:
# https://github.com/deep-diver/semantic-segmentation-ml-pipeline/blob/main/notebooks/inference_from_SavedModel.ipynb
def get_seg_overlay(image, seg):
color_seg = np.zeros(
(seg.shape[0], seg.shape[1], 3), dtype=np.uint8
) # height, width, 3
palette = np.array(PETS_PALETTE)
for label, color in enumerate(palette):
color_seg[seg == label, :] = color
# Show image + mask
img = np.array(image) * 0.5 + color_seg * 0.5
img *= 255
img = np.clip(img, 0, 255)
img = img.astype(np.uint8)
return img
def run_model(image: Image) -> tf.Tensor:
preprocessed_image = preprocess_input(image)
prediction = MODEL.predict(preprocessed_image)
seg_mask = tf.math.argmax(prediction, -1)
seg_mask = tf.squeeze(seg_mask)
return seg_mask
def get_predictions(image: Image):
predicted_segmentation_mask = run_model(image)
preprocessed_image = preprocess_input(image)
preprocessed_image = tf.squeeze(preprocessed_image, 0)
pred_img = get_seg_overlay(
preprocessed_image.numpy(), predicted_segmentation_mask.numpy()
)
return Image.fromarray(pred_img)
title = (
"Simple demo for a semantic segmentation model trained on the PETS dataset."
)
description = """
Note that the outputs obtained in this demo won't be state-of-the-art. The underlying project has a different objective focusing more on the ops side of
deploying a semantic segmentation model. For more details, check out the repository: https://github.com/deep-diver/semantic-segmentation-ml-pipeline/.
"""
demo = gr.Interface(
get_predictions,
gr.inputs.Image(type="pil"),
"pil",
allow_flagging="never",
title=title,
description=description,
examples=[["test-image1.png"], ["test-image2.png"], ["test-image3.png"], ["test-image4.png"], ["test-image5.png"]],
)
demo.launch()