Spaces:
Running
Running
mask2former app setup
Browse files- app.py +42 -0
- color_palette.py +89 -0
- examples/armchair.jpg +0 -0
- examples/cat-dog.jpg +0 -0
- examples/person-bike.jpg +0 -0
- predict.py +101 -0
app.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from predict import predict_masks
|
3 |
+
import glob
|
4 |
+
|
5 |
+
##Create list of examples to be loaded
|
6 |
+
example_list = glob.glob("examples/*")
|
7 |
+
example_list = list(map(lambda el:[el], example_list))
|
8 |
+
|
9 |
+
demo = gr.Blocks()
|
10 |
+
|
11 |
+
with demo:
|
12 |
+
|
13 |
+
gr.Markdown("# **<p align='center'>Mask2Former: Masked Attention Transformer for Universal Segmentation</p>**")
|
14 |
+
gr.Markdown("This space demonstrates the use of Mask2Former. It was introduced in the paper [Masked-attention Mask Transformer for Universal Image Segmentation](https://arxiv.org/abs/2112.01527) and first released in [this repository](https://github.com/facebookresearch/Mask2Former/). \
|
15 |
+
Before Mask2Former, you'd have to resort to using a specialized architecture designed for solving a particular kind of image segmentation task (i.e. semantic, instance or panoptic segmentation). On the other hand, in the form of Mask2Former, for the first time, we have a single architecture that is capable of solving any segmentation task and performs on par or better than specialized architectures.")
|
16 |
+
|
17 |
+
with gr.Box():
|
18 |
+
|
19 |
+
with gr.Row():
|
20 |
+
segmentation_task = gr.Dropdown(["semantic", "panoptic"], value="panoptic", label="Segmentation Task", show_label=True)
|
21 |
+
with gr.Box():
|
22 |
+
with gr.Row():
|
23 |
+
input_image = gr.Image(type='filepath',label="Input Image", show_label=True)
|
24 |
+
output_mask = gr.Image(label="Predicted Masks", show_label=True)
|
25 |
+
|
26 |
+
gr.Markdown("**Predict**")
|
27 |
+
|
28 |
+
with gr.Box():
|
29 |
+
with gr.Row():
|
30 |
+
submit_button = gr.Button("Submit")
|
31 |
+
|
32 |
+
gr.Markdown("**Examples:**")
|
33 |
+
|
34 |
+
with gr.Column():
|
35 |
+
gr.Examples(example_list, [input_image, segmentation_task], output_mask, predict_masks)
|
36 |
+
|
37 |
+
|
38 |
+
submit_button.click(predict_masks, inputs=[input_image, segmentation_task], outputs=output_mask)
|
39 |
+
|
40 |
+
gr.Markdown('\n Demo created by: <a href=\"https://www.linkedin.com/in/shivalika-singh/\">Shivalika Singh</a>')
|
41 |
+
|
42 |
+
demo.launch(debug=True)
|
color_palette.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# color palattes for COCO, cityscapes and ADE datasets
|
2 |
+
|
3 |
+
def coco_panoptic_palette():
|
4 |
+
return [(220, 20, 60), (119, 11, 32), (0, 0, 142), (0, 0, 230),
|
5 |
+
(106, 0, 228), (0, 60, 100), (0, 80, 100), (0, 0, 70),
|
6 |
+
(0, 0, 192), (250, 170, 30), (100, 170, 30), (220, 220, 0),
|
7 |
+
(175, 116, 175), (250, 0, 30), (165, 42, 42), (255, 77, 255),
|
8 |
+
(0, 226, 252), (182, 182, 255), (0, 82, 0), (120, 166, 157),
|
9 |
+
(110, 76, 0), (174, 57, 255), (199, 100, 0), (72, 0, 118),
|
10 |
+
(255, 179, 240), (0, 125, 92), (209, 0, 151), (188, 208, 182),
|
11 |
+
(0, 220, 176), (255, 99, 164), (92, 0, 73), (133, 129, 255),
|
12 |
+
(78, 180, 255), (0, 228, 0), (174, 255, 243), (45, 89, 255),
|
13 |
+
(134, 134, 103), (145, 148, 174), (255, 208, 186),
|
14 |
+
(197, 226, 255), (171, 134, 1), (109, 63, 54), (207, 138, 255),
|
15 |
+
(151, 0, 95), (9, 80, 61), (84, 105, 51), (74, 65, 105),
|
16 |
+
(166, 196, 102), (208, 195, 210), (255, 109, 65), (0, 143, 149),
|
17 |
+
(179, 0, 194), (209, 99, 106), (5, 121, 0), (227, 255, 205),
|
18 |
+
(147, 186, 208), (153, 69, 1), (3, 95, 161), (163, 255, 0),
|
19 |
+
(119, 0, 170), (0, 182, 199), (0, 165, 120), (183, 130, 88),
|
20 |
+
(95, 32, 0), (130, 114, 135), (110, 129, 133), (166, 74, 118),
|
21 |
+
(219, 142, 185), (79, 210, 114), (178, 90, 62), (65, 70, 15),
|
22 |
+
(127, 167, 115), (59, 105, 106), (142, 108, 45), (196, 172, 0),
|
23 |
+
(95, 54, 80), (128, 76, 255), (201, 57, 1), (246, 0, 122),
|
24 |
+
(191, 162, 208), (255, 255, 128), (147, 211, 203),
|
25 |
+
(150, 100, 100), (168, 171, 172), (146, 112, 198),
|
26 |
+
(210, 170, 100), (92, 136, 89), (218, 88, 184), (241, 129, 0),
|
27 |
+
(217, 17, 255), (124, 74, 181), (70, 70, 70), (255, 228, 255),
|
28 |
+
(154, 208, 0), (193, 0, 92), (76, 91, 113), (255, 180, 195),
|
29 |
+
(106, 154, 176),
|
30 |
+
(230, 150, 140), (60, 143, 255), (128, 64, 128), (92, 82, 55),
|
31 |
+
(254, 212, 124), (73, 77, 174), (255, 160, 98), (255, 255, 255),
|
32 |
+
(104, 84, 109), (169, 164, 131), (225, 199, 255), (137, 54, 74),
|
33 |
+
(135, 158, 223), (7, 246, 231), (107, 255, 200), (58, 41, 149),
|
34 |
+
(183, 121, 142), (255, 73, 97), (107, 142, 35), (190, 153, 153),
|
35 |
+
(146, 139, 141),
|
36 |
+
(70, 130, 180), (134, 199, 156), (209, 226, 140), (96, 36, 108),
|
37 |
+
(96, 96, 96), (64, 170, 64), (152, 251, 152), (208, 229, 228),
|
38 |
+
(206, 186, 171), (152, 161, 64), (116, 112, 0), (0, 114, 143),
|
39 |
+
(102, 102, 156), (250, 141, 255)]
|
40 |
+
|
41 |
+
def cityscapes_palette():
|
42 |
+
return [[128, 64, 128],[244, 35, 232],[70, 70, 70],[102, 102, 156],[190, 153, 153],
|
43 |
+
[153, 153, 153],[250, 170, 30],[220, 220, 0],[107, 142, 35],[152, 251, 152],
|
44 |
+
[70, 130, 180], [220, 20, 60], [255, 0, 0], [0, 0, 142], [0, 0, 70],
|
45 |
+
[0, 60, 100], [0, 80, 100], [0, 0, 230], [119, 11, 32]]
|
46 |
+
|
47 |
+
def ade_palette():
|
48 |
+
"""Color palette that maps each class to RGB values.
|
49 |
+
|
50 |
+
This one is actually taken from ADE20k.
|
51 |
+
"""
|
52 |
+
return [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50],
|
53 |
+
[4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255],
|
54 |
+
[230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7],
|
55 |
+
[150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82],
|
56 |
+
[143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3],
|
57 |
+
[0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255],
|
58 |
+
[255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220],
|
59 |
+
[255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224],
|
60 |
+
[255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255],
|
61 |
+
[224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7],
|
62 |
+
[255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153],
|
63 |
+
[6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255],
|
64 |
+
[140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0],
|
65 |
+
[255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255],
|
66 |
+
[255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255],
|
67 |
+
[11, 200, 200], [255, 82, 0], [0, 255, 245], [0, 61, 255],
|
68 |
+
[0, 255, 112], [0, 255, 133], [255, 0, 0], [255, 163, 0],
|
69 |
+
[255, 102, 0], [194, 255, 0], [0, 143, 255], [51, 255, 0],
|
70 |
+
[0, 82, 255], [0, 255, 41], [0, 255, 173], [10, 0, 255],
|
71 |
+
[173, 255, 0], [0, 255, 153], [255, 92, 0], [255, 0, 255],
|
72 |
+
[255, 0, 245], [255, 0, 102], [255, 173, 0], [255, 0, 20],
|
73 |
+
[255, 184, 184], [0, 31, 255], [0, 255, 61], [0, 71, 255],
|
74 |
+
[255, 0, 204], [0, 255, 194], [0, 255, 82], [0, 10, 255],
|
75 |
+
[0, 112, 255], [51, 0, 255], [0, 194, 255], [0, 122, 255],
|
76 |
+
[0, 255, 163], [255, 153, 0], [0, 255, 10], [255, 112, 0],
|
77 |
+
[143, 255, 0], [82, 0, 255], [163, 255, 0], [255, 235, 0],
|
78 |
+
[8, 184, 170], [133, 0, 255], [0, 255, 92], [184, 0, 255],
|
79 |
+
[255, 0, 31], [0, 184, 255], [0, 214, 255], [255, 0, 112],
|
80 |
+
[92, 255, 0], [0, 224, 255], [112, 224, 255], [70, 184, 160],
|
81 |
+
[163, 0, 255], [153, 0, 255], [71, 255, 0], [255, 0, 163],
|
82 |
+
[255, 204, 0], [255, 0, 143], [0, 255, 235], [133, 255, 0],
|
83 |
+
[255, 0, 235], [245, 0, 255], [255, 0, 122], [255, 245, 0],
|
84 |
+
[10, 190, 212], [214, 255, 0], [0, 204, 255], [20, 0, 255],
|
85 |
+
[255, 255, 0], [0, 153, 255], [0, 41, 255], [0, 255, 204],
|
86 |
+
[41, 0, 255], [41, 255, 0], [173, 0, 255], [0, 245, 255],
|
87 |
+
[71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255],
|
88 |
+
[184, 255, 0], [0, 133, 255], [255, 214, 0], [25, 194, 194],
|
89 |
+
[102, 255, 0], [92, 0, 255]]
|
examples/armchair.jpg
ADDED
examples/cat-dog.jpg
ADDED
examples/person-bike.jpg
ADDED
predict.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import random
|
3 |
+
import numpy as np
|
4 |
+
from PIL import Image
|
5 |
+
from collections import defaultdict
|
6 |
+
from detectron2.data import MetadataCatalog
|
7 |
+
from detectron2.utils.visualizer import ColorMode, Visualizer
|
8 |
+
from color_palette import ade_palette
|
9 |
+
from transformers import MaskFormerImageProcessor, Mask2FormerForUniversalSegmentation
|
10 |
+
|
11 |
+
def load_model_and_processor(model_ckpt: str):
|
12 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
13 |
+
model = Mask2FormerForUniversalSegmentation.from_pretrained(model_ckpt).to(torch.device(device))
|
14 |
+
model.eval()
|
15 |
+
image_preprocessor = MaskFormerImageProcessor.from_pretrained(model_ckpt)
|
16 |
+
return model, image_preprocessor
|
17 |
+
|
18 |
+
def load_default_ckpt(segmentation_task: str):
|
19 |
+
if segmentation_task == "semantic":
|
20 |
+
default_pretrained_ckpt = "facebook/mask2former-swin-tiny-ade-semantic"
|
21 |
+
elif segmentation_task == "instance":
|
22 |
+
default_pretrained_ckpt = "facebook/mask2former-swin-small-coco-instance"
|
23 |
+
else:
|
24 |
+
default_pretrained_ckpt = "facebook/mask2former-swin-tiny-coco-panoptic"
|
25 |
+
return default_pretrained_ckpt
|
26 |
+
|
27 |
+
def draw_panoptic_segmentation(predicted_segmentation_map, seg_info, image):
|
28 |
+
metadata = MetadataCatalog.get("coco_2017_val_panoptic")
|
29 |
+
for res in seg_info:
|
30 |
+
res['category_id'] = res.pop('label_id')
|
31 |
+
pred_class = res['category_id']
|
32 |
+
isthing = pred_class in metadata.thing_dataset_id_to_contiguous_id.values()
|
33 |
+
res['isthing'] = bool(isthing)
|
34 |
+
|
35 |
+
visualizer = Visualizer(np.array(image)[:, :, ::-1], metadata=metadata, instance_mode=ColorMode.IMAGE)
|
36 |
+
out = visualizer.draw_panoptic_seg_predictions(
|
37 |
+
predicted_segmentation_map.cpu(), seg_info, alpha=0.5
|
38 |
+
)
|
39 |
+
output_img = Image.fromarray(out.get_image())
|
40 |
+
return output_img
|
41 |
+
|
42 |
+
def draw_semantic_segmentation(segmentation_map, image, palette):
|
43 |
+
color_segmentation_map = np.zeros((segmentation_map.shape[0], segmentation_map.shape[1], 3), dtype=np.uint8) # height, width, 3
|
44 |
+
for label, color in enumerate(palette):
|
45 |
+
color_segmentation_map[segmentation_map - 1 == label, :] = color
|
46 |
+
# Convert to BGR
|
47 |
+
ground_truth_color_seg = color_segmentation_map[..., ::-1]
|
48 |
+
|
49 |
+
img = np.array(image) * 0.5 + ground_truth_color_seg * 0.5
|
50 |
+
img = img.astype(np.uint8)
|
51 |
+
return img
|
52 |
+
|
53 |
+
def visualize_instance_seg_mask(mask):
|
54 |
+
image = np.zeros((mask.shape[0], mask.shape[1], 3))
|
55 |
+
labels = np.unique(mask)
|
56 |
+
label2color = {label: (random.randint(0, 1), random.randint(0, 255), random.randint(0, 255)) for label in labels}
|
57 |
+
for i in range(image.shape[0]):
|
58 |
+
for j in range(image.shape[1]):
|
59 |
+
image[i, j, :] = label2color[mask[i, j]]
|
60 |
+
image = image / 255
|
61 |
+
return image
|
62 |
+
|
63 |
+
def predict_masks(input_img_path: str, segmentation_task: str):
|
64 |
+
|
65 |
+
#load model and image processor
|
66 |
+
default_pretrained_ckpt = load_default_ckpt(segmentation_task)
|
67 |
+
model, image_processor = load_model_and_processor(default_pretrained_ckpt)
|
68 |
+
|
69 |
+
## pass input image through image processor
|
70 |
+
image = Image.open(input_img_path)
|
71 |
+
inputs = image_processor(images=image, return_tensors="pt")
|
72 |
+
|
73 |
+
## pass inputs to model for prediction
|
74 |
+
with torch.no_grad():
|
75 |
+
outputs = model(**inputs)
|
76 |
+
|
77 |
+
# pass outputs to processor for postprocessing
|
78 |
+
if segmentation_task == "semantic":
|
79 |
+
result = image_processor.post_process_semantic_segmentation(outputs, target_sizes=[image.size[::-1]])[0]
|
80 |
+
predicted_segmentation_map = result.cpu().numpy()
|
81 |
+
palette = ade_palette()
|
82 |
+
output_result = draw_semantic_segmentation(predicted_segmentation_map, image, palette)
|
83 |
+
|
84 |
+
elif segmentation_task == "instance":
|
85 |
+
pass
|
86 |
+
# result = image_processor.post_process_segmentation(outputs)[0].cpu().detach()
|
87 |
+
# predicted_segmentation_map = result["segmentation"]
|
88 |
+
# # predicted_segmentation_map = torch.argmax(result, dim=0).numpy()
|
89 |
+
# # results = torch.argmax(predicted_segmentation_map, dim=0).numpy()
|
90 |
+
# print("predicted_segmentation_map:",predicted_segmentation_map)
|
91 |
+
# print("type predicted_segmentation_map:", type(predicted_segmentation_map))
|
92 |
+
# output_result = visualize_instance_seg_mask(predicted_segmentation_map)
|
93 |
+
# # mask = plot_semantic_map(predicted_segmentation_map, image)
|
94 |
+
|
95 |
+
else:
|
96 |
+
result = image_processor.post_process_panoptic_segmentation(outputs, target_sizes=[image.size[::-1]])[0]
|
97 |
+
predicted_segmentation_map = result["segmentation"]
|
98 |
+
seg_info = result['segments_info']
|
99 |
+
output_result = draw_panoptic_segmentation(predicted_segmentation_map, seg_info, image)
|
100 |
+
|
101 |
+
return output_result
|