from PIL import Image, ImageFilter import numpy as np from transformers import pipeline import gradio as gr import os models = [ "facebook/detr-resnet-50-panoptic", "CIDAS/clipseg-rd64-refined", "facebook/maskformer-swin-large-ade", "nvidia/segformer-b1-finetuned-cityscapes-1024-1024", ] current_model = models[0] #model = pipeline("image-segmentation", model="facebook/detr-resnet-50-panoptic") pred = [] def img_resize(image): width = 1280 width_percent = (width / float(image.size[0])) height = int((float(image.size[1]) * float(width_percent))) return image.resize((width, height)) def image_objects(image): global pred image = img_resize(image) pred = model(image) pred_object_list = [str(i)+'_'+x['label'] for i, x in enumerate(pred)] return gr.Dropdown.update(choices = pred_object_list, interactive = True) def get_seg(image, model_choice): image = img_resize(image) model = models[model_choice] segment = pipeline("image-segmentation", model=f"{model}") pred = segment(image) pred_object_list = [str(i)+'_'+x['label'] for i, x in enumerate(pred)] seg_box=[] for i in range(len(pred)): #object_number = int(object.split('_')[0]) mask_array = np.asarray(pred[i]['mask'])/255 image_array = np.asarray(image) mask_array_three_channel = np.zeros_like(image_array) mask_array_three_channel[:,:,0] = mask_array mask_array_three_channel[:,:,1] = mask_array mask_array_three_channel[:,:,2] = mask_array segmented_image = image_array*mask_array_three_channel seg_out=segmented_image.astype(np.uint8) seg_box.append(seg_out) return(seg_box,gr.Dropdown.update(choices = pred_object_list, interactive = True)) app = gr.Blocks() with app: gr.Markdown( """ ## Image Dissector """) with gr.Row(): with gr.Column(): image_input = gr.Image(label="Input Image",type="pil") model_name = gr.Dropdown(show_label=False, choices=[m for m in models], type="index", value=current_model, interactive=True) with gr.Column(): gal1=gr.Gallery(type="filepath").style(grid=6) with gr.Row(): with gr.Column(): object_output = gr.Dropdown(label="Objects") image_input.change(get_seg, inputs=[image_input, model_name], outputs=[gal1,object_output]) app.launch()