Spaces:
Runtime error
Runtime error
File size: 2,644 Bytes
514ba19 e98c94e 514ba19 6a33499 514ba19 e98c94e 514ba19 6a33499 514ba19 39340ac 514ba19 e98c94e 514ba19 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 |
import gradio as gr
import numpy as np
import tensorflow as tf
from huggingface_hub import from_pretrained_keras
from PIL import Image
MODEL_CKPT = "chansung/segmentation-training-pipeline@v1667722548"
MODEL = from_pretrained_keras(MODEL_CKPT)
RESOLTUION = 128
PETS_PALETTE = []
with open(r"./palette.txt", "r") as fp:
for line in fp:
if "#" not in line:
tmp_list = list(map(int, line[:-1].strip("][").split(", ")))
PETS_PALETTE.append(tmp_list)
def preprocess_input(image: Image) -> tf.Tensor:
image = np.array(image)
image = tf.convert_to_tensor(image)
image = tf.image.resize(image, (RESOLTUION, RESOLTUION))
image = image / 255
return tf.expand_dims(image, 0)
# The below utility get_seg_overlay() are from:
# https://github.com/deep-diver/semantic-segmentation-ml-pipeline/blob/main/notebooks/inference_from_SavedModel.ipynb
def get_seg_overlay(image, seg):
color_seg = np.zeros(
(seg.shape[0], seg.shape[1], 3), dtype=np.uint8
) # height, width, 3
palette = np.array(PETS_PALETTE)
for label, color in enumerate(palette):
color_seg[seg == label, :] = color
# Show image + mask
img = np.array(image) * 0.5 + color_seg * 0.5
img *= 255
img = np.clip(img, 0, 255)
img = img.astype(np.uint8)
return img
def run_model(image: Image) -> tf.Tensor:
preprocessed_image = preprocess_input(image)
prediction = MODEL.predict(preprocessed_image)
seg_mask = tf.math.argmax(prediction, -1)
seg_mask = tf.squeeze(seg_mask)
return seg_mask
def get_predictions(image: Image):
predicted_segmentation_mask = run_model(image)
preprocessed_image = preprocess_input(image)
preprocessed_image = tf.squeeze(preprocessed_image, 0)
pred_img = get_seg_overlay(
preprocessed_image.numpy(), predicted_segmentation_mask.numpy()
)
return Image.fromarray(pred_img)
title = (
"Simple demo for a semantic segmentation model trained on the PETS dataset."
)
description = """
Note that the outputs obtained in this demo won't be state-of-the-art. The underlying project has a different objective focusing more on the ops side of
deploying a semantic segmentation model. For more details, check out the repository: https://github.com/deep-diver/semantic-segmentation-ml-pipeline/.
"""
demo = gr.Interface(
get_predictions,
gr.inputs.Image(type="pil"),
"pil",
allow_flagging="never",
title=title,
description=description,
examples=[["test-image1.png"], ["test-image2.png"], ["test-image3.png"], ["test-image4.png"], ["test-image5.png"]],
)
demo.launch()
|