import albumentations as A from albumentations.pytorch.transforms import ToTensorV2 import numpy as np import gradio as gr import torch from torch import nn import numpy as np import matplotlib.pyplot as plt from PIL import Image from transformers import SegformerForSemanticSegmentation model = SegformerForSemanticSegmentation.from_pretrained('s3nh/SegFormer-b0-person-segmentation') def inference(image, chosen_model): # Transforms _transform = A.Compose([ A.Resize(height = 512, width=512), ToTensorV2(), ]) trans_image = _transform(image=np.array(image)) outputs = model(trans_image['image'].float().unsqueeze(0)) logits = outputs.logits output = torch.sigmoid(logits).detach().numpy()[0] # output = np.transpose(output, (1,2,0)) # upsampled_logits = nn.functional.interpolate(logits, # size=image.size[::-1], # (height, width) # mode='bilinear', # align_corners=False) seg = output color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8) # height, width, 3 palette = np.array([[0, 0, 0],[255, 255, 255]]) for label, color in enumerate(palette): color_seg[seg == label] = color # Convert to BGR color_seg = color_seg[..., ::-1] img = np.array(image) * 0.5 + color_seg * 0.5 output = Image.fromarray(img.astype(np.uint8)) return output demo = gr.Interface( inference, inputs = gr.Image(), outputs= gr.Image(type="pil"), title='Segformer B0 - People segmentation', description='Segformer', ) demo.launch()