shivi commited on
Commit
68a69f9
1 Parent(s): abe2586

mask2former app setup

Browse files
app.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from predict import predict_masks
3
+ import glob
4
+
5
+ ##Create list of examples to be loaded
6
+ example_list = glob.glob("examples/*")
7
+ example_list = list(map(lambda el:[el], example_list))
8
+
9
+ demo = gr.Blocks()
10
+
11
+ with demo:
12
+
13
+ gr.Markdown("# **<p align='center'>Mask2Former: Masked Attention Transformer for Universal Segmentation</p>**")
14
+ gr.Markdown("This space demonstrates the use of Mask2Former. It was introduced in the paper [Masked-attention Mask Transformer for Universal Image Segmentation](https://arxiv.org/abs/2112.01527) and first released in [this repository](https://github.com/facebookresearch/Mask2Former/). \
15
+ Before Mask2Former, you'd have to resort to using a specialized architecture designed for solving a particular kind of image segmentation task (i.e. semantic, instance or panoptic segmentation). On the other hand, in the form of Mask2Former, for the first time, we have a single architecture that is capable of solving any segmentation task and performs on par or better than specialized architectures.")
16
+
17
+ with gr.Box():
18
+
19
+ with gr.Row():
20
+ segmentation_task = gr.Dropdown(["semantic", "panoptic"], value="panoptic", label="Segmentation Task", show_label=True)
21
+ with gr.Box():
22
+ with gr.Row():
23
+ input_image = gr.Image(type='filepath',label="Input Image", show_label=True)
24
+ output_mask = gr.Image(label="Predicted Masks", show_label=True)
25
+
26
+ gr.Markdown("**Predict**")
27
+
28
+ with gr.Box():
29
+ with gr.Row():
30
+ submit_button = gr.Button("Submit")
31
+
32
+ gr.Markdown("**Examples:**")
33
+
34
+ with gr.Column():
35
+ gr.Examples(example_list, [input_image, segmentation_task], output_mask, predict_masks)
36
+
37
+
38
+ submit_button.click(predict_masks, inputs=[input_image, segmentation_task], outputs=output_mask)
39
+
40
+ gr.Markdown('\n Demo created by: <a href=\"https://www.linkedin.com/in/shivalika-singh/\">Shivalika Singh</a>')
41
+
42
+ demo.launch(debug=True)
color_palette.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # color palattes for COCO, cityscapes and ADE datasets
2
+
3
+ def coco_panoptic_palette():
4
+ return [(220, 20, 60), (119, 11, 32), (0, 0, 142), (0, 0, 230),
5
+ (106, 0, 228), (0, 60, 100), (0, 80, 100), (0, 0, 70),
6
+ (0, 0, 192), (250, 170, 30), (100, 170, 30), (220, 220, 0),
7
+ (175, 116, 175), (250, 0, 30), (165, 42, 42), (255, 77, 255),
8
+ (0, 226, 252), (182, 182, 255), (0, 82, 0), (120, 166, 157),
9
+ (110, 76, 0), (174, 57, 255), (199, 100, 0), (72, 0, 118),
10
+ (255, 179, 240), (0, 125, 92), (209, 0, 151), (188, 208, 182),
11
+ (0, 220, 176), (255, 99, 164), (92, 0, 73), (133, 129, 255),
12
+ (78, 180, 255), (0, 228, 0), (174, 255, 243), (45, 89, 255),
13
+ (134, 134, 103), (145, 148, 174), (255, 208, 186),
14
+ (197, 226, 255), (171, 134, 1), (109, 63, 54), (207, 138, 255),
15
+ (151, 0, 95), (9, 80, 61), (84, 105, 51), (74, 65, 105),
16
+ (166, 196, 102), (208, 195, 210), (255, 109, 65), (0, 143, 149),
17
+ (179, 0, 194), (209, 99, 106), (5, 121, 0), (227, 255, 205),
18
+ (147, 186, 208), (153, 69, 1), (3, 95, 161), (163, 255, 0),
19
+ (119, 0, 170), (0, 182, 199), (0, 165, 120), (183, 130, 88),
20
+ (95, 32, 0), (130, 114, 135), (110, 129, 133), (166, 74, 118),
21
+ (219, 142, 185), (79, 210, 114), (178, 90, 62), (65, 70, 15),
22
+ (127, 167, 115), (59, 105, 106), (142, 108, 45), (196, 172, 0),
23
+ (95, 54, 80), (128, 76, 255), (201, 57, 1), (246, 0, 122),
24
+ (191, 162, 208), (255, 255, 128), (147, 211, 203),
25
+ (150, 100, 100), (168, 171, 172), (146, 112, 198),
26
+ (210, 170, 100), (92, 136, 89), (218, 88, 184), (241, 129, 0),
27
+ (217, 17, 255), (124, 74, 181), (70, 70, 70), (255, 228, 255),
28
+ (154, 208, 0), (193, 0, 92), (76, 91, 113), (255, 180, 195),
29
+ (106, 154, 176),
30
+ (230, 150, 140), (60, 143, 255), (128, 64, 128), (92, 82, 55),
31
+ (254, 212, 124), (73, 77, 174), (255, 160, 98), (255, 255, 255),
32
+ (104, 84, 109), (169, 164, 131), (225, 199, 255), (137, 54, 74),
33
+ (135, 158, 223), (7, 246, 231), (107, 255, 200), (58, 41, 149),
34
+ (183, 121, 142), (255, 73, 97), (107, 142, 35), (190, 153, 153),
35
+ (146, 139, 141),
36
+ (70, 130, 180), (134, 199, 156), (209, 226, 140), (96, 36, 108),
37
+ (96, 96, 96), (64, 170, 64), (152, 251, 152), (208, 229, 228),
38
+ (206, 186, 171), (152, 161, 64), (116, 112, 0), (0, 114, 143),
39
+ (102, 102, 156), (250, 141, 255)]
40
+
41
+ def cityscapes_palette():
42
+ return [[128, 64, 128],[244, 35, 232],[70, 70, 70],[102, 102, 156],[190, 153, 153],
43
+ [153, 153, 153],[250, 170, 30],[220, 220, 0],[107, 142, 35],[152, 251, 152],
44
+ [70, 130, 180], [220, 20, 60], [255, 0, 0], [0, 0, 142], [0, 0, 70],
45
+ [0, 60, 100], [0, 80, 100], [0, 0, 230], [119, 11, 32]]
46
+
47
+ def ade_palette():
48
+ """Color palette that maps each class to RGB values.
49
+
50
+ This one is actually taken from ADE20k.
51
+ """
52
+ return [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50],
53
+ [4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255],
54
+ [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7],
55
+ [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82],
56
+ [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3],
57
+ [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255],
58
+ [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220],
59
+ [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224],
60
+ [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255],
61
+ [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7],
62
+ [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153],
63
+ [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255],
64
+ [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0],
65
+ [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255],
66
+ [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255],
67
+ [11, 200, 200], [255, 82, 0], [0, 255, 245], [0, 61, 255],
68
+ [0, 255, 112], [0, 255, 133], [255, 0, 0], [255, 163, 0],
69
+ [255, 102, 0], [194, 255, 0], [0, 143, 255], [51, 255, 0],
70
+ [0, 82, 255], [0, 255, 41], [0, 255, 173], [10, 0, 255],
71
+ [173, 255, 0], [0, 255, 153], [255, 92, 0], [255, 0, 255],
72
+ [255, 0, 245], [255, 0, 102], [255, 173, 0], [255, 0, 20],
73
+ [255, 184, 184], [0, 31, 255], [0, 255, 61], [0, 71, 255],
74
+ [255, 0, 204], [0, 255, 194], [0, 255, 82], [0, 10, 255],
75
+ [0, 112, 255], [51, 0, 255], [0, 194, 255], [0, 122, 255],
76
+ [0, 255, 163], [255, 153, 0], [0, 255, 10], [255, 112, 0],
77
+ [143, 255, 0], [82, 0, 255], [163, 255, 0], [255, 235, 0],
78
+ [8, 184, 170], [133, 0, 255], [0, 255, 92], [184, 0, 255],
79
+ [255, 0, 31], [0, 184, 255], [0, 214, 255], [255, 0, 112],
80
+ [92, 255, 0], [0, 224, 255], [112, 224, 255], [70, 184, 160],
81
+ [163, 0, 255], [153, 0, 255], [71, 255, 0], [255, 0, 163],
82
+ [255, 204, 0], [255, 0, 143], [0, 255, 235], [133, 255, 0],
83
+ [255, 0, 235], [245, 0, 255], [255, 0, 122], [255, 245, 0],
84
+ [10, 190, 212], [214, 255, 0], [0, 204, 255], [20, 0, 255],
85
+ [255, 255, 0], [0, 153, 255], [0, 41, 255], [0, 255, 204],
86
+ [41, 0, 255], [41, 255, 0], [173, 0, 255], [0, 245, 255],
87
+ [71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255],
88
+ [184, 255, 0], [0, 133, 255], [255, 214, 0], [25, 194, 194],
89
+ [102, 255, 0], [92, 0, 255]]
examples/armchair.jpg ADDED
examples/cat-dog.jpg ADDED
examples/person-bike.jpg ADDED
predict.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import random
3
+ import numpy as np
4
+ from PIL import Image
5
+ from collections import defaultdict
6
+ from detectron2.data import MetadataCatalog
7
+ from detectron2.utils.visualizer import ColorMode, Visualizer
8
+ from color_palette import ade_palette
9
+ from transformers import MaskFormerImageProcessor, Mask2FormerForUniversalSegmentation
10
+
11
+ def load_model_and_processor(model_ckpt: str):
12
+ device = "cuda" if torch.cuda.is_available() else "cpu"
13
+ model = Mask2FormerForUniversalSegmentation.from_pretrained(model_ckpt).to(torch.device(device))
14
+ model.eval()
15
+ image_preprocessor = MaskFormerImageProcessor.from_pretrained(model_ckpt)
16
+ return model, image_preprocessor
17
+
18
+ def load_default_ckpt(segmentation_task: str):
19
+ if segmentation_task == "semantic":
20
+ default_pretrained_ckpt = "facebook/mask2former-swin-tiny-ade-semantic"
21
+ elif segmentation_task == "instance":
22
+ default_pretrained_ckpt = "facebook/mask2former-swin-small-coco-instance"
23
+ else:
24
+ default_pretrained_ckpt = "facebook/mask2former-swin-tiny-coco-panoptic"
25
+ return default_pretrained_ckpt
26
+
27
+ def draw_panoptic_segmentation(predicted_segmentation_map, seg_info, image):
28
+ metadata = MetadataCatalog.get("coco_2017_val_panoptic")
29
+ for res in seg_info:
30
+ res['category_id'] = res.pop('label_id')
31
+ pred_class = res['category_id']
32
+ isthing = pred_class in metadata.thing_dataset_id_to_contiguous_id.values()
33
+ res['isthing'] = bool(isthing)
34
+
35
+ visualizer = Visualizer(np.array(image)[:, :, ::-1], metadata=metadata, instance_mode=ColorMode.IMAGE)
36
+ out = visualizer.draw_panoptic_seg_predictions(
37
+ predicted_segmentation_map.cpu(), seg_info, alpha=0.5
38
+ )
39
+ output_img = Image.fromarray(out.get_image())
40
+ return output_img
41
+
42
+ def draw_semantic_segmentation(segmentation_map, image, palette):
43
+ color_segmentation_map = np.zeros((segmentation_map.shape[0], segmentation_map.shape[1], 3), dtype=np.uint8) # height, width, 3
44
+ for label, color in enumerate(palette):
45
+ color_segmentation_map[segmentation_map - 1 == label, :] = color
46
+ # Convert to BGR
47
+ ground_truth_color_seg = color_segmentation_map[..., ::-1]
48
+
49
+ img = np.array(image) * 0.5 + ground_truth_color_seg * 0.5
50
+ img = img.astype(np.uint8)
51
+ return img
52
+
53
+ def visualize_instance_seg_mask(mask):
54
+ image = np.zeros((mask.shape[0], mask.shape[1], 3))
55
+ labels = np.unique(mask)
56
+ label2color = {label: (random.randint(0, 1), random.randint(0, 255), random.randint(0, 255)) for label in labels}
57
+ for i in range(image.shape[0]):
58
+ for j in range(image.shape[1]):
59
+ image[i, j, :] = label2color[mask[i, j]]
60
+ image = image / 255
61
+ return image
62
+
63
+ def predict_masks(input_img_path: str, segmentation_task: str):
64
+
65
+ #load model and image processor
66
+ default_pretrained_ckpt = load_default_ckpt(segmentation_task)
67
+ model, image_processor = load_model_and_processor(default_pretrained_ckpt)
68
+
69
+ ## pass input image through image processor
70
+ image = Image.open(input_img_path)
71
+ inputs = image_processor(images=image, return_tensors="pt")
72
+
73
+ ## pass inputs to model for prediction
74
+ with torch.no_grad():
75
+ outputs = model(**inputs)
76
+
77
+ # pass outputs to processor for postprocessing
78
+ if segmentation_task == "semantic":
79
+ result = image_processor.post_process_semantic_segmentation(outputs, target_sizes=[image.size[::-1]])[0]
80
+ predicted_segmentation_map = result.cpu().numpy()
81
+ palette = ade_palette()
82
+ output_result = draw_semantic_segmentation(predicted_segmentation_map, image, palette)
83
+
84
+ elif segmentation_task == "instance":
85
+ pass
86
+ # result = image_processor.post_process_segmentation(outputs)[0].cpu().detach()
87
+ # predicted_segmentation_map = result["segmentation"]
88
+ # # predicted_segmentation_map = torch.argmax(result, dim=0).numpy()
89
+ # # results = torch.argmax(predicted_segmentation_map, dim=0).numpy()
90
+ # print("predicted_segmentation_map:",predicted_segmentation_map)
91
+ # print("type predicted_segmentation_map:", type(predicted_segmentation_map))
92
+ # output_result = visualize_instance_seg_mask(predicted_segmentation_map)
93
+ # # mask = plot_semantic_map(predicted_segmentation_map, image)
94
+
95
+ else:
96
+ result = image_processor.post_process_panoptic_segmentation(outputs, target_sizes=[image.size[::-1]])[0]
97
+ predicted_segmentation_map = result["segmentation"]
98
+ seg_info = result['segments_info']
99
+ output_result = draw_panoptic_segmentation(predicted_segmentation_map, seg_info, image)
100
+
101
+ return output_result