Spaces:
Running
on
A10G
Running
on
A10G
File size: 1,761 Bytes
dd0ab9f 0819ae2 dd0ab9f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 |
import logging
from typing import List, Tuple, Dict
import streamlit as st
import torch
import gc
import numpy as np
from PIL import Image
from transformers import AutoImageProcessor, UperNetForSemanticSegmentation
from palette import ade_palette
LOGGING = logging.getLogger(__name__)
def flush():
gc.collect()
torch.cuda.empty_cache()
@st.cache_resource(max_entries=5)
def get_segmentation_pipeline() -> Tuple[AutoImageProcessor, UperNetForSemanticSegmentation]:
"""Method to load the segmentation pipeline
Returns:
Tuple[AutoImageProcessor, UperNetForSemanticSegmentation]: segmentation pipeline
"""
image_processor = AutoImageProcessor.from_pretrained("openmmlab/upernet-convnext-small")
image_segmentor = UperNetForSemanticSegmentation.from_pretrained(
"openmmlab/upernet-convnext-small")
return image_processor, image_segmentor
@torch.inference_mode()
@torch.autocast('cuda')
def segment_image(image: Image) -> Image:
"""Method to segment image
Args:
image (Image): input image
Returns:
Image: segmented image
"""
image_processor, image_segmentor = get_segmentation_pipeline()
pixel_values = image_processor(image, return_tensors="pt").pixel_values
with torch.no_grad():
outputs = image_segmentor(pixel_values)
seg = image_processor.post_process_semantic_segmentation(
outputs, target_sizes=[image.size[::-1]])[0]
color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8)
palette = np.array(ade_palette())
for label, color in enumerate(palette):
color_seg[seg == label, :] = color
color_seg = color_seg.astype(np.uint8)
seg_image = Image.fromarray(color_seg).convert('RGB')
return seg_image |