task3 / app.py
EUNSEO56's picture
Rename aapp.py to app.py
e94dae3
raw
history blame
1.48 kB
import gradio as gr
import numpy as np
from PIL import Image
import tensorflow as tf
from transformers import SegformerFeatureExtractor, TFSegformerForSemanticSegmentation
# Segformer 모델과 관련 객체를 초기화
feature_extractor = SegformerFeatureExtractor.from_pretrained("nickmuchi/segformer-b4-finetuned-segments-sidewalk")
model = TFSegformerForSemanticSegmentation.from_pretrained("nickmuchi/segformer-b4-finetuned-segments-sidewalk", from_pt=True)
def perform_semantic_segmentation(input_img):
input_img = Image.fromarray(input_img)
# 이미지를 처리하고 모델에 전달
inputs = feature_extractor(images=input_img, return_tensors="tf")
outputs = model(**inputs)
logits = outputs.logits
# 모델 출력을 처리하여 시맨틱 분할 결과를 얻음
logits = tf.transpose(logits, [0, 2, 3, 1])
logits = tf.image.resize(logits, input_img.size[::-1])
seg = tf.math.argmax(logits, axis=-1)[0]
return input_img, seg.numpy()
def segformer_interface(input_image):
original_image, segmentation_map = perform_semantic_segmentation(input_image)
return original_image, segmentation_map
# Gradio 데모 구성
demo = gr.Interface(
fn=segformer_interface,
inputs=gr.Image(shape=(400, 600)),
outputs=[gr.Image(type="plot"), gr.Image(type="plot")], # 원본 이미지 및 시맨틱 분할 맵을 출력
examples=["side-1.jpg", "side-2.jpg", "side-3.jpg"],
allow_flagging='never'
)
demo.launch()