|
import gradio as gr |
|
import numpy as np |
|
from PIL import Image |
|
import tensorflow as tf |
|
from transformers import SegformerFeatureExtractor, TFSegformerForSemanticSegmentation |
|
|
|
|
|
feature_extractor = SegformerFeatureExtractor.from_pretrained("nickmuchi/segformer-b4-finetuned-segments-sidewalk") |
|
model = TFSegformerForSemanticSegmentation.from_pretrained("nickmuchi/segformer-b4-finetuned-segments-sidewalk", from_pt=True) |
|
|
|
def perform_semantic_segmentation(input_img): |
|
input_img = Image.fromarray(input_img) |
|
|
|
|
|
inputs = feature_extractor(images=input_img, return_tensors="tf") |
|
outputs = model(**inputs) |
|
logits = outputs.logits |
|
|
|
|
|
logits = tf.transpose(logits, [0, 2, 3, 1]) |
|
logits = tf.image.resize(logits, input_img.size[::-1]) |
|
seg = tf.math.argmax(logits, axis=-1)[0] |
|
|
|
return input_img, seg.numpy() |
|
|
|
def segformer_interface(input_image): |
|
original_image, segmentation_map = perform_semantic_segmentation(input_image) |
|
return original_image, segmentation_map |
|
|
|
|
|
demo = gr.Interface( |
|
fn=segformer_interface, |
|
inputs=gr.Image(shape=(400, 600)), |
|
outputs=[gr.Image(type="plot"), gr.Image(type="plot")], |
|
examples=["side-1.jpg", "side-2.jpg", "side-3.jpg"], |
|
allow_flagging='never' |
|
) |
|
|
|
demo.launch() |
|
|