Kevin Sun commited on
Commit
6cd90b7
·
1 Parent(s): 8927fea

init commit

Browse files
CLIP_as_RNN ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit 2457b49b339498af726408aa6673155de408c0f0
README.md CHANGED
@@ -1,13 +1,14 @@
1
- ---
2
- title: CLIP As RNN
3
- emoji: 🏢
4
- colorFrom: purple
5
- colorTo: purple
6
- sdk: gradio
7
- sdk_version: 4.29.0
8
- app_file: app.py
9
- pinned: false
10
- license: apache-2.0
11
- ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # CLIP as RNN: Segment Countless Visual Concepts without Training Endeavor (CaR)
 
 
 
 
 
 
 
 
 
 
2
 
3
+ This repo holds the implementation code of the paper [CLIP as RNN: Segment Countless Visual Concepts without Training Endeavor (CaR)](https://arxiv.org/abs/2312.07661) by Shuyang Sun, Runjia Li, Philip Torr, Xiuye Gu, and Siyang Li:
4
+
5
+ ```
6
+ @article{clip_as_rnn,
7
+ title={CLIP as RNN: Segment Countless Visual Concepts without Training Endeavor},
8
+ author={Shuyang Sun and Runjia Li and Philip Torr and Xiuye Gu and Siyang Li},
9
+ year={2023},
10
+ eprint={2312.07661},
11
+ archivePrefix={arXiv},
12
+ primaryClass={cs.CV}
13
+ }
14
+ ```
app.py ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Run a Gradio demo of the CaR model on a single image."""
2
+
3
+ import numpy as np
4
+ import argparse
5
+ from functools import reduce
6
+ import PIL.Image as Image
7
+ import torch
8
+ from modeling.model import CaR
9
+ from utils.utils import Config, load_yaml
10
+ import matplotlib.pyplot as plt
11
+ import colorsys
12
+ from modeling.post_process.post_process import match_masks, generate_masks_from_sam
13
+ from sam.sam import SAMPipeline
14
+ from sam.utils import build_sam_config
15
+ import random
16
+ import gradio as gr
17
+
18
+ # set random seed
19
+ random.seed(15)
20
+ np.random.seed(0)
21
+ torch.manual_seed(0)
22
+
23
+
24
+ CFG_PATH = "configs/demo/pokemon.yaml"
25
+
26
+ def generate_distinct_colors(n):
27
+ colors = []
28
+ # generate a random number from 0 to 1
29
+ random_color_bias = random.random()
30
+ for i in range(n):
31
+ hue = float(i) / n
32
+ hue += random_color_bias
33
+ hue = hue % 1.0
34
+ rgb = colorsys.hsv_to_rgb(hue, 1.0, 1.0)
35
+ # Convert RGB values from [0, 1] range to [0, 255]
36
+ colors.append(tuple(int(val * 255) for val in rgb))
37
+ return colors
38
+
39
+
40
+ def overlap_masks(masks):
41
+ """
42
+ Overlap masks to generate a single mask for visualization.
43
+
44
+ Parameters:
45
+ - masks: list of np.arrays of shape (H, W) representing binary masks for each class
46
+
47
+ Returns:
48
+ - overlap_mask: list of np.array of shape (H, W) that have no overlaps
49
+ """
50
+ overlap_mask = torch.zeros_like(masks[0])
51
+ for mask_idx, mask in enumerate(masks):
52
+ overlap_mask[mask > 0] = mask_idx + 1
53
+
54
+ clean_masks = [overlap_mask == mask_idx +
55
+ 1 for mask_idx in range(len(masks))]
56
+ clean_masks = torch.stack(clean_masks, dim=0)
57
+
58
+ return clean_masks
59
+
60
+
61
+ def visualize_segmentation(image,
62
+ masks,
63
+ class_names,
64
+ alpha=0.7,
65
+ y_list=None,
66
+ x_list=None):
67
+ """
68
+ Visualize segmentation masks on an image.
69
+
70
+ Parameters:
71
+ - image: np.array of shape (H, W, 3) representing the RGB image
72
+ - masks: list of np.arrays of shape (H, W) representing binary masks for each class
73
+ - class_names: list of strings representing names of each class
74
+ - alpha: float, transparency level of masks on the image
75
+
76
+ Returns:
77
+ - visualization: plt.figure object
78
+ """
79
+ # Create a figure and axis
80
+ fig, ax = plt.subplots(1, figsize=(12, 9))
81
+ # Display the image
82
+ # ax.imshow(image)
83
+ # Generate distinct colors for each mask
84
+ final_mask = np.zeros(
85
+ (masks.shape[1], masks.shape[2], 3), dtype=np.float32)
86
+ binary_final_mask = np.zeros(
87
+ (masks.shape[1], masks.shape[2]), dtype=np.float32)
88
+ colors = generate_distinct_colors(len(class_names))
89
+ idx = 0
90
+ for mask, color, class_name in zip(masks, colors, class_names):
91
+ # Overlay the mask
92
+ final_mask += np.dstack([mask * c for c in color])
93
+ binary_final_mask += mask
94
+ # Find a representative point (e.g., centroid) for placing the label
95
+ if y_list is None or x_list is None:
96
+ y, x = np.argwhere(mask).mean(axis=0)
97
+ else:
98
+ y, x = y_list[idx], x_list[idx]
99
+ ax.text(x, y, class_name, color='white',
100
+ fontsize=22, va='center', ha='center',
101
+ bbox=dict(facecolor='black', alpha=0.7, edgecolor='none'))
102
+ idx += 1
103
+
104
+ image[binary_final_mask > 0] = image[binary_final_mask > 0] * (1 - alpha)
105
+ final_image = image + final_mask * alpha
106
+ final_image = final_image.astype(np.uint8)
107
+ ax.imshow(final_image)
108
+ # Remove axis ticks and labels
109
+ ax.axis('off')
110
+ return fig
111
+
112
+
113
+ def get_sam_masks(cfg,
114
+ masks,
115
+ image_path=None,
116
+ img_sam=None,
117
+ pipeline=None):
118
+ # image_id = image_path.split('/')[-1].split('.')[0]
119
+ # sam_mask_path = os.path.join(cfg.test.sam_mask_root, f'{image_id}.npz')
120
+ # if os.path.exists(sam_mask_path):
121
+ # sam_mask_masks = np.load(sam_mask_path, allow_pickle=True)
122
+ # mask_tensor = torch.from_numpy(sam_mask_masks['mask_tensor'])
123
+ # mask_list = sam_mask_path['mask_list']
124
+ # else:
125
+ print("generating sam masks online")
126
+ if img_sam is None and image_path is not None:
127
+ raise ValueError(
128
+ 'Please provide either the image path or the image numpy array.')
129
+
130
+ mask_tensor, mask_list = generate_masks_from_sam(
131
+ image_path,
132
+ save_path='./',
133
+ pipeline=pipeline,
134
+ img_sam=img_sam,
135
+ visualize=False,
136
+ )
137
+ mask_tensor = mask_tensor.to(masks.device)
138
+ # only conduct sam on masks that is not all zero
139
+ attn_map, mask_ids = [], []
140
+ for mask_id, mask in enumerate(masks):
141
+ if torch.sum(mask) > 0:
142
+ attn_map.append(mask.unsqueeze(0))
143
+ mask_ids.append(mask_id)
144
+ matched_masks = [match_masks(
145
+ mask_tensor,
146
+ attn,
147
+ mask_list,
148
+ iom_thres=cfg.car.iom_thres,
149
+ min_pred_threshold=cfg.sam.min_pred_threshold)
150
+ for attn in attn_map]
151
+ for matched_mask, mask_id in zip(matched_masks, mask_ids):
152
+ sam_masks = np.array([item['segmentation'] for item in matched_mask])
153
+ sam_mask = np.any(sam_masks, axis=0)
154
+ masks[mask_id] = torch.from_numpy(sam_mask).to(masks.device)
155
+ return masks
156
+
157
+
158
+ def load_sam(cfg, device):
159
+ sam_checkpoint, model_type = build_sam_config(cfg)
160
+ pipeline = SAMPipeline(
161
+ sam_checkpoint,
162
+ model_type,
163
+ device=device,
164
+ points_per_side=cfg.sam.points_per_side,
165
+ pred_iou_thresh=cfg.sam.pred_iou_thresh,
166
+ stability_score_thresh=cfg.sam.stability_score_thresh,
167
+ box_nms_thresh=cfg.sam.box_nms_thresh,
168
+ )
169
+ return pipeline
170
+
171
+ def generate(img,
172
+ class_names,
173
+ clip_thresh,
174
+ mask_thresh,
175
+ confidence_thresh,
176
+ post_process,
177
+ stability_score_thresh,
178
+ box_nms_thresh,
179
+ iom_thres,
180
+ min_pred_threshold):
181
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
182
+ cfg = Config(**load_yaml(CFG_PATH))
183
+ cfg.car.clipes_threshold = clip_thresh
184
+ cfg.car.mask_threshold = mask_thresh
185
+ cfg.car.confidence_threshold = confidence_thresh
186
+ cfg.sam.stability_score_thresh = stability_score_thresh
187
+ cfg.sam.box_nms_thresh = box_nms_thresh
188
+ cfg.car.iom_thres = iom_thres
189
+ cfg.sam.min_pred_threshold = min_pred_threshold
190
+ car_model = CaR(cfg,
191
+ visualize=True,
192
+ seg_mode='semantic',
193
+ device=device)
194
+
195
+
196
+ # resize image by dividing 2 if the size is larger than 1000
197
+ if img.size[0] > 1000:
198
+ img = img.resize((img.size[0] // 2, img.size[1] // 2))
199
+
200
+ y_list, x_list = None, None
201
+ class_names = class_names.split(',')
202
+ sentences = class_names
203
+
204
+ # class_names = ['the women chatting', 'the women chatting', 'table', 'fridge', 'cooking pot']
205
+
206
+ pseudo_masks, _, _ = car_model(
207
+ img, sentences, 1)
208
+
209
+ if post_process == 'SAM':
210
+ pipeline = load_sam(cfg, device)
211
+ pseudo_masks = get_sam_masks(
212
+ cfg,
213
+ pseudo_masks,
214
+ image_path=None,
215
+ img_sam=np.array(img),
216
+ pipeline=pipeline)
217
+ pseudo_masks = overlap_masks(pseudo_masks)
218
+
219
+ # visualize segmentation masks
220
+ demo_fig = visualize_segmentation(np.array(img),
221
+ pseudo_masks.detach().cpu().numpy(),
222
+ class_names,
223
+ y_list=y_list,
224
+ x_list=x_list)
225
+
226
+ # convert the demo figure to an pil image
227
+ demo_fig.canvas.draw()
228
+ demo_img = np.array(demo_fig.canvas.renderer._renderer)
229
+ demo_img = Image.fromarray(demo_img)
230
+ return demo_img
231
+
232
+
233
+
234
+ if __name__ == "__main__":
235
+ parser = argparse.ArgumentParser('car')
236
+ parser.add_argument("--cfg-path",
237
+ default='configs/local_car.yaml',
238
+ help="path to configuration file.")
239
+ args = parser.parse_args()
240
+
241
+ demo = gr.Interface(generate,
242
+ inputs=[gr.Image(label="upload an image", type="pil"),
243
+ "text",
244
+ gr.Slider(label="clip thresh",
245
+ minimum=0,
246
+ maximum=1,
247
+ value=0.4,
248
+ step=0.1,
249
+ info="the threshold for clip-es adversarial heatmap clipping"),
250
+ gr.Slider(label="mask thresh",
251
+ minimum=0,
252
+ maximum=1,
253
+ value=0.6,
254
+ step=0.1,
255
+ info="the binariation threshold for the mask to generate visual prompt"),
256
+ gr.Slider(label="confidence thresh",
257
+ minimum=0,
258
+ maximum=1,
259
+ value=0,
260
+ step=0.1,
261
+ info="the threshold for filtering the proposed classes"),
262
+ gr.Radio(["CRF", "SAM"], label="post process", value="CRF", info="choose the post process method"),
263
+ gr.Slider(label="stability score thresh for SAM mask proposal \n(only when SAM is chosen for post process)",
264
+ minimum=0,
265
+ maximum=1,
266
+ value=0.95,
267
+ step=0.1),
268
+ gr.Slider(label="box nms thresh for SAM mask proposal \n(only when SAM is chosen for post process)", minimum=0, maximum=1, value=0.7, step=0.1),
269
+ gr.Slider(label="intersection over mask threshold for SAM mask proposal \n(only when SAM is chosen for post process)", minimum=0, maximum=1, value=0.5, step=0.1),
270
+ gr.Slider(label="minimum prediction threshold for SAM mask proposal \n(only when SAM is chosen for post process)", minimum=0, maximum=1, value=0.03, step=0.01)],
271
+ outputs="image",
272
+ title="CLIP as RNN: Segment Countless Visual Concepts without Training Endeavor",
273
+ description="This is the official demo for CLIP as RNN. Please upload an image and type in the class names (connected by ',' e.g. cat,dog,human) you want to segment. The model will generate the segmentation masks for the input image. You can also adjust the clip thresh, mask thresh and confidence thresh to get better results.",
274
+ examples=[["demo/pokemon1.jpg", "Charmander,Bulbasaur,Squirtle", 0.6, 0.6, 0, "SAM", 0.95, 0.7, 0.6, 0.01],
275
+ ["demo/batman.jpg", "Batman,Joker,Cat Woman", 0.6, 0.6, 0, "SAM", 0.95, 0.7, 0.6, 0.01],
276
+ ["demo/avengers1.jpg", "Thor,Captain America,Hulk,Iron Man", 0.6, 0.6, 0, "SAM", 0.89, 0.65, 0.5, 0.03],
277
+
278
+ ])
279
+ demo.launch(share=True)
280
+
281
+
282
+ # device = "cuda" if torch.cuda.is_available() else "cpu"
283
+
284
+
285
+ stop = 0
configs/ade_150.yaml ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ clip:
2
+ semantic_clip_model_name: 'ViT-L/14'
3
+ semantic_pretrained_data: 'openai'
4
+ clip_model_name: "ViT-B/16"
5
+ pretrained_data: 'openai'
6
+
7
+ car:
8
+ iom_thres: 0.6
9
+ mask_threshold: 0.6
10
+ min_area_ratio: 0.2
11
+ num_iteration: 1
12
+ confidence_threshold: 0.25
13
+ clipes_threshold: 0.7
14
+ bg_factor: 1
15
+ stuff_bg_factor: 1
16
+ visual_prompt_type: ['gray', 'blur']
17
+ stuff_visual_prompt_type: ['gray', 'blur']
18
+ semantic_templates: ['a clean origami {}.',
19
+ 'a photo of a {}.',
20
+ 'This is a photo of a {}',
21
+ 'There is a {} in the scene',
22
+ 'There is the {} in the scene',
23
+ 'a photo of a {} in the scene',
24
+ 'a photo of a small {}.',
25
+ 'a photo of a medium {}.',
26
+ 'a photo of a large {}.',
27
+ 'This is a photo of a small {}.',
28
+ 'This is a photo of a medium {}.',
29
+ 'This is a photo of a large {}.',
30
+ 'There is a small {} in the scene.',
31
+ 'There is a medium {} in the scene.',
32
+ 'There is a large {} in the scene.']
33
+
34
+ bg_cls: ['ground', 'land', 'grass', 'tree', 'building',
35
+ 'wall', 'sky', 'lake', 'water', 'river', 'sea',
36
+ 'railway', 'railroad', 'helmet', 'cloud', 'house',
37
+ 'mountain', 'ocean', 'road', 'rock', 'street',
38
+ 'valley', 'bridge']
39
+
40
+ test:
41
+ algo: "car"
42
+ ds_name: "ade"
43
+ seg_mode: "semantic"
44
+ split: 'validation'
45
+ data_root: "$YOUR_ADE_DATA_DIR"
46
+ # You need to extract the sam mask for the ADE dataset if use_pseudo=False
47
+ sam_mask_root: "$YOUR_SAM_MASK_DIR"
48
+ output_path: "./outputs/"
49
+ use_pseudo: True
50
+ n_class: 151
51
+ num_chunks: 1
52
+ chunk_index: 0
53
+ ignore_background: True
54
+
55
+ save_path: "./outputs"
configs/ade_847.yaml ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ clip:
2
+ semantic_clip_model_name: 'ViT-L/14'
3
+ semantic_pretrained_data: 'openai'
4
+ clip_model_name: "ViT-B/16"
5
+ pretrained_data: 'openai'
6
+
7
+ car:
8
+ iom_thres: 0.6
9
+ mask_threshold: 0.6
10
+ min_area_ratio: 0.2
11
+ num_iteration: 1
12
+ confidence_threshold: 0.25
13
+ clipes_threshold: 0.7
14
+ bg_factor: 1
15
+ stuff_bg_factor: 1
16
+ visual_prompt_type: ['gray', 'blur']
17
+ stuff_visual_prompt_type: ['gray', 'blur']
18
+ semantic_templates: ['a clean origami {}.',
19
+ 'a photo of a {}.',
20
+ 'This is a photo of a {}',
21
+ 'There is a {} in the scene',
22
+ 'There is the {} in the scene',
23
+ 'a photo of a {} in the scene',
24
+ 'a photo of a small {}.',
25
+ 'a photo of a medium {}.',
26
+ 'a photo of a large {}.',
27
+ 'This is a photo of a small {}.',
28
+ 'This is a photo of a medium {}.',
29
+ 'This is a photo of a large {}.',
30
+ 'There is a small {} in the scene.',
31
+ 'There is a medium {} in the scene.',
32
+ 'There is a large {} in the scene.']
33
+
34
+ bg_cls: ['ground', 'land', 'grass', 'tree', 'building',
35
+ 'wall', 'sky', 'lake', 'water', 'river', 'sea',
36
+ 'railway', 'railroad', 'helmet', 'cloud', 'house',
37
+ 'mountain', 'ocean', 'road', 'rock', 'street',
38
+ 'valley', 'bridge']
39
+
40
+ test:
41
+ algo: "car"
42
+ ds_name: "ade_847"
43
+ seg_mode: "semantic"
44
+ split: 'validation'
45
+ data_root: "$YOUR_ADE_DATA_DIR"
46
+ # You need to extract the sam mask for the ADE dataset if use_pseudo=False
47
+ sam_mask_root: "$YOUR_SAM_MASK_DIR"
48
+ output_path: "./outputs/"
49
+ use_pseudo: True
50
+ n_class: 847
51
+ num_chunks: 1
52
+ chunk_index: 0
53
+ ignore_background: True
54
+
55
+ save_path: "./outputs"
configs/coco.yaml ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ clip:
2
+ semantic_clip_model_name: 'ViT-L/14'
3
+ semantic_pretrained_data: 'openai'
4
+ clip_model_name: "ViT-B/16"
5
+ pretrained_data: 'openai'
6
+
7
+
8
+ car:
9
+ iom_thres: 0.7
10
+ mask_threshold: 0.5
11
+ min_area_ratio: 0.2
12
+ num_iteration: 1
13
+ confidence_threshold: 0.3
14
+ clipes_threshold: 0.5
15
+ visual_prompt_type: ['blur', 'gray']
16
+ semantic_templates: ['a clean origami {}.',
17
+ 'a photo of a {}.',
18
+ 'This is a photo of a {}',
19
+ 'There is a {} in the scene',
20
+ 'There is the {} in the scene',
21
+ 'a photo of a {} in the scene',
22
+ 'a photo of a small {}.',
23
+ 'a photo of a medium {}.',
24
+ 'a photo of a large {}.',
25
+ 'This is a photo of a small {}.',
26
+ 'This is a photo of a medium {}.',
27
+ 'This is a photo of a large {}.',
28
+ 'There is a small {} in the scene.',
29
+ 'There is a medium {} in the scene.',
30
+ 'There is a large {} in the scene.']
31
+ bg_cls: ['ground', 'land', 'grass', 'tree', 'building',
32
+ 'wall', 'sky', 'lake', 'water', 'river', 'sea',
33
+ 'railway', 'railroad', 'helmet', 'cloud', 'house',
34
+ 'mountain', 'ocean', 'road', 'rock', 'street',
35
+ 'valley', 'bridge']
36
+
37
+ test:
38
+ algo: "car"
39
+ ds_name: "coco"
40
+ seg_mode: "semantic"
41
+ data_root: "$YOUR_DATA_DIR"
42
+ # You need to extract the sam mask for the ADE dataset if use_pseudo=False
43
+ sam_mask_root: "$YOUR_SAM_MASK_DIR"
44
+ output_path: "./outputs/"
45
+ use_pseudo: True
46
+ split: "val"
47
+ n_class: 81
48
+ num_chunks: 1
49
+ chunk_index: 0
50
+
51
+ save_path: "./outputs"
configs/gres.yaml ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ clip:
2
+ semantic_clip_model_name: 'ViT-L/14'
3
+ semantic_pretrained_data: 'openai'
4
+ clip_model_name: "ViT-B/16"
5
+ pretrained_data: 'openai'
6
+
7
+ car:
8
+ iom_thres: 0.5
9
+ mask_threshold: 0.5
10
+ confidence_threshold: 0
11
+ clipes_threshold: 0.3
12
+ cam_text_template: 'a clean origami {}.'
13
+ color: [255, 0, 0] # red
14
+ visual_prompt_type: ['circle']
15
+ bg_cls: ['ground', 'land', 'grass', 'tree', 'building',
16
+ 'wall', 'sky', 'lake', 'water', 'river', 'sea',
17
+ 'railway', 'railroad', 'helmet', 'cloud', 'house',
18
+ 'mountain', 'ocean', 'road', 'rock', 'street',
19
+ 'valley', 'bridge']
20
+
21
+
22
+ test:
23
+ algo: "car"
24
+ ds_name: "gres"
25
+ split: 'val'
26
+ seg_mode: "refer"
27
+ data_root: "$YOUR_ADE_DATA_DIR"
28
+ output_path: "./outputs/"
29
+ prompts_augment: False
30
+ use_pseudo: True
31
+ use_background: False
32
+ prompts_prefix: False
33
+ prompts_augment: False
34
+
35
+ sentence_process:
36
+ mixing_alpha: 0.
37
+
38
+ save_path: "./outputs"
configs/pascal_context.yaml ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ clip:
2
+ semantic_clip_model_name: 'ViT-L/14'
3
+ semantic_pretrained_data: 'openai'
4
+ clip_model_name: "ViT-B/16"
5
+ pretrained_data: 'openai'
6
+
7
+
8
+ car:
9
+ iom_thres: 0.5
10
+ mask_threshold: 0.6
11
+ stuff_mask_threshold: 0.6
12
+ min_area_ratio: 0.2
13
+ num_iteration: 1
14
+ confidence_threshold: 0.25
15
+ clipes_threshold: 0.4
16
+ bg_factor: 1
17
+ stuff_bg_factor: 1
18
+ has_pamr: False
19
+ visual_prompt_type: ['blur', 'circle']
20
+ stuff_visual_prompt_type: ['blur', 'gray']
21
+ semantic_templates: ['a clean origami {}.',
22
+ 'a photo of a {}.',
23
+ 'This is a photo of a {}',
24
+ 'There is a {} in the scene',
25
+ 'There is the {} in the scene',
26
+ 'a photo of a {} in the scene',
27
+ 'a photo of a small {}.',
28
+ 'a photo of a medium {}.',
29
+ 'a photo of a large {}.',
30
+ 'This is a photo of a small {}.',
31
+ 'This is a photo of a medium {}.',
32
+ 'This is a photo of a large {}.',
33
+ 'There is a small {} in the scene.',
34
+ 'There is a medium {} in the scene.',
35
+ 'There is a large {} in the scene.']
36
+
37
+
38
+
39
+ bg_cls: ['ground', 'land', 'grass', 'tree', 'building',
40
+ 'wall', 'sky', 'lake', 'water', 'river', 'sea',
41
+ 'railway', 'railroad', 'helmet', 'cloud', 'house',
42
+ 'mountain', 'ocean', 'road', 'rock', 'street',
43
+ 'valley', 'bridge']
44
+
45
+
46
+ test:
47
+ algo: "car"
48
+ ds_name: "context"
49
+ seg_mode: "semantic"
50
+ n_class: 60
51
+ data_root: "$YOUR_DATA_DIR"
52
+ output_path: "./outputs/"
53
+ use_pseudo: True
54
+ split: "val"
55
+ num_chunks: 1
56
+ chunk_index: 0
57
+ ignore_background: False
58
+
59
+
60
+ save_path: "./outputs"
configs/pascal_context_459.yaml ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ clip:
2
+ semantic_clip_model_name: 'ViT-L/14'
3
+ semantic_pretrained_data: 'openai'
4
+ clip_model_name: "ViT-B/16"
5
+ pretrained_data: 'openai'
6
+
7
+ car:
8
+ iom_thres: 0.6
9
+ mask_threshold: 0.4
10
+ min_area_ratio: 0.2
11
+ num_iteration: 1
12
+ confidence_threshold: 0.25 # 0.2
13
+ clipes_threshold: 0.7
14
+ bg_factor: 1
15
+ stuff_bg_factor: 1
16
+ visual_prompt_type: ['gray', 'blur']
17
+ stuff_visual_prompt_type: ['gray', 'blur']
18
+ semantic_templates: ['a clean origami {}.',
19
+ 'a photo of a {}.',
20
+ 'This is a photo of a {}',
21
+ 'There is a {} in the scene',
22
+ 'There is the {} in the scene',
23
+ 'a photo of a {} in the scene',
24
+ 'a photo of a small {}.',
25
+ 'a photo of a medium {}.',
26
+ 'a photo of a large {}.',
27
+ 'This is a photo of a small {}.',
28
+ 'This is a photo of a medium {}.',
29
+ 'This is a photo of a large {}.',
30
+ 'There is a small {} in the scene.',
31
+ 'There is a medium {} in the scene.',
32
+ 'There is a large {} in the scene.']
33
+
34
+ bg_cls: ['ground', 'land', 'grass', 'tree', 'building',
35
+ 'wall', 'sky', 'lake', 'water', 'river', 'sea',
36
+ 'railway', 'railroad', 'helmet', 'cloud', 'house',
37
+ 'mountain', 'ocean', 'road', 'rock', 'street',
38
+ 'valley', 'bridge']
39
+
40
+ test:
41
+ algo: "car"
42
+ ds_name: "pascal_459"
43
+ seg_mode: "semantic"
44
+ split: 'validation'
45
+ data_root: "$YOUR_DATA_DIR"
46
+ # You need to extract the sam mask for the ADE dataset if use_pseudo=False
47
+ sam_mask_root: "$YOUR_SAM_MASK_DIR"
48
+ output_path: "./outputs/"
49
+ use_pseudo: True
50
+ n_class: 460
51
+ num_chunks: 1
52
+ chunk_index: 0
53
+ ignore_background: True
54
+
55
+ save_path: "./outputs"
configs/refcoco+.yaml ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ clip:
2
+ semantic_clip_model_name: 'ViT-B/16'
3
+ semantic_pretrained_data: 'openai'
4
+ clip_model_name: "ViT-B/16"
5
+ pretrained_data: 'openai'
6
+
7
+ car:
8
+ iom_thres: 0.5
9
+ mask_threshold: 0.2
10
+ confidence_threshold: 0.1
11
+ clipes_threshold: 0.5 # refcocog: 0.6
12
+ color: [255, 0, 0] # red
13
+ visual_prompt_type: ['circle', 'blur']
14
+ min_area_ratio: 0.2
15
+ bg_cls: ['ground', 'land', 'grass', 'tree', 'building',
16
+ 'wall', 'sky', 'lake', 'water', 'river', 'sea',
17
+ 'railway', 'railroad', 'helmet', 'cloud', 'house',
18
+ 'mountain', 'ocean', 'road', 'rock', 'street',
19
+ 'valley', 'bridge']
20
+
21
+ test:
22
+ algo: "car"
23
+ ds_name: "refcoco+"
24
+ seg_mode: "refer"
25
+ split: 'val'
26
+ data_root: "$YOUR_DATA_DIR"
27
+ output_path: "./outputs/"
28
+ prompts_augment: False
29
+ use_pseudo: True
30
+
31
+ sentence_process:
32
+ mixing_alpha: 0.
33
+
34
+ save_path: "./outputs"
configs/refcoco.yaml ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ clip:
2
+ semantic_clip_model_name: 'ViT-B/16'
3
+ semantic_pretrained_data: 'openai'
4
+ clip_model_name: "ViT-B/16"
5
+ pretrained_data: 'openai'
6
+
7
+ car:
8
+ iom_thres: 0.5
9
+ mask_threshold: 0.5
10
+ confidence_threshold: 0.3
11
+ clipes_threshold: 0.5
12
+ color: [255, 0, 0] # red
13
+ visual_prompt_type: ['circle']
14
+ min_area_ratio: 0.2
15
+ bg_cls: ['ground', 'land', 'grass', 'tree', 'building',
16
+ 'wall', 'sky', 'lake', 'water', 'river', 'sea',
17
+ 'railway', 'railroad', 'helmet', 'cloud', 'house',
18
+ 'mountain', 'ocean', 'road', 'rock', 'street',
19
+ 'valley', 'bridge']
20
+
21
+ test:
22
+ algo: "car"
23
+ ds_name: "refcoco"
24
+ seg_mode: "refer"
25
+ split: 'val'
26
+ data_root: "$YOUR_DATA_DIR"
27
+ output_path: "./outputs/"
28
+ prompts_augment: False
29
+ use_pseudo: True
30
+
31
+ sentence_process:
32
+ mixing_alpha: 0.
33
+
34
+ save_path: "./outputs"
configs/refcocog.yaml ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ clip:
2
+ semantic_clip_model_name: 'ViT-B/16'
3
+ semantic_pretrained_data: 'openai'
4
+ clip_model_name: "ViT-B/16"
5
+ pretrained_data: 'openai'
6
+
7
+ car:
8
+ iom_thres: 0.5
9
+ mask_threshold: 0.5
10
+ confidence_threshold: 0.1
11
+ clipes_threshold: 0.6
12
+ color: [255, 0, 0] # red
13
+ visual_prompt_type: ['circle', 'blur']
14
+ min_area_ratio: 0.2
15
+ bg_cls: ['ground', 'land', 'grass', 'tree', 'building',
16
+ 'wall', 'sky', 'lake', 'water', 'river', 'sea',
17
+ 'railway', 'railroad', 'helmet', 'cloud', 'house',
18
+ 'mountain', 'ocean', 'road', 'rock', 'street',
19
+ 'valley', 'bridge']
20
+
21
+ test:
22
+ algo: "car"
23
+ ds_name: "refcoco+"
24
+ seg_mode: "refer"
25
+ splitby: 'umd'
26
+ split: 'val'
27
+ data_root: "$YOUR_DATA_DIR"
28
+ output_path: "./outputs/"
29
+ prompts_augment: False
30
+ use_pseudo: True
31
+
32
+ sentence_process:
33
+ mixing_alpha: 0.
34
+
35
+ save_path: "./outputs"
36
+
37
+
configs/voc.yaml ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ clip:
2
+ semantic_clip_model_name: 'ViT-L/14'
3
+ semantic_pretrained_data: 'openai'
4
+ clip_model_name: "ViT-B/16"
5
+ pretrained_data: 'openai'
6
+
7
+ car:
8
+ iom_thres: 0.6
9
+ mask_threshold: 0.4
10
+ min_area_ratio: 0.2
11
+ confidence_threshold: 0.6 # 0.2
12
+ clipes_threshold: 0.4
13
+ visualize: False
14
+ visual_prompt_type: ['circle', 'blur']
15
+ semantic_templates: ['a clean origami {}.',
16
+ 'a photo of a {}.',
17
+ 'This is a photo of a {}',
18
+ 'There is a {} in the scene',
19
+ 'There is the {} in the scene',
20
+ 'a photo of a {} in the scene',
21
+ 'a photo of a small {}.',
22
+ 'a photo of a medium {}.',
23
+ 'a photo of a large {}.',
24
+ 'This is a photo of a small {}.',
25
+ 'This is a photo of a medium {}.',
26
+ 'This is a photo of a large {}.',
27
+ 'There is a small {} in the scene.',
28
+ 'There is a medium {} in the scene.',
29
+ 'There is a large {} in the scene.']
30
+
31
+ bg_cls: ['ground', 'land', 'grass', 'tree', 'building',
32
+ 'wall', 'sky', 'lake', 'water', 'river', 'sea',
33
+ 'railway', 'railroad', 'helmet', 'cloud', 'house',
34
+ 'mountain', 'ocean', 'road', 'rock', 'street',
35
+ 'valley', 'bridge']
36
+
37
+ # SAM is activated only if test.use_pseudo is False
38
+ sam:
39
+ model_dir: "$YOUR_SAM_MODEL_DIR"
40
+ sam_checkpoint: "$YOUR_SAM_MODEL_DIR/sam_hq_vit_h.pth"
41
+ model_type: "vit_h"
42
+ min_pred_threshold: 0.05
43
+ points_per_side:
44
+ pred_iou_thresh: 0.88
45
+ stability_score_thresh: 0.95
46
+ box_nms_thresh: 0.7
47
+
48
+ test:
49
+ algo: "car"
50
+ ds_name: "voc"
51
+ seg_mode: "semantic"
52
+ split: 'val'
53
+ data_root: "$YOUR_DATA_DIR"
54
+ # You need to extract the sam mask for the ADE dataset if use_pseudo=False
55
+ sam_mask_root: "$YOUR_SAM_MASK_DIR"
56
+ output_path: "./outputs/"
57
+ use_pseudo: True
58
+ n_class: 21
59
+ num_chunks: 1
60
+ chunk_index: 0
61
+ ignore_background: False
62
+
63
+ save_path: "./outputs"
data/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
data/ade.py ADDED
@@ -0,0 +1,544 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """ADE20K dataset."""
17
+
18
+ import os
19
+ import numpy as np
20
+ from PIL import Image
21
+ import torch
22
+
23
+
24
+ ADE_CLASSES = [
25
+ 'wall',
26
+ 'building, edifice',
27
+ 'sky',
28
+ 'floor, flooring',
29
+ 'tree',
30
+ 'ceiling',
31
+ 'road, route',
32
+ 'bed',
33
+ 'windowpane, window',
34
+ 'grass',
35
+ 'cabinet',
36
+ 'sidewalk, pavement',
37
+ 'person, individual, someone, somebody, mortal, soul',
38
+ 'earth, ground',
39
+ 'door, double, door',
40
+ 'table',
41
+ 'mountain, mount',
42
+ 'plant, flora, plant, life',
43
+ 'curtain, drape, drapery, mantle, pall',
44
+ 'chair',
45
+ 'car, auto, automobile, machine, motorcar',
46
+ 'water',
47
+ 'painting, picture',
48
+ 'sofa, couch, lounge',
49
+ 'shelf',
50
+ 'house',
51
+ 'sea',
52
+ 'mirror',
53
+ 'rug, carpet, carpeting',
54
+ 'field',
55
+ 'armchair',
56
+ 'seat',
57
+ 'fence, fencing',
58
+ 'desk',
59
+ 'rock, stone',
60
+ 'wardrobe, closet, press',
61
+ 'lamp',
62
+ 'bathtub, bathing, tub, bath, tub',
63
+ 'railing, rail',
64
+ 'cushion',
65
+ 'base, pedestal, stand',
66
+ 'box',
67
+ 'column, pillar',
68
+ 'signboard, sign',
69
+ 'chest, of, drawers, chest, bureau, dresser',
70
+ 'counter',
71
+ 'sand',
72
+ 'sink',
73
+ 'skyscraper',
74
+ 'fireplace, hearth, open, fireplace',
75
+ 'refrigerator, icebox',
76
+ 'grandstand, covered, stand',
77
+ 'path',
78
+ 'stairs, steps',
79
+ 'runway',
80
+ 'case, display, case, showcase, vitrine',
81
+ 'pool, table, billiard, table, snooker, table',
82
+ 'pillow',
83
+ 'screen, door, screen',
84
+ 'stairway, staircase',
85
+ 'river',
86
+ 'bridge, span',
87
+ 'bookcase',
88
+ 'blind, screen',
89
+ 'coffee, table, cocktail, table',
90
+ 'toilet, can, commode, crapper, pot, potty, stool, throne',
91
+ 'flower',
92
+ 'book',
93
+ 'hill',
94
+ 'bench',
95
+ 'countertop',
96
+ 'stove, kitchen, stove, range, kitchen, range, cooking, stove',
97
+ 'palm, palm, tree',
98
+ 'kitchen, island',
99
+ (
100
+ 'computer, computing, machine, computing, device, data, processor,'
101
+ ' electronic, computer, information, processing, system'
102
+ ),
103
+ 'swivel, chair',
104
+ 'boat',
105
+ 'bar',
106
+ 'arcade, machine',
107
+ 'hovel, hut, hutch, shack, shanty',
108
+ (
109
+ 'bus, autobus, coach, charabanc, double-decker, jitney, motorbus,'
110
+ ' motorcoach, omnibus, passenger, vehicle'
111
+ ),
112
+ 'towel',
113
+ 'light, light, source',
114
+ 'truck, motortruck',
115
+ 'tower',
116
+ 'chandelier, pendant, pendent',
117
+ 'awning, sunshade, sunblind',
118
+ 'streetlight, street, lamp',
119
+ 'booth, cubicle, stall, kiosk',
120
+ (
121
+ 'television, television, receiver, television, set, tv, tv, set, idiot,'
122
+ ' box, boob, tube, telly, goggle, box'
123
+ ),
124
+ 'airplane, aeroplane, plane',
125
+ 'dirt, track',
126
+ 'apparel, wearing, apparel, dress, clothes',
127
+ 'pole',
128
+ 'land, ground, soil',
129
+ 'bannister, banister, balustrade, balusters, handrail',
130
+ 'escalator, moving, staircase, moving, stairway',
131
+ 'ottoman, pouf, pouffe, puff, hassock',
132
+ 'bottle',
133
+ 'buffet, counter, sideboard',
134
+ 'poster, posting, placard, notice, bill, card',
135
+ 'stage',
136
+ 'van',
137
+ 'ship',
138
+ 'fountain',
139
+ 'conveyer, belt, conveyor, belt, conveyer, conveyor, transporter',
140
+ 'canopy',
141
+ 'washer, automatic, washer, washing, machine',
142
+ 'plaything, toy',
143
+ 'swimming, pool, swimming, bath, natatorium',
144
+ 'stool',
145
+ 'barrel, cask',
146
+ 'basket, handbasket',
147
+ 'waterfall, falls',
148
+ 'tent, collapsible, shelter',
149
+ 'bag',
150
+ 'minibike, motorbike',
151
+ 'cradle',
152
+ 'oven',
153
+ 'ball',
154
+ 'food, solid, food',
155
+ 'step, stair',
156
+ 'tank, storage, tank',
157
+ 'trade, name, brand, name, brand, marque',
158
+ 'microwave, microwave, oven',
159
+ 'pot, flowerpot',
160
+ 'animal, animate, being, beast, brute, creature, fauna',
161
+ 'bicycle, bike, wheel, cycle',
162
+ 'lake',
163
+ 'dishwasher, dish, washer, dishwashing, machine',
164
+ 'screen, silver, screen, projection, screen',
165
+ 'blanket, cover',
166
+ 'sculpture',
167
+ 'hood, exhaust, hood',
168
+ 'sconce',
169
+ 'vase',
170
+ 'traffic, light, traffic, signal, stoplight',
171
+ 'tray',
172
+ (
173
+ 'ashcan, trash, can, garbage, can, wastebin, ash, bin, ash-bin, ashbin,'
174
+ ' dustbin, trash, barrel, trash, bin'
175
+ ),
176
+ 'fan',
177
+ 'pier, wharf, wharfage, dock',
178
+ 'crt, screen',
179
+ 'plate',
180
+ 'monitor, monitoring, device',
181
+ 'bulletin, board, notice, board',
182
+ 'shower',
183
+ 'radiator',
184
+ 'glass, drinking, glass',
185
+ 'clock',
186
+ 'flag',
187
+ ]
188
+
189
+
190
+ ADE_STUFF_CLASS = [
191
+ 'wall',
192
+ 'sky',
193
+ 'floor, flooring',
194
+ 'tree',
195
+ 'ceiling',
196
+ 'road, route',
197
+ 'grass',
198
+ 'earth, ground',
199
+ 'mountain, mount',
200
+ 'plant, flora, plant, life',
201
+ 'water',
202
+ 'sea',
203
+ 'field',
204
+ 'sand',
205
+ 'skyscraper',
206
+ 'path',
207
+ 'river',
208
+ 'bridge, span',
209
+ 'flower',
210
+ 'hill',
211
+ 'land, ground, soil',
212
+ 'dirt, track',
213
+ 'apparel, wearing, apparel, dress, clothes',
214
+ 'lake',
215
+ 'waterfall, falls',
216
+ ]
217
+
218
+ ADE_THING_CLASS = [
219
+ 'building, edifice',
220
+ 'bed',
221
+ 'windowpane, window',
222
+ 'cabinet',
223
+ 'sidewalk, pavement',
224
+ 'person, individual, someone, somebody, mortal, soul',
225
+ 'door, double, door',
226
+ 'table',
227
+ 'curtain, drape, drapery, mantle, pall',
228
+ 'chair',
229
+ 'car, auto, automobile, machine, motorcar',
230
+ 'painting, picture',
231
+ 'sofa, couch, lounge',
232
+ 'shelf',
233
+ 'house',
234
+ 'mirror',
235
+ 'rug, carpet, carpeting',
236
+ 'armchair',
237
+ 'seat',
238
+ 'fence, fencing',
239
+ 'desk',
240
+ 'rock, stone',
241
+ 'wardrobe, closet, press',
242
+ 'lamp',
243
+ 'bathtub, bathing, tub, bath, tub',
244
+ 'railing, rail',
245
+ 'cushion',
246
+ 'base, pedestal, stand',
247
+ 'box',
248
+ 'column, pillar',
249
+ 'signboard, sign',
250
+ 'chest, of, drawers, chest, bureau, dresser',
251
+ 'counter',
252
+ 'sink',
253
+ 'fireplace, hearth, open, fireplace',
254
+ 'refrigerator, icebox',
255
+ 'grandstand, covered, stand',
256
+ 'stairs, steps',
257
+ 'runway',
258
+ 'case, display, case, showcase, vitrine',
259
+ 'pool, table, billiard, table, snooker, table',
260
+ 'pillow',
261
+ 'screen, door, screen',
262
+ 'stairway, staircase',
263
+ 'bookcase',
264
+ 'blind, screen',
265
+ 'coffee, table, cocktail, table',
266
+ 'toilet, can, commode, crapper, pot, potty, stool, throne',
267
+ 'book',
268
+ 'bench',
269
+ 'countertop',
270
+ 'stove, kitchen, stove, range, kitchen, range, cooking, stove',
271
+ 'palm, palm, tree',
272
+ 'kitchen, island',
273
+ (
274
+ 'computer, computing, machine, computing, device, data, processor,'
275
+ ' electronic, computer, information, processing, system'
276
+ ),
277
+ 'swivel, chair',
278
+ 'boat',
279
+ 'bar',
280
+ 'arcade, machine',
281
+ 'hovel, hut, hutch, shack, shanty',
282
+ (
283
+ 'bus, autobus, coach, charabanc, double-decker, jitney, motorbus,'
284
+ ' motorcoach, omnibus, passenger, vehicle'
285
+ ),
286
+ 'towel',
287
+ 'light, light, source',
288
+ 'truck, motortruck',
289
+ 'tower',
290
+ 'chandelier, pendant, pendent',
291
+ 'awning, sunshade, sunblind',
292
+ 'streetlight, street, lamp',
293
+ 'booth, cubicle, stall, kiosk',
294
+ (
295
+ 'television, television, receiver, television, set, tv, tv, set, idiot,'
296
+ ' box, boob, tube, telly, goggle, box'
297
+ ),
298
+ 'airplane, aeroplane, plane',
299
+ 'pole',
300
+ 'bannister, banister, balustrade, balusters, handrail',
301
+ 'escalator, moving, staircase, moving, stairway',
302
+ 'ottoman, pouf, pouffe, puff, hassock',
303
+ 'bottle',
304
+ 'buffet, counter, sideboard',
305
+ 'poster, posting, placard, notice, bill, card',
306
+ 'stage',
307
+ 'van',
308
+ 'ship',
309
+ 'fountain',
310
+ 'conveyer, belt, conveyor, belt, conveyer, conveyor, transporter',
311
+ 'canopy',
312
+ 'washer, automatic, washer, washing, machine',
313
+ 'plaything, toy',
314
+ 'swimming, pool, swimming, bath, natatorium',
315
+ 'stool',
316
+ 'barrel, cask',
317
+ 'basket, handbasket',
318
+ 'tent, collapsible, shelter',
319
+ 'bag',
320
+ 'minibike, motorbike',
321
+ 'cradle',
322
+ 'oven',
323
+ 'ball',
324
+ 'food, solid, food',
325
+ 'step, stair',
326
+ 'tank, storage, tank',
327
+ 'trade, name, brand, name, brand, marque',
328
+ 'microwave, microwave, oven',
329
+ 'pot, flowerpot',
330
+ 'animal, animate, being, beast, brute, creature, fauna',
331
+ 'bicycle, bike, wheel, cycle',
332
+ 'dishwasher, dish, washer, dishwashing, machine',
333
+ 'screen, silver, screen, projection, screen',
334
+ 'blanket, cover',
335
+ 'sculpture',
336
+ 'hood, exhaust, hood',
337
+ 'sconce',
338
+ 'vase',
339
+ 'traffic, light, traffic, signal, stoplight',
340
+ 'tray',
341
+ (
342
+ 'ashcan, trash, can, garbage, can, wastebin, ash, bin, ash-bin, ashbin,'
343
+ ' dustbin, trash, barrel, trash, bin'
344
+ ),
345
+ 'fan',
346
+ 'pier, wharf, wharfage, dock',
347
+ 'crt, screen',
348
+ 'plate',
349
+ 'monitor, monitoring, device',
350
+ 'bulletin, board, notice, board',
351
+ 'shower',
352
+ 'radiator',
353
+ 'glass, drinking, glass',
354
+ 'clock',
355
+ 'flag',
356
+ ]
357
+
358
+
359
+ ADE_STUFF_CLASS_ID = [
360
+ 0,
361
+ 2,
362
+ 3,
363
+ 4,
364
+ 5,
365
+ 6,
366
+ 9,
367
+ 13,
368
+ 16,
369
+ 17,
370
+ 21,
371
+ 26,
372
+ 29,
373
+ 46,
374
+ 48,
375
+ 52,
376
+ 60,
377
+ 61,
378
+ 66,
379
+ 68,
380
+ 94,
381
+ 91,
382
+ 92,
383
+ 128,
384
+ 113,
385
+ ]
386
+
387
+ ADE_THING_CLASS_ID = [
388
+ 1,
389
+ 7,
390
+ 8,
391
+ 10,
392
+ 11,
393
+ 12,
394
+ 14,
395
+ 15,
396
+ 18,
397
+ 19,
398
+ 20,
399
+ 22,
400
+ 23,
401
+ 24,
402
+ 25,
403
+ 27,
404
+ 28,
405
+ 30,
406
+ 31,
407
+ 32,
408
+ 33,
409
+ 34,
410
+ 35,
411
+ 36,
412
+ 37,
413
+ 38,
414
+ 39,
415
+ 40,
416
+ 41,
417
+ 42,
418
+ 43,
419
+ 44,
420
+ 45,
421
+ 47,
422
+ 49,
423
+ 50,
424
+ 51,
425
+ 53,
426
+ 54,
427
+ 55,
428
+ 56,
429
+ 57,
430
+ 58,
431
+ 59,
432
+ 62,
433
+ 63,
434
+ 64,
435
+ 65,
436
+ 67,
437
+ 69,
438
+ 70,
439
+ 71,
440
+ 72,
441
+ 73,
442
+ 74,
443
+ 75,
444
+ 76,
445
+ 77,
446
+ 78,
447
+ 79,
448
+ 80,
449
+ 81,
450
+ 82,
451
+ 83,
452
+ 84,
453
+ 85,
454
+ 86,
455
+ 87,
456
+ 88,
457
+ 89,
458
+ 90,
459
+ 93,
460
+ 95,
461
+ 96,
462
+ 97,
463
+ 98,
464
+ 99,
465
+ 100,
466
+ 101,
467
+ 102,
468
+ 103,
469
+ 104,
470
+ 105,
471
+ 106,
472
+ 107,
473
+ 108,
474
+ 109,
475
+ 110,
476
+ 111,
477
+ 112,
478
+ 114,
479
+ 115,
480
+ 116,
481
+ 117,
482
+ 118,
483
+ 119,
484
+ 120,
485
+ 121,
486
+ 122,
487
+ 123,
488
+ 124,
489
+ 125,
490
+ 126,
491
+ 127,
492
+ 129,
493
+ 130,
494
+ 131,
495
+ 132,
496
+ 133,
497
+ 134,
498
+ 135,
499
+ 136,
500
+ 137,
501
+ 138,
502
+ 139,
503
+ 140,
504
+ 141,
505
+ 142,
506
+ 143,
507
+ 144,
508
+ 145,
509
+ 146,
510
+ 147,
511
+ 148,
512
+ 149,
513
+ ]
514
+
515
+
516
+ class ADEDataset(torch.utils.data.Dataset):
517
+ """ADE dataset."""
518
+
519
+ def __init__(self, root, split='validation', transform=None):
520
+ """Construct ADE dataset.
521
+
522
+ Args:
523
+ root (string): Root directory where images are downloaded.
524
+ split (string): The split of the dataset.
525
+ transform (callable, optional): Optional transform to be applied on a
526
+ sample.
527
+ """
528
+ self.root = root
529
+ self.image_dir = os.path.join(root, 'images', split)
530
+ self.ann_dir = os.path.join(root, 'annotations', split)
531
+ self.images = os.listdir(self.image_dir)
532
+ self.transform = transform
533
+
534
+ def __getitem__(self, index):
535
+ img_path = os.path.join(self.image_dir, self.images[index])
536
+ img = Image.open(img_path).convert('RGB')
537
+ img = np.asarray(img)
538
+ idx = self.images[index].split('.')[0]
539
+ ann_path = os.path.join(self.ann_dir, f'{idx}.png')
540
+ ann = np.asarray(Image.open(ann_path), dtype=np.int32)
541
+ return img, img_path, ann, idx
542
+
543
+ def __len__(self):
544
+ return len(self.images)
data/ade847.py ADDED
@@ -0,0 +1,1827 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """ADE-847 dataset."""
17
+
18
+ import os
19
+ import numpy as np
20
+ from PIL import Image
21
+ # pylint: disable=g-importing-member
22
+ from torch.utils.data import Dataset
23
+
24
+
25
+ ADE_847_CLASSES = [
26
+ 'wall',
27
+ 'building, edifice',
28
+ 'sky',
29
+ 'tree',
30
+ 'road, route',
31
+ 'floor, flooring',
32
+ 'ceiling',
33
+ 'bed',
34
+ 'sidewalk, pavement',
35
+ 'earth, ground',
36
+ 'cabinet',
37
+ 'person, individual, someone, somebody, mortal, soul',
38
+ 'grass',
39
+ 'windowpane, window',
40
+ 'car, auto, automobile, machine, motorcar',
41
+ 'mountain, mount',
42
+ 'plant, flora, plant life',
43
+ 'table',
44
+ 'chair',
45
+ 'curtain, drape, drapery, mantle, pall',
46
+ 'door',
47
+ 'sofa, couch, lounge',
48
+ 'sea',
49
+ 'painting, picture',
50
+ 'water',
51
+ 'mirror',
52
+ 'house',
53
+ 'rug, carpet, carpeting',
54
+ 'shelf',
55
+ 'armchair',
56
+ 'fence, fencing',
57
+ 'field',
58
+ 'lamp',
59
+ 'rock, stone',
60
+ 'seat',
61
+ 'river',
62
+ 'desk',
63
+ 'bathtub, bathing tub, bath, tub',
64
+ 'railing, rail',
65
+ 'signboard, sign',
66
+ 'cushion',
67
+ 'path',
68
+ 'work surface',
69
+ 'stairs, steps',
70
+ 'column, pillar',
71
+ 'sink',
72
+ 'wardrobe, closet, press',
73
+ 'snow',
74
+ 'refrigerator, icebox',
75
+ 'base, pedestal, stand',
76
+ 'bridge, span',
77
+ 'blind, screen',
78
+ 'runway',
79
+ 'cliff, drop, drop-off',
80
+ 'sand',
81
+ 'fireplace, hearth, open fireplace',
82
+ 'pillow',
83
+ 'screen door, screen',
84
+ 'toilet, can, commode, crapper, pot, potty, stool, throne',
85
+ 'skyscraper',
86
+ 'grandstand, covered stand',
87
+ 'box',
88
+ 'pool table, billiard table, snooker table',
89
+ 'palm, palm tree',
90
+ 'double door',
91
+ 'coffee table, cocktail table',
92
+ 'counter',
93
+ 'countertop',
94
+ 'chest of drawers, chest, bureau, dresser',
95
+ 'kitchen island',
96
+ 'boat',
97
+ 'waterfall, falls',
98
+ 'stove, kitchen stove, range, kitchen range, cooking stove',
99
+ 'flower',
100
+ 'bookcase',
101
+ 'controls',
102
+ 'book',
103
+ 'stairway, staircase',
104
+ 'streetlight, street lamp',
105
+ (
106
+ 'computer, computing machine, computing device, data processor,'
107
+ ' electronic computer, information processing system'
108
+ ),
109
+ (
110
+ 'bus, autobus, coach, charabanc, double-decker, jitney, motorbus,'
111
+ ' motorcoach, omnibus, passenger vehicle'
112
+ ),
113
+ 'swivel chair',
114
+ 'light, light source',
115
+ 'bench',
116
+ 'case, display case, showcase, vitrine',
117
+ 'towel',
118
+ 'fountain',
119
+ 'embankment',
120
+ (
121
+ 'television receiver, television, television set, tv, tv set, idiot'
122
+ ' box, boob tube, telly, goggle box'
123
+ ),
124
+ 'van',
125
+ 'hill',
126
+ 'awning, sunshade, sunblind',
127
+ 'poster, posting, placard, notice, bill, card',
128
+ 'truck, motortruck',
129
+ 'airplane, aeroplane, plane',
130
+ 'pole',
131
+ 'tower',
132
+ 'court',
133
+ 'ball',
134
+ 'aircraft carrier, carrier, flattop, attack aircraft carrier',
135
+ 'buffet, counter, sideboard',
136
+ 'hovel, hut, hutch, shack, shanty',
137
+ 'apparel, wearing apparel, dress, clothes',
138
+ 'minibike, motorbike',
139
+ 'animal, animate being, beast, brute, creature, fauna',
140
+ 'chandelier, pendant, pendent',
141
+ 'step, stair',
142
+ 'booth, cubicle, stall, kiosk',
143
+ 'bicycle, bike, wheel, cycle',
144
+ 'doorframe, doorcase',
145
+ 'sconce',
146
+ 'pond',
147
+ 'trade name, brand name, brand, marque',
148
+ 'bannister, banister, balustrade, balusters, handrail',
149
+ 'bag',
150
+ 'traffic light, traffic signal, stoplight',
151
+ 'gazebo',
152
+ 'escalator, moving staircase, moving stairway',
153
+ 'land, ground, soil',
154
+ 'board, plank',
155
+ 'arcade machine',
156
+ 'eiderdown, duvet, continental quilt',
157
+ 'bar',
158
+ 'stall, stand, sales booth',
159
+ 'playground',
160
+ 'ship',
161
+ 'ottoman, pouf, pouffe, puff, hassock',
162
+ (
163
+ 'ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin,'
164
+ ' dustbin, trash barrel, trash bin'
165
+ ),
166
+ 'bottle',
167
+ 'cradle',
168
+ 'pot, flowerpot',
169
+ 'conveyer belt, conveyor belt, conveyer, conveyor, transporter',
170
+ 'train, railroad train',
171
+ 'stool',
172
+ 'lake',
173
+ 'tank, storage tank',
174
+ 'ice, water ice',
175
+ 'basket, handbasket',
176
+ 'manhole',
177
+ 'tent, collapsible shelter',
178
+ 'canopy',
179
+ 'microwave, microwave oven',
180
+ 'barrel, cask',
181
+ 'dirt track',
182
+ 'beam',
183
+ 'dishwasher, dish washer, dishwashing machine',
184
+ 'plate',
185
+ 'screen, crt screen',
186
+ 'ruins',
187
+ 'washer, automatic washer, washing machine',
188
+ 'blanket, cover',
189
+ 'plaything, toy',
190
+ 'food, solid food',
191
+ 'screen, silver screen, projection screen',
192
+ 'oven',
193
+ 'stage',
194
+ 'beacon, lighthouse, beacon light, pharos',
195
+ 'umbrella',
196
+ 'sculpture',
197
+ 'aqueduct',
198
+ 'container',
199
+ 'scaffolding, staging',
200
+ 'hood, exhaust hood',
201
+ 'curb, curbing, kerb',
202
+ 'roller coaster',
203
+ 'horse, equus caballus',
204
+ 'catwalk',
205
+ 'glass, drinking glass',
206
+ 'vase',
207
+ 'central reservation',
208
+ 'carousel',
209
+ 'radiator',
210
+ 'closet',
211
+ 'machine',
212
+ 'pier, wharf, wharfage, dock',
213
+ 'fan',
214
+ 'inflatable bounce game',
215
+ 'pitch',
216
+ 'paper',
217
+ 'arcade, colonnade',
218
+ 'hot tub',
219
+ 'helicopter',
220
+ 'tray',
221
+ 'partition, divider',
222
+ 'vineyard',
223
+ 'bowl',
224
+ 'bullring',
225
+ 'flag',
226
+ 'pot',
227
+ 'footbridge, overcrossing, pedestrian bridge',
228
+ 'shower',
229
+ 'bag, traveling bag, travelling bag, grip, suitcase',
230
+ 'bulletin board, notice board',
231
+ 'confessional booth',
232
+ 'trunk, tree trunk, bole',
233
+ 'forest',
234
+ 'elevator door',
235
+ 'laptop, laptop computer',
236
+ 'instrument panel',
237
+ 'bucket, pail',
238
+ 'tapestry, tapis',
239
+ 'platform',
240
+ 'jacket',
241
+ 'gate',
242
+ 'monitor, monitoring device',
243
+ 'telephone booth, phone booth, call box, telephone box, telephone kiosk',
244
+ 'spotlight, spot',
245
+ 'ring',
246
+ 'control panel',
247
+ 'blackboard, chalkboard',
248
+ 'air conditioner, air conditioning',
249
+ 'chest',
250
+ 'clock',
251
+ 'sand dune',
252
+ 'pipe, pipage, piping',
253
+ 'vault',
254
+ 'table football',
255
+ 'cannon',
256
+ 'swimming pool, swimming bath, natatorium',
257
+ 'fluorescent, fluorescent fixture',
258
+ 'statue',
259
+ 'loudspeaker, speaker, speaker unit, loudspeaker system, speaker system',
260
+ 'exhibitor',
261
+ 'ladder',
262
+ 'carport',
263
+ 'dam',
264
+ 'pulpit',
265
+ 'skylight, fanlight',
266
+ 'water tower',
267
+ 'grill, grille, grillwork',
268
+ 'display board',
269
+ 'pane, pane of glass, window glass',
270
+ 'rubbish, trash, scrap',
271
+ 'ice rink',
272
+ 'fruit',
273
+ 'patio',
274
+ 'vending machine',
275
+ 'telephone, phone, telephone set',
276
+ 'net',
277
+ 'backpack, back pack, knapsack, packsack, rucksack, haversack',
278
+ 'jar',
279
+ 'track',
280
+ 'magazine',
281
+ 'shutter',
282
+ 'roof',
283
+ 'banner, streamer',
284
+ 'landfill',
285
+ 'post',
286
+ 'altarpiece, reredos',
287
+ 'hat, chapeau, lid',
288
+ 'arch, archway',
289
+ 'table game',
290
+ 'bag, handbag, pocketbook, purse',
291
+ 'document, written document, papers',
292
+ 'dome',
293
+ 'pier',
294
+ 'shanties',
295
+ 'forecourt',
296
+ 'crane',
297
+ 'dog, domestic dog, canis familiaris',
298
+ 'piano, pianoforte, forte-piano',
299
+ 'drawing',
300
+ 'cabin',
301
+ 'ad, advertisement, advertizement, advertising, advertizing, advert',
302
+ 'amphitheater, amphitheatre, coliseum',
303
+ 'monument',
304
+ 'henhouse',
305
+ 'cockpit',
306
+ 'heater, warmer',
307
+ 'windmill, aerogenerator, wind generator',
308
+ 'pool',
309
+ 'elevator, lift',
310
+ 'decoration, ornament, ornamentation',
311
+ 'labyrinth',
312
+ 'text, textual matter',
313
+ 'printer',
314
+ 'mezzanine, first balcony',
315
+ 'mattress',
316
+ 'straw',
317
+ 'stalls',
318
+ 'patio, terrace',
319
+ 'billboard, hoarding',
320
+ 'bus stop',
321
+ 'trouser, pant',
322
+ 'console table, console',
323
+ 'rack',
324
+ 'notebook',
325
+ 'shrine',
326
+ 'pantry',
327
+ 'cart',
328
+ 'steam shovel',
329
+ 'porch',
330
+ 'postbox, mailbox, letter box',
331
+ 'figurine, statuette',
332
+ 'recycling bin',
333
+ 'folding screen',
334
+ 'telescope',
335
+ 'deck chair, beach chair',
336
+ 'kennel',
337
+ 'coffee maker',
338
+ "altar, communion table, lord's table",
339
+ 'fish',
340
+ 'easel',
341
+ 'artificial golf green',
342
+ 'iceberg',
343
+ 'candlestick, candle holder',
344
+ 'shower stall, shower bath',
345
+ 'television stand',
346
+ (
347
+ 'wall socket, wall plug, electric outlet, electrical outlet, outlet,'
348
+ ' electric receptacle'
349
+ ),
350
+ 'skeleton',
351
+ 'grand piano, grand',
352
+ 'candy, confect',
353
+ 'grille door',
354
+ 'pedestal, plinth, footstall',
355
+ 'jersey, t-shirt, tee shirt',
356
+ 'shoe',
357
+ 'gravestone, headstone, tombstone',
358
+ 'shanty',
359
+ 'structure',
360
+ 'rocking chair, rocker',
361
+ 'bird',
362
+ 'place mat',
363
+ 'tomb',
364
+ 'big top',
365
+ 'gas pump, gasoline pump, petrol pump, island dispenser',
366
+ 'lockers',
367
+ 'cage',
368
+ 'finger',
369
+ 'bleachers',
370
+ 'ferris wheel',
371
+ 'hairdresser chair',
372
+ 'mat',
373
+ 'stands',
374
+ 'aquarium, fish tank, marine museum',
375
+ 'streetcar, tram, tramcar, trolley, trolley car',
376
+ 'napkin, table napkin, serviette',
377
+ 'dummy',
378
+ 'booklet, brochure, folder, leaflet, pamphlet',
379
+ 'sand trap',
380
+ 'shop, store',
381
+ 'table cloth',
382
+ 'service station',
383
+ 'coffin',
384
+ 'drawer',
385
+ 'cages',
386
+ 'slot machine, coin machine',
387
+ 'balcony',
388
+ 'volleyball court',
389
+ 'table tennis',
390
+ 'control table',
391
+ 'shirt',
392
+ 'merchandise, ware, product',
393
+ 'railway',
394
+ 'parterre',
395
+ 'chimney',
396
+ 'can, tin, tin can',
397
+ 'tanks',
398
+ 'fabric, cloth, material, textile',
399
+ 'alga, algae',
400
+ 'system',
401
+ 'map',
402
+ 'greenhouse',
403
+ 'mug',
404
+ 'barbecue',
405
+ 'trailer',
406
+ 'toilet tissue, toilet paper, bathroom tissue',
407
+ 'organ',
408
+ 'dishrag, dishcloth',
409
+ 'island',
410
+ 'keyboard',
411
+ 'trench',
412
+ 'basket, basketball hoop, hoop',
413
+ 'steering wheel, wheel',
414
+ 'pitcher, ewer',
415
+ 'goal',
416
+ 'bread, breadstuff, staff of life',
417
+ 'beds',
418
+ 'wood',
419
+ 'file cabinet',
420
+ 'newspaper, paper',
421
+ 'motorboat',
422
+ 'rope',
423
+ 'guitar',
424
+ 'rubble',
425
+ 'scarf',
426
+ 'barrels',
427
+ 'cap',
428
+ 'leaves',
429
+ 'control tower',
430
+ 'dashboard',
431
+ 'bandstand',
432
+ 'lectern',
433
+ 'switch, electric switch, electrical switch',
434
+ 'baseboard, mopboard, skirting board',
435
+ 'shower room',
436
+ 'smoke',
437
+ 'faucet, spigot',
438
+ 'bulldozer',
439
+ 'saucepan',
440
+ 'shops',
441
+ 'meter',
442
+ 'crevasse',
443
+ 'gear',
444
+ 'candelabrum, candelabra',
445
+ 'sofa bed',
446
+ 'tunnel',
447
+ 'pallet',
448
+ 'wire, conducting wire',
449
+ 'kettle, boiler',
450
+ 'bidet',
451
+ (
452
+ 'baby buggy, baby carriage, carriage, perambulator, pram, stroller,'
453
+ ' go-cart, pushchair, pusher'
454
+ ),
455
+ 'music stand',
456
+ 'pipe, tube',
457
+ 'cup',
458
+ 'parking meter',
459
+ 'ice hockey rink',
460
+ 'shelter',
461
+ 'weeds',
462
+ 'temple',
463
+ 'patty, cake',
464
+ 'ski slope',
465
+ 'panel',
466
+ 'wallet',
467
+ 'wheel',
468
+ 'towel rack, towel horse',
469
+ 'roundabout',
470
+ 'canister, cannister, tin',
471
+ 'rod',
472
+ 'soap dispenser',
473
+ 'bell',
474
+ 'canvas',
475
+ 'box office, ticket office, ticket booth',
476
+ 'teacup',
477
+ 'trellis',
478
+ 'workbench',
479
+ 'valley, vale',
480
+ 'toaster',
481
+ 'knife',
482
+ 'podium',
483
+ 'ramp',
484
+ 'tumble dryer',
485
+ 'fireplug, fire hydrant, plug',
486
+ 'gym shoe, sneaker, tennis shoe',
487
+ 'lab bench',
488
+ 'equipment',
489
+ 'rocky formation',
490
+ 'plastic',
491
+ 'calendar',
492
+ 'caravan',
493
+ 'check-in-desk',
494
+ 'ticket counter',
495
+ 'brush',
496
+ 'mill',
497
+ 'covered bridge',
498
+ 'bowling alley',
499
+ 'hanger',
500
+ 'excavator',
501
+ 'trestle',
502
+ 'revolving door',
503
+ 'blast furnace',
504
+ 'scale, weighing machine',
505
+ 'projector',
506
+ 'soap',
507
+ 'locker',
508
+ 'tractor',
509
+ 'stretcher',
510
+ 'frame',
511
+ 'grating',
512
+ 'alembic',
513
+ 'candle, taper, wax light',
514
+ 'barrier',
515
+ 'cardboard',
516
+ 'cave',
517
+ 'puddle',
518
+ 'tarp',
519
+ 'price tag',
520
+ 'watchtower',
521
+ 'meters',
522
+ (
523
+ 'light bulb, lightbulb, bulb, incandescent lamp, electric light,'
524
+ ' electric-light bulb'
525
+ ),
526
+ 'tracks',
527
+ 'hair dryer',
528
+ 'skirt',
529
+ 'viaduct',
530
+ 'paper towel',
531
+ 'coat',
532
+ 'sheet',
533
+ 'fire extinguisher, extinguisher, asphyxiator',
534
+ 'water wheel',
535
+ 'pottery, clayware',
536
+ 'magazine rack',
537
+ 'teapot',
538
+ 'microphone, mike',
539
+ 'support',
540
+ 'forklift',
541
+ 'canyon',
542
+ 'cash register, register',
543
+ 'leaf, leafage, foliage',
544
+ 'remote control, remote',
545
+ 'soap dish',
546
+ 'windshield, windscreen',
547
+ 'cat',
548
+ 'cue, cue stick, pool cue, pool stick',
549
+ 'vent, venthole, vent-hole, blowhole',
550
+ 'videos',
551
+ 'shovel',
552
+ 'eaves',
553
+ 'antenna, aerial, transmitting aerial',
554
+ 'shipyard',
555
+ 'hen, biddy',
556
+ 'traffic cone',
557
+ 'washing machines',
558
+ 'truck crane',
559
+ 'cds',
560
+ 'niche',
561
+ 'scoreboard',
562
+ 'briefcase',
563
+ 'boot',
564
+ 'sweater, jumper',
565
+ 'hay',
566
+ 'pack',
567
+ 'bottle rack',
568
+ 'glacier',
569
+ 'pergola',
570
+ 'building materials',
571
+ 'television camera',
572
+ 'first floor',
573
+ 'rifle',
574
+ 'tennis table',
575
+ 'stadium',
576
+ 'safety belt',
577
+ 'cover',
578
+ 'dish rack',
579
+ 'synthesizer',
580
+ 'pumpkin',
581
+ 'gutter',
582
+ 'fruit stand',
583
+ 'ice floe, floe',
584
+ 'handle, grip, handgrip, hold',
585
+ 'wheelchair',
586
+ 'mousepad, mouse mat',
587
+ 'diploma',
588
+ 'fairground ride',
589
+ 'radio',
590
+ 'hotplate',
591
+ 'junk',
592
+ 'wheelbarrow',
593
+ 'stream',
594
+ 'toll plaza',
595
+ 'punching bag',
596
+ 'trough',
597
+ 'throne',
598
+ 'chair desk',
599
+ 'weighbridge',
600
+ 'extractor fan',
601
+ 'hanging clothes',
602
+ 'dish, dish aerial, dish antenna, saucer',
603
+ 'alarm clock, alarm',
604
+ 'ski lift',
605
+ 'chain',
606
+ 'garage',
607
+ 'mechanical shovel',
608
+ 'wine rack',
609
+ 'tramway',
610
+ 'treadmill',
611
+ 'menu',
612
+ 'block',
613
+ 'well',
614
+ 'witness stand',
615
+ 'branch',
616
+ 'duck',
617
+ 'casserole',
618
+ 'frying pan',
619
+ 'desk organizer',
620
+ 'mast',
621
+ 'spectacles, specs, eyeglasses, glasses',
622
+ 'service elevator',
623
+ 'dollhouse',
624
+ 'hammock',
625
+ 'clothes hanging',
626
+ 'photocopier',
627
+ 'notepad',
628
+ 'golf cart',
629
+ 'footpath',
630
+ 'cross',
631
+ 'baptismal font',
632
+ 'boiler',
633
+ 'skip',
634
+ 'rotisserie',
635
+ 'tables',
636
+ 'water mill',
637
+ 'helmet',
638
+ 'cover curtain',
639
+ 'brick',
640
+ 'table runner',
641
+ 'ashtray',
642
+ 'street box',
643
+ 'stick',
644
+ 'hangers',
645
+ 'cells',
646
+ 'urinal',
647
+ 'centerpiece',
648
+ 'portable fridge',
649
+ 'dvds',
650
+ 'golf club',
651
+ 'skirting board',
652
+ 'water cooler',
653
+ 'clipboard',
654
+ 'camera, photographic camera',
655
+ 'pigeonhole',
656
+ 'chips',
657
+ 'food processor',
658
+ 'post box',
659
+ 'lid',
660
+ 'drum',
661
+ 'blender',
662
+ 'cave entrance',
663
+ 'dental chair',
664
+ 'obelisk',
665
+ 'canoe',
666
+ 'mobile',
667
+ 'monitors',
668
+ 'pool ball',
669
+ 'cue rack',
670
+ 'baggage carts',
671
+ 'shore',
672
+ 'fork',
673
+ 'paper filer',
674
+ 'bicycle rack',
675
+ 'coat rack',
676
+ 'garland',
677
+ 'sports bag',
678
+ 'fish tank',
679
+ 'towel dispenser',
680
+ 'carriage',
681
+ 'brochure',
682
+ 'plaque',
683
+ 'stringer',
684
+ 'iron',
685
+ 'spoon',
686
+ 'flag pole',
687
+ 'toilet brush',
688
+ 'book stand',
689
+ 'water faucet, water tap, tap, hydrant',
690
+ 'ticket office',
691
+ 'broom',
692
+ 'dvd',
693
+ 'ice bucket',
694
+ 'carapace, shell, cuticle, shield',
695
+ 'tureen',
696
+ 'folders',
697
+ 'chess',
698
+ 'root',
699
+ 'sewing machine',
700
+ 'model',
701
+ 'pen',
702
+ 'violin',
703
+ 'sweatshirt',
704
+ 'recycling materials',
705
+ 'mitten',
706
+ 'chopping board, cutting board',
707
+ 'mask',
708
+ 'log',
709
+ 'mouse, computer mouse',
710
+ 'grill',
711
+ 'hole',
712
+ 'target',
713
+ 'trash bag',
714
+ 'chalk',
715
+ 'sticks',
716
+ 'balloon',
717
+ 'score',
718
+ 'hair spray',
719
+ 'roll',
720
+ 'runner',
721
+ 'engine',
722
+ 'inflatable glove',
723
+ 'games',
724
+ 'pallets',
725
+ 'baskets',
726
+ 'coop',
727
+ 'dvd player',
728
+ 'rocking horse',
729
+ 'buckets',
730
+ 'bread rolls',
731
+ 'shawl',
732
+ 'watering can',
733
+ 'spotlights',
734
+ 'post-it',
735
+ 'bowls',
736
+ 'security camera',
737
+ 'runner cloth',
738
+ 'lock',
739
+ 'alarm, warning device, alarm system',
740
+ 'side',
741
+ 'roulette',
742
+ 'bone',
743
+ 'cutlery',
744
+ 'pool balls',
745
+ 'wheels',
746
+ 'spice rack',
747
+ 'plant pots',
748
+ 'towel ring',
749
+ 'bread box',
750
+ 'video',
751
+ 'funfair',
752
+ 'breads',
753
+ 'tripod',
754
+ 'ironing board',
755
+ 'skimmer',
756
+ 'hollow',
757
+ 'scratching post',
758
+ 'tricycle',
759
+ 'file box',
760
+ 'mountain pass',
761
+ 'tombstones',
762
+ 'cooker',
763
+ 'card game, cards',
764
+ 'golf bag',
765
+ 'towel paper',
766
+ 'chaise lounge',
767
+ 'sun',
768
+ 'toilet paper holder',
769
+ 'rake',
770
+ 'key',
771
+ 'umbrella stand',
772
+ 'dartboard',
773
+ 'transformer',
774
+ 'fireplace utensils',
775
+ 'sweatshirts',
776
+ 'cellular telephone, cellular phone, cellphone, cell, mobile phone',
777
+ 'tallboy',
778
+ 'stapler',
779
+ 'sauna',
780
+ 'test tube',
781
+ 'palette',
782
+ 'shopping carts',
783
+ 'tools',
784
+ 'push button, push, button',
785
+ 'star',
786
+ 'roof rack',
787
+ 'barbed wire',
788
+ 'spray',
789
+ 'ear',
790
+ 'sponge',
791
+ 'racket',
792
+ 'tins',
793
+ 'eyeglasses',
794
+ 'file',
795
+ 'scarfs',
796
+ 'sugar bowl',
797
+ 'flip flop',
798
+ 'headstones',
799
+ 'laptop bag',
800
+ 'leash',
801
+ 'climbing frame',
802
+ 'suit hanger',
803
+ 'floor spotlight',
804
+ 'plate rack',
805
+ 'sewer',
806
+ 'hard drive',
807
+ 'sprinkler',
808
+ 'tools box',
809
+ 'necklace',
810
+ 'bulbs',
811
+ 'steel industry',
812
+ 'club',
813
+ 'jack',
814
+ 'door bars',
815
+ 'control panel, instrument panel, control board, board, panel',
816
+ 'hairbrush',
817
+ 'napkin holder',
818
+ 'office',
819
+ 'smoke detector',
820
+ 'utensils',
821
+ 'apron',
822
+ 'scissors',
823
+ 'terminal',
824
+ 'grinder',
825
+ 'entry phone',
826
+ 'newspaper stand',
827
+ 'pepper shaker',
828
+ 'onions',
829
+ (
830
+ 'central processing unit, cpu, c p u , central processor, processor,'
831
+ ' mainframe'
832
+ ),
833
+ 'tape',
834
+ 'bat',
835
+ 'coaster',
836
+ 'calculator',
837
+ 'potatoes',
838
+ 'luggage rack',
839
+ 'salt',
840
+ 'street number',
841
+ 'viewpoint',
842
+ 'sword',
843
+ 'cd',
844
+ 'rowing machine',
845
+ 'plug',
846
+ 'andiron, firedog, dog, dog-iron',
847
+ 'pepper',
848
+ 'tongs',
849
+ 'bonfire',
850
+ 'dog dish',
851
+ 'belt',
852
+ 'dumbbells',
853
+ 'videocassette recorder, vcr',
854
+ 'hook',
855
+ 'envelopes',
856
+ 'shower faucet',
857
+ 'watch',
858
+ 'padlock',
859
+ 'swimming pool ladder',
860
+ 'spanners',
861
+ 'gravy boat',
862
+ 'notice board',
863
+ 'trash bags',
864
+ 'fire alarm',
865
+ 'ladle',
866
+ 'stethoscope',
867
+ 'rocket',
868
+ 'funnel',
869
+ 'bowling pins',
870
+ 'valve',
871
+ 'thermometer',
872
+ 'cups',
873
+ 'spice jar',
874
+ 'night light',
875
+ 'soaps',
876
+ 'games table',
877
+ 'slotted spoon',
878
+ 'reel',
879
+ 'scourer',
880
+ 'sleeping robe',
881
+ 'desk mat',
882
+ 'dumbbell',
883
+ 'hammer',
884
+ 'tie',
885
+ 'typewriter',
886
+ 'shaker',
887
+ 'cheese dish',
888
+ 'sea star',
889
+ 'racquet',
890
+ 'butane gas cylinder',
891
+ 'paper weight',
892
+ 'shaving brush',
893
+ 'sunglasses',
894
+ 'gear shift',
895
+ 'towel rail',
896
+ 'adding machine, totalizer, totaliser',
897
+ ]
898
+
899
+ ADE_847_CLASS_ID = list(range(847))
900
+
901
+ ADE_847_STUFF_CLASS = [
902
+ 'wall',
903
+ 'sky',
904
+ 'tree',
905
+ 'road, route',
906
+ 'floor, flooring',
907
+ 'sidewalk, pavement',
908
+ 'earth, ground',
909
+ 'grass',
910
+ 'mountain, mount',
911
+ 'plant, flora, plant life',
912
+ 'sea',
913
+ 'water',
914
+ 'rock, stone',
915
+ 'snow',
916
+ 'sand',
917
+ 'island',
918
+ 'field',
919
+ 'forest',
920
+ 'land, ground, soil',
921
+ 'lake',
922
+ 'ice, water ice',
923
+ 'cliff, drop, drop-off',
924
+ 'dirt track',
925
+ 'hill',
926
+ 'valley, vale',
927
+ 'stream',
928
+ 'shore',
929
+ 'pond',
930
+ 'iceberg',
931
+ ]
932
+
933
+ ADE_847_THING_CLASS = [
934
+ 'building, edifice',
935
+ 'ceiling',
936
+ 'bed',
937
+ 'cabinet',
938
+ 'person, individual, someone, somebody, mortal, soul',
939
+ 'windowpane, window',
940
+ 'car, auto, automobile, machine, motorcar',
941
+ 'table',
942
+ 'chair',
943
+ 'curtain, drape, drapery, mantle, pall',
944
+ 'door',
945
+ 'sofa, couch, lounge',
946
+ 'painting, picture',
947
+ 'mirror',
948
+ 'house',
949
+ 'rug, carpet, carpeting',
950
+ 'shelf',
951
+ 'armchair',
952
+ 'fence, fencing',
953
+ 'lamp',
954
+ 'seat',
955
+ 'river',
956
+ 'desk',
957
+ 'bathtub, bathing tub, bath, tub',
958
+ 'railing, rail',
959
+ 'signboard, sign',
960
+ 'cushion',
961
+ 'path',
962
+ 'work surface',
963
+ 'stairs, steps',
964
+ 'column, pillar',
965
+ 'sink',
966
+ 'wardrobe, closet, press',
967
+ 'refrigerator, icebox',
968
+ 'base, pedestal, stand',
969
+ 'bridge, span',
970
+ 'blind, screen',
971
+ 'runway',
972
+ 'fireplace, hearth, open fireplace',
973
+ 'pillow',
974
+ 'screen door, screen',
975
+ 'toilet, can, commode, crapper, pot, potty, stool, throne',
976
+ 'skyscraper',
977
+ 'grandstand, covered stand',
978
+ 'box',
979
+ 'pool table, billiard table, snooker table',
980
+ 'palm, palm tree',
981
+ 'double door',
982
+ 'coffee table, cocktail table',
983
+ 'counter',
984
+ 'countertop',
985
+ 'chest of drawers, chest, bureau, dresser',
986
+ 'kitchen island',
987
+ 'boat',
988
+ 'waterfall, falls',
989
+ 'stove, kitchen stove, range, kitchen range, cooking stove',
990
+ 'flower',
991
+ 'bookcase',
992
+ 'controls',
993
+ 'book',
994
+ 'stairway, staircase',
995
+ 'streetlight, street lamp',
996
+ (
997
+ 'computer, computing machine, computing device, data processor,'
998
+ ' electronic computer, information processing system'
999
+ ),
1000
+ (
1001
+ 'bus, autobus, coach, charabanc, double-decker, jitney, motorbus,'
1002
+ ' motorcoach, omnibus, passenger vehicle'
1003
+ ),
1004
+ 'swivel chair',
1005
+ 'light, light source',
1006
+ 'bench',
1007
+ 'case, display case, showcase, vitrine',
1008
+ 'towel',
1009
+ 'fountain',
1010
+ 'embankment',
1011
+ (
1012
+ 'television receiver, television, television set, tv, tv set, idiot'
1013
+ ' box, boob tube, telly, goggle box'
1014
+ ),
1015
+ 'van',
1016
+ 'awning, sunshade, sunblind',
1017
+ 'poster, posting, placard, notice, bill, card',
1018
+ 'truck, motortruck',
1019
+ 'airplane, aeroplane, plane',
1020
+ 'pole',
1021
+ 'tower',
1022
+ 'court',
1023
+ 'ball',
1024
+ 'aircraft carrier, carrier, flattop, attack aircraft carrier',
1025
+ 'buffet, counter, sideboard',
1026
+ 'hovel, hut, hutch, shack, shanty',
1027
+ 'apparel, wearing apparel, dress, clothes',
1028
+ 'minibike, motorbike',
1029
+ 'animal, animate being, beast, brute, creature, fauna',
1030
+ 'chandelier, pendant, pendent',
1031
+ 'step, stair',
1032
+ 'booth, cubicle, stall, kiosk',
1033
+ 'bicycle, bike, wheel, cycle',
1034
+ 'doorframe, doorcase',
1035
+ 'sconce',
1036
+ 'trade name, brand name, brand, marque',
1037
+ 'bannister, banister, balustrade, balusters, handrail',
1038
+ 'bag',
1039
+ 'traffic light, traffic signal, stoplight',
1040
+ 'gazebo',
1041
+ 'escalator, moving staircase, moving stairway',
1042
+ 'board, plank',
1043
+ 'arcade machine',
1044
+ 'eiderdown, duvet, continental quilt',
1045
+ 'bar',
1046
+ 'stall, stand, sales booth',
1047
+ 'playground',
1048
+ 'ship',
1049
+ 'ottoman, pouf, pouffe, puff, hassock',
1050
+ (
1051
+ 'ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin,'
1052
+ ' dustbin, trash barrel, trash bin'
1053
+ ),
1054
+ 'bottle',
1055
+ 'cradle',
1056
+ 'pot, flowerpot',
1057
+ 'conveyer belt, conveyor belt, conveyer, conveyor, transporter',
1058
+ 'train, railroad train',
1059
+ 'stool',
1060
+ 'tank, storage tank',
1061
+ 'basket, handbasket',
1062
+ 'manhole',
1063
+ 'tent, collapsible shelter',
1064
+ 'canopy',
1065
+ 'microwave, microwave oven',
1066
+ 'barrel, cask',
1067
+ 'beam',
1068
+ 'dishwasher, dish washer, dishwashing machine',
1069
+ 'plate',
1070
+ 'screen, crt screen',
1071
+ 'ruins',
1072
+ 'washer, automatic washer, washing machine',
1073
+ 'blanket, cover',
1074
+ 'plaything, toy',
1075
+ 'food, solid food',
1076
+ 'screen, silver screen, projection screen',
1077
+ 'oven',
1078
+ 'stage',
1079
+ 'beacon, lighthouse, beacon light, pharos',
1080
+ 'umbrella',
1081
+ 'sculpture',
1082
+ 'aqueduct',
1083
+ 'container',
1084
+ 'scaffolding, staging',
1085
+ 'hood, exhaust hood',
1086
+ 'curb, curbing, kerb',
1087
+ 'roller coaster',
1088
+ 'horse, equus caballus',
1089
+ 'catwalk',
1090
+ 'glass, drinking glass',
1091
+ 'vase',
1092
+ 'central reservation',
1093
+ 'carousel',
1094
+ 'radiator',
1095
+ 'closet',
1096
+ 'machine',
1097
+ 'pier, wharf, wharfage, dock',
1098
+ 'fan',
1099
+ 'inflatable bounce game',
1100
+ 'pitch',
1101
+ 'paper',
1102
+ 'arcade, colonnade',
1103
+ 'hot tub',
1104
+ 'helicopter',
1105
+ 'tray',
1106
+ 'partition, divider',
1107
+ 'vineyard',
1108
+ 'bowl',
1109
+ 'bullring',
1110
+ 'flag',
1111
+ 'pot',
1112
+ 'footbridge, overcrossing, pedestrian bridge',
1113
+ 'shower',
1114
+ 'bag, traveling bag, travelling bag, grip, suitcase',
1115
+ 'bulletin board, notice board',
1116
+ 'confessional booth',
1117
+ 'trunk, tree trunk, bole',
1118
+ 'elevator door',
1119
+ 'laptop, laptop computer',
1120
+ 'instrument panel',
1121
+ 'bucket, pail',
1122
+ 'tapestry, tapis',
1123
+ 'platform',
1124
+ 'jacket',
1125
+ 'gate',
1126
+ 'monitor, monitoring device',
1127
+ 'telephone booth, phone booth, call box, telephone box, telephone kiosk',
1128
+ 'spotlight, spot',
1129
+ 'ring',
1130
+ 'control panel',
1131
+ 'blackboard, chalkboard',
1132
+ 'air conditioner, air conditioning',
1133
+ 'chest',
1134
+ 'clock',
1135
+ 'sand dune',
1136
+ 'pipe, pipage, piping',
1137
+ 'vault',
1138
+ 'table football',
1139
+ 'cannon',
1140
+ 'swimming pool, swimming bath, natatorium',
1141
+ 'fluorescent, fluorescent fixture',
1142
+ 'statue',
1143
+ 'loudspeaker, speaker, speaker unit, loudspeaker system, speaker system',
1144
+ 'exhibitor',
1145
+ 'ladder',
1146
+ 'carport',
1147
+ 'dam',
1148
+ 'pulpit',
1149
+ 'skylight, fanlight',
1150
+ 'water tower',
1151
+ 'grill, grille, grillwork',
1152
+ 'display board',
1153
+ 'pane, pane of glass, window glass',
1154
+ 'rubbish, trash, scrap',
1155
+ 'ice rink',
1156
+ 'fruit',
1157
+ 'patio',
1158
+ 'vending machine',
1159
+ 'telephone, phone, telephone set',
1160
+ 'net',
1161
+ 'backpack, back pack, knapsack, packsack, rucksack, haversack',
1162
+ 'jar',
1163
+ 'track',
1164
+ 'magazine',
1165
+ 'shutter',
1166
+ 'roof',
1167
+ 'banner, streamer',
1168
+ 'landfill',
1169
+ 'post',
1170
+ 'altarpiece, reredos',
1171
+ 'hat, chapeau, lid',
1172
+ 'arch, archway',
1173
+ 'table game',
1174
+ 'bag, handbag, pocketbook, purse',
1175
+ 'document, written document, papers',
1176
+ 'dome',
1177
+ 'pier',
1178
+ 'shanties',
1179
+ 'forecourt',
1180
+ 'crane',
1181
+ 'dog, domestic dog, canis familiaris',
1182
+ 'piano, pianoforte, forte-piano',
1183
+ 'drawing',
1184
+ 'cabin',
1185
+ 'ad, advertisement, advertizement, advertising, advertizing, advert',
1186
+ 'amphitheater, amphitheatre, coliseum',
1187
+ 'monument',
1188
+ 'henhouse',
1189
+ 'cockpit',
1190
+ 'heater, warmer',
1191
+ 'windmill, aerogenerator, wind generator',
1192
+ 'pool',
1193
+ 'elevator, lift',
1194
+ 'decoration, ornament, ornamentation',
1195
+ 'labyrinth',
1196
+ 'text, textual matter',
1197
+ 'printer',
1198
+ 'mezzanine, first balcony',
1199
+ 'mattress',
1200
+ 'straw',
1201
+ 'stalls',
1202
+ 'patio, terrace',
1203
+ 'billboard, hoarding',
1204
+ 'bus stop',
1205
+ 'trouser, pant',
1206
+ 'console table, console',
1207
+ 'rack',
1208
+ 'notebook',
1209
+ 'shrine',
1210
+ 'pantry',
1211
+ 'cart',
1212
+ 'steam shovel',
1213
+ 'porch',
1214
+ 'postbox, mailbox, letter box',
1215
+ 'figurine, statuette',
1216
+ 'recycling bin',
1217
+ 'folding screen',
1218
+ 'telescope',
1219
+ 'deck chair, beach chair',
1220
+ 'kennel',
1221
+ 'coffee maker',
1222
+ "altar, communion table, lord's table",
1223
+ 'fish',
1224
+ 'easel',
1225
+ 'artificial golf green',
1226
+ 'candlestick, candle holder',
1227
+ 'shower stall, shower bath',
1228
+ 'television stand',
1229
+ (
1230
+ 'wall socket, wall plug, electric outlet, electrical outlet, outlet,'
1231
+ ' electric receptacle'
1232
+ ),
1233
+ 'skeleton',
1234
+ 'grand piano, grand',
1235
+ 'candy, confect',
1236
+ 'grille door',
1237
+ 'pedestal, plinth, footstall',
1238
+ 'jersey, t-shirt, tee shirt',
1239
+ 'shoe',
1240
+ 'gravestone, headstone, tombstone',
1241
+ 'shanty',
1242
+ 'structure',
1243
+ 'rocking chair, rocker',
1244
+ 'bird',
1245
+ 'place mat',
1246
+ 'tomb',
1247
+ 'big top',
1248
+ 'gas pump, gasoline pump, petrol pump, island dispenser',
1249
+ 'lockers',
1250
+ 'cage',
1251
+ 'finger',
1252
+ 'bleachers',
1253
+ 'ferris wheel',
1254
+ 'hairdresser chair',
1255
+ 'mat',
1256
+ 'stands',
1257
+ 'aquarium, fish tank, marine museum',
1258
+ 'streetcar, tram, tramcar, trolley, trolley car',
1259
+ 'napkin, table napkin, serviette',
1260
+ 'dummy',
1261
+ 'booklet, brochure, folder, leaflet, pamphlet',
1262
+ 'sand trap',
1263
+ 'shop, store',
1264
+ 'table cloth',
1265
+ 'service station',
1266
+ 'coffin',
1267
+ 'drawer',
1268
+ 'cages',
1269
+ 'slot machine, coin machine',
1270
+ 'balcony',
1271
+ 'volleyball court',
1272
+ 'table tennis',
1273
+ 'control table',
1274
+ 'shirt',
1275
+ 'merchandise, ware, product',
1276
+ 'railway',
1277
+ 'parterre',
1278
+ 'chimney',
1279
+ 'can, tin, tin can',
1280
+ 'tanks',
1281
+ 'fabric, cloth, material, textile',
1282
+ 'alga, algae',
1283
+ 'system',
1284
+ 'map',
1285
+ 'greenhouse',
1286
+ 'mug',
1287
+ 'barbecue',
1288
+ 'trailer',
1289
+ 'toilet tissue, toilet paper, bathroom tissue',
1290
+ 'organ',
1291
+ 'dishrag, dishcloth',
1292
+ 'keyboard',
1293
+ 'trench',
1294
+ 'basket, basketball hoop, hoop',
1295
+ 'steering wheel, wheel',
1296
+ 'pitcher, ewer',
1297
+ 'goal',
1298
+ 'bread, breadstuff, staff of life',
1299
+ 'beds',
1300
+ 'wood',
1301
+ 'file cabinet',
1302
+ 'newspaper, paper',
1303
+ 'motorboat',
1304
+ 'rope',
1305
+ 'guitar',
1306
+ 'rubble',
1307
+ 'scarf',
1308
+ 'barrels',
1309
+ 'cap',
1310
+ 'leaves',
1311
+ 'control tower',
1312
+ 'dashboard',
1313
+ 'bandstand',
1314
+ 'lectern',
1315
+ 'switch, electric switch, electrical switch',
1316
+ 'baseboard, mopboard, skirting board',
1317
+ 'shower room',
1318
+ 'smoke',
1319
+ 'faucet, spigot',
1320
+ 'bulldozer',
1321
+ 'saucepan',
1322
+ 'shops',
1323
+ 'meter',
1324
+ 'crevasse',
1325
+ 'gear',
1326
+ 'candelabrum, candelabra',
1327
+ 'sofa bed',
1328
+ 'tunnel',
1329
+ 'pallet',
1330
+ 'wire, conducting wire',
1331
+ 'kettle, boiler',
1332
+ 'bidet',
1333
+ (
1334
+ 'baby buggy, baby carriage, carriage, perambulator, pram, stroller,'
1335
+ ' go-cart, pushchair, pusher'
1336
+ ),
1337
+ 'music stand',
1338
+ 'pipe, tube',
1339
+ 'cup',
1340
+ 'parking meter',
1341
+ 'ice hockey rink',
1342
+ 'shelter',
1343
+ 'weeds',
1344
+ 'temple',
1345
+ 'patty, cake',
1346
+ 'ski slope',
1347
+ 'panel',
1348
+ 'wallet',
1349
+ 'wheel',
1350
+ 'towel rack, towel horse',
1351
+ 'roundabout',
1352
+ 'canister, cannister, tin',
1353
+ 'rod',
1354
+ 'soap dispenser',
1355
+ 'bell',
1356
+ 'canvas',
1357
+ 'box office, ticket office, ticket booth',
1358
+ 'teacup',
1359
+ 'trellis',
1360
+ 'workbench',
1361
+ 'toaster',
1362
+ 'knife',
1363
+ 'podium',
1364
+ 'ramp',
1365
+ 'tumble dryer',
1366
+ 'fireplug, fire hydrant, plug',
1367
+ 'gym shoe, sneaker, tennis shoe',
1368
+ 'lab bench',
1369
+ 'equipment',
1370
+ 'rocky formation',
1371
+ 'plastic',
1372
+ 'calendar',
1373
+ 'caravan',
1374
+ 'check-in-desk',
1375
+ 'ticket counter',
1376
+ 'brush',
1377
+ 'mill',
1378
+ 'covered bridge',
1379
+ 'bowling alley',
1380
+ 'hanger',
1381
+ 'excavator',
1382
+ 'trestle',
1383
+ 'revolving door',
1384
+ 'blast furnace',
1385
+ 'scale, weighing machine',
1386
+ 'projector',
1387
+ 'soap',
1388
+ 'locker',
1389
+ 'tractor',
1390
+ 'stretcher',
1391
+ 'frame',
1392
+ 'grating',
1393
+ 'alembic',
1394
+ 'candle, taper, wax light',
1395
+ 'barrier',
1396
+ 'cardboard',
1397
+ 'cave',
1398
+ 'puddle',
1399
+ 'tarp',
1400
+ 'price tag',
1401
+ 'watchtower',
1402
+ 'meters',
1403
+ (
1404
+ 'light bulb, lightbulb, bulb, incandescent lamp, electric light,'
1405
+ ' electric-light bulb'
1406
+ ),
1407
+ 'tracks',
1408
+ 'hair dryer',
1409
+ 'skirt',
1410
+ 'viaduct',
1411
+ 'paper towel',
1412
+ 'coat',
1413
+ 'sheet',
1414
+ 'fire extinguisher, extinguisher, asphyxiator',
1415
+ 'water wheel',
1416
+ 'pottery, clayware',
1417
+ 'magazine rack',
1418
+ 'teapot',
1419
+ 'microphone, mike',
1420
+ 'support',
1421
+ 'forklift',
1422
+ 'canyon',
1423
+ 'cash register, register',
1424
+ 'leaf, leafage, foliage',
1425
+ 'remote control, remote',
1426
+ 'soap dish',
1427
+ 'windshield, windscreen',
1428
+ 'cat',
1429
+ 'cue, cue stick, pool cue, pool stick',
1430
+ 'vent, venthole, vent-hole, blowhole',
1431
+ 'videos',
1432
+ 'shovel',
1433
+ 'eaves',
1434
+ 'antenna, aerial, transmitting aerial',
1435
+ 'shipyard',
1436
+ 'hen, biddy',
1437
+ 'traffic cone',
1438
+ 'washing machines',
1439
+ 'truck crane',
1440
+ 'cds',
1441
+ 'niche',
1442
+ 'scoreboard',
1443
+ 'briefcase',
1444
+ 'boot',
1445
+ 'sweater, jumper',
1446
+ 'hay',
1447
+ 'pack',
1448
+ 'bottle rack',
1449
+ 'glacier',
1450
+ 'pergola',
1451
+ 'building materials',
1452
+ 'television camera',
1453
+ 'first floor',
1454
+ 'rifle',
1455
+ 'tennis table',
1456
+ 'stadium',
1457
+ 'safety belt',
1458
+ 'cover',
1459
+ 'dish rack',
1460
+ 'synthesizer',
1461
+ 'pumpkin',
1462
+ 'gutter',
1463
+ 'fruit stand',
1464
+ 'ice floe, floe',
1465
+ 'handle, grip, handgrip, hold',
1466
+ 'wheelchair',
1467
+ 'mousepad, mouse mat',
1468
+ 'diploma',
1469
+ 'fairground ride',
1470
+ 'radio',
1471
+ 'hotplate',
1472
+ 'junk',
1473
+ 'wheelbarrow',
1474
+ 'toll plaza',
1475
+ 'punching bag',
1476
+ 'trough',
1477
+ 'throne',
1478
+ 'chair desk',
1479
+ 'weighbridge',
1480
+ 'extractor fan',
1481
+ 'hanging clothes',
1482
+ 'dish, dish aerial, dish antenna, saucer',
1483
+ 'alarm clock, alarm',
1484
+ 'ski lift',
1485
+ 'chain',
1486
+ 'garage',
1487
+ 'mechanical shovel',
1488
+ 'wine rack',
1489
+ 'tramway',
1490
+ 'treadmill',
1491
+ 'menu',
1492
+ 'block',
1493
+ 'well',
1494
+ 'witness stand',
1495
+ 'branch',
1496
+ 'duck',
1497
+ 'casserole',
1498
+ 'frying pan',
1499
+ 'desk organizer',
1500
+ 'mast',
1501
+ 'spectacles, specs, eyeglasses, glasses',
1502
+ 'service elevator',
1503
+ 'dollhouse',
1504
+ 'hammock',
1505
+ 'clothes hanging',
1506
+ 'photocopier',
1507
+ 'notepad',
1508
+ 'golf cart',
1509
+ 'footpath',
1510
+ 'cross',
1511
+ 'baptismal font',
1512
+ 'boiler',
1513
+ 'skip',
1514
+ 'rotisserie',
1515
+ 'tables',
1516
+ 'water mill',
1517
+ 'helmet',
1518
+ 'cover curtain',
1519
+ 'brick',
1520
+ 'table runner',
1521
+ 'ashtray',
1522
+ 'street box',
1523
+ 'stick',
1524
+ 'hangers',
1525
+ 'cells',
1526
+ 'urinal',
1527
+ 'centerpiece',
1528
+ 'portable fridge',
1529
+ 'dvds',
1530
+ 'golf club',
1531
+ 'skirting board',
1532
+ 'water cooler',
1533
+ 'clipboard',
1534
+ 'camera, photographic camera',
1535
+ 'pigeonhole',
1536
+ 'chips',
1537
+ 'food processor',
1538
+ 'post box',
1539
+ 'lid',
1540
+ 'drum',
1541
+ 'blender',
1542
+ 'cave entrance',
1543
+ 'dental chair',
1544
+ 'obelisk',
1545
+ 'canoe',
1546
+ 'mobile',
1547
+ 'monitors',
1548
+ 'pool ball',
1549
+ 'cue rack',
1550
+ 'baggage carts',
1551
+ 'fork',
1552
+ 'paper filer',
1553
+ 'bicycle rack',
1554
+ 'coat rack',
1555
+ 'garland',
1556
+ 'sports bag',
1557
+ 'fish tank',
1558
+ 'towel dispenser',
1559
+ 'carriage',
1560
+ 'brochure',
1561
+ 'plaque',
1562
+ 'stringer',
1563
+ 'iron',
1564
+ 'spoon',
1565
+ 'flag pole',
1566
+ 'toilet brush',
1567
+ 'book stand',
1568
+ 'water faucet, water tap, tap, hydrant',
1569
+ 'ticket office',
1570
+ 'broom',
1571
+ 'dvd',
1572
+ 'ice bucket',
1573
+ 'carapace, shell, cuticle, shield',
1574
+ 'tureen',
1575
+ 'folders',
1576
+ 'chess',
1577
+ 'root',
1578
+ 'sewing machine',
1579
+ 'model',
1580
+ 'pen',
1581
+ 'violin',
1582
+ 'sweatshirt',
1583
+ 'recycling materials',
1584
+ 'mitten',
1585
+ 'chopping board, cutting board',
1586
+ 'mask',
1587
+ 'log',
1588
+ 'mouse, computer mouse',
1589
+ 'grill',
1590
+ 'hole',
1591
+ 'target',
1592
+ 'trash bag',
1593
+ 'chalk',
1594
+ 'sticks',
1595
+ 'balloon',
1596
+ 'score',
1597
+ 'hair spray',
1598
+ 'roll',
1599
+ 'runner',
1600
+ 'engine',
1601
+ 'inflatable glove',
1602
+ 'games',
1603
+ 'pallets',
1604
+ 'baskets',
1605
+ 'coop',
1606
+ 'dvd player',
1607
+ 'rocking horse',
1608
+ 'buckets',
1609
+ 'bread rolls',
1610
+ 'shawl',
1611
+ 'watering can',
1612
+ 'spotlights',
1613
+ 'post-it',
1614
+ 'bowls',
1615
+ 'security camera',
1616
+ 'runner cloth',
1617
+ 'lock',
1618
+ 'alarm, warning device, alarm system',
1619
+ 'side',
1620
+ 'roulette',
1621
+ 'bone',
1622
+ 'cutlery',
1623
+ 'pool balls',
1624
+ 'wheels',
1625
+ 'spice rack',
1626
+ 'plant pots',
1627
+ 'towel ring',
1628
+ 'bread box',
1629
+ 'video',
1630
+ 'funfair',
1631
+ 'breads',
1632
+ 'tripod',
1633
+ 'ironing board',
1634
+ 'skimmer',
1635
+ 'hollow',
1636
+ 'scratching post',
1637
+ 'tricycle',
1638
+ 'file box',
1639
+ 'mountain pass',
1640
+ 'tombstones',
1641
+ 'cooker',
1642
+ 'card game, cards',
1643
+ 'golf bag',
1644
+ 'towel paper',
1645
+ 'chaise lounge',
1646
+ 'sun',
1647
+ 'toilet paper holder',
1648
+ 'rake',
1649
+ 'key',
1650
+ 'umbrella stand',
1651
+ 'dartboard',
1652
+ 'transformer',
1653
+ 'fireplace utensils',
1654
+ 'sweatshirts',
1655
+ 'cellular telephone, cellular phone, cellphone, cell, mobile phone',
1656
+ 'tallboy',
1657
+ 'stapler',
1658
+ 'sauna',
1659
+ 'test tube',
1660
+ 'palette',
1661
+ 'shopping carts',
1662
+ 'tools',
1663
+ 'push button, push, button',
1664
+ 'star',
1665
+ 'roof rack',
1666
+ 'barbed wire',
1667
+ 'spray',
1668
+ 'ear',
1669
+ 'sponge',
1670
+ 'racket',
1671
+ 'tins',
1672
+ 'eyeglasses',
1673
+ 'file',
1674
+ 'scarfs',
1675
+ 'sugar bowl',
1676
+ 'flip flop',
1677
+ 'headstones',
1678
+ 'laptop bag',
1679
+ 'leash',
1680
+ 'climbing frame',
1681
+ 'suit hanger',
1682
+ 'floor spotlight',
1683
+ 'plate rack',
1684
+ 'sewer',
1685
+ 'hard drive',
1686
+ 'sprinkler',
1687
+ 'tools box',
1688
+ 'necklace',
1689
+ 'bulbs',
1690
+ 'steel industry',
1691
+ 'club',
1692
+ 'jack',
1693
+ 'door bars',
1694
+ 'control panel, instrument panel, control board, board, panel',
1695
+ 'hairbrush',
1696
+ 'napkin holder',
1697
+ 'office',
1698
+ 'smoke detector',
1699
+ 'utensils',
1700
+ 'apron',
1701
+ 'scissors',
1702
+ 'terminal',
1703
+ 'grinder',
1704
+ 'entry phone',
1705
+ 'newspaper stand',
1706
+ 'pepper shaker',
1707
+ 'onions',
1708
+ (
1709
+ 'central processing unit, cpu, c p u , central processor, processor,'
1710
+ ' mainframe'
1711
+ ),
1712
+ 'tape',
1713
+ 'bat',
1714
+ 'coaster',
1715
+ 'calculator',
1716
+ 'potatoes',
1717
+ 'luggage rack',
1718
+ 'salt',
1719
+ 'street number',
1720
+ 'viewpoint',
1721
+ 'sword',
1722
+ 'cd',
1723
+ 'rowing machine',
1724
+ 'plug',
1725
+ 'andiron, firedog, dog, dog-iron',
1726
+ 'pepper',
1727
+ 'tongs',
1728
+ 'bonfire',
1729
+ 'dog dish',
1730
+ 'belt',
1731
+ 'dumbbells',
1732
+ 'videocassette recorder, vcr',
1733
+ 'hook',
1734
+ 'envelopes',
1735
+ 'shower faucet',
1736
+ 'watch',
1737
+ 'padlock',
1738
+ 'swimming pool ladder',
1739
+ 'spanners',
1740
+ 'gravy boat',
1741
+ 'notice board',
1742
+ 'trash bags',
1743
+ 'fire alarm',
1744
+ 'ladle',
1745
+ 'stethoscope',
1746
+ 'rocket',
1747
+ 'funnel',
1748
+ 'bowling pins',
1749
+ 'valve',
1750
+ 'thermometer',
1751
+ 'cups',
1752
+ 'spice jar',
1753
+ 'night light',
1754
+ 'soaps',
1755
+ 'games table',
1756
+ 'slotted spoon',
1757
+ 'reel',
1758
+ 'scourer',
1759
+ 'sleeping robe',
1760
+ 'desk mat',
1761
+ 'dumbbell',
1762
+ 'hammer',
1763
+ 'tie',
1764
+ 'typewriter',
1765
+ 'shaker',
1766
+ 'cheese dish',
1767
+ 'sea star',
1768
+ 'racquet',
1769
+ 'butane gas cylinder',
1770
+ 'paper weight',
1771
+ 'shaving brush',
1772
+ 'sunglasses',
1773
+ 'gear shift',
1774
+ 'towel rail',
1775
+ 'adding machine, totalizer, totaliser',
1776
+ ]
1777
+
1778
+ ADE_847_STUFF_CLASS_ID = [
1779
+ 0, 2, 3, 4, 5, 8, 9, 12, 15, 16, 22, 24, 33, 47, 54, 368, 31, 195, 118, 134,
1780
+ 136, 53, 143, 90, 435, 546, 624, 111, 304,
1781
+ ]
1782
+
1783
+ ADE_847_THING_CLASS_ID = [
1784
+ i for i in ADE_847_CLASS_ID if i not in ADE_847_STUFF_CLASS_ID
1785
+ ]
1786
+
1787
+
1788
+ class ADE847Dataset(Dataset):
1789
+ """ADE847 dataset."""
1790
+
1791
+ def __init__(self, root, split='validation', transform=None):
1792
+ super(ADE847Dataset, self).__init__()
1793
+ self.root = root
1794
+ self.split = split
1795
+ self.transforms = transform
1796
+ self.image_dir = os.path.join(root, 'images_detectron2', split)
1797
+ self.mask_dir = os.path.join(root, 'annotations_detectron2', split)
1798
+ self.images = os.listdir(self.image_dir)
1799
+
1800
+ def process_mask(self, mask):
1801
+ mask = np.array(mask)
1802
+ mask[mask > 847] = 0
1803
+ return mask
1804
+
1805
+ def __getitem__(self, index):
1806
+ image_path = os.path.join(self.image_dir, self.images[index])
1807
+ image = Image.open(image_path).convert('RGB')
1808
+ target = (
1809
+ np.asarray(
1810
+ Image.open(
1811
+ os.path.join(
1812
+ self.mask_dir, self.images[index].replace('jpg', 'tif')
1813
+ )
1814
+ ),
1815
+ dtype=np.int32,
1816
+ )
1817
+ + 1
1818
+ )
1819
+ target = self.process_mask(target)
1820
+
1821
+ if self.transforms:
1822
+ image = self.transforms(image)
1823
+
1824
+ return image, image_path, target, index
1825
+
1826
+ def __len__(self):
1827
+ return len(self.images)
data/coco.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """COCO Stuff Dataset."""
17
+
18
+ import os
19
+ import numpy as np
20
+ from PIL import Image
21
+ import torch
22
+
23
+
24
+ COCO_OBJECT_CLASSES = [
25
+ 'person with clothes,people,human',
26
+ 'bicycle',
27
+ 'car',
28
+ 'motorbike',
29
+ 'aeroplane',
30
+ 'bus',
31
+ 'train',
32
+ 'truck',
33
+ 'boat',
34
+ 'traffic light',
35
+ 'fire hydrant',
36
+ 'stop sign',
37
+ 'parking meter',
38
+ 'bench',
39
+ 'bird avian',
40
+ 'cat',
41
+ 'dog',
42
+ 'horse',
43
+ 'sheep',
44
+ 'cow',
45
+ 'elephant',
46
+ 'bear',
47
+ 'zebra',
48
+ 'giraffe',
49
+ 'backpack,bag',
50
+ 'umbrella,parasol',
51
+ 'handbag,purse',
52
+ 'necktie',
53
+ 'suitcase',
54
+ 'frisbee',
55
+ 'skis',
56
+ 'sknowboard',
57
+ 'sports ball',
58
+ 'kite',
59
+ 'baseball bat',
60
+ 'glove',
61
+ 'skateboard',
62
+ 'surfboard',
63
+ 'tennis racket',
64
+ 'bottle',
65
+ 'wine glass',
66
+ 'cup',
67
+ 'fork',
68
+ 'knife',
69
+ 'dessertspoon',
70
+ 'bowl',
71
+ 'banana',
72
+ 'apple',
73
+ 'sandwich',
74
+ 'orange',
75
+ 'broccoli',
76
+ 'carrot',
77
+ 'hot dog',
78
+ 'pizza',
79
+ 'donut',
80
+ 'cake',
81
+ 'chair seat',
82
+ 'sofa',
83
+ 'pottedplant',
84
+ 'bed',
85
+ 'diningtable',
86
+ 'toilet',
87
+ 'tvmonitor screen',
88
+ 'laptop',
89
+ 'mouse',
90
+ 'remote control',
91
+ 'keyboard',
92
+ 'cell phone',
93
+ 'microwave',
94
+ 'oven',
95
+ 'toaster',
96
+ 'sink',
97
+ 'refrigerator',
98
+ 'book',
99
+ 'clock',
100
+ 'vase',
101
+ 'scissors',
102
+ 'teddy bear',
103
+ 'hairdrier,blowdrier',
104
+ 'toothbrush',
105
+ ]
106
+
107
+
108
+ class COCODataset(torch.utils.data.Dataset):
109
+ """COCO Object Dataset."""
110
+
111
+ def __init__(self, root, split='val', transform=None):
112
+ """Construct COCO Object Dataset.
113
+
114
+ Args:
115
+ root (string): Root directory where images are downloaded.
116
+ split (string): Path to the annotation file.
117
+ transform (callable, optional): Optional transform to be applied on a
118
+ sample.
119
+ """
120
+ self.root = root
121
+ self.image_dir = os.path.join(root, 'images', f'{split}2017')
122
+ self.ann_dir = os.path.join(root, 'annotations', f'{split}2017')
123
+ self.images = os.listdir(self.image_dir)
124
+ self.transform = transform
125
+
126
+ def __getitem__(self, index):
127
+ img_path = os.path.join(self.image_dir, self.images[index])
128
+ img = Image.open(img_path).convert('RGB')
129
+ img = np.asarray(img)
130
+ idx = self.images[index].split('.')[0]
131
+ ann_path = os.path.join(self.ann_dir, f'{idx}_instanceTrainIds.png')
132
+ ann = np.asarray(Image.open(ann_path), dtype=np.int32)
133
+
134
+ return img, img_path, ann, idx
135
+
136
+ def __len__(self):
137
+ return len(self.images)
data/context.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Pascal Context Dataset."""
17
+
18
+ from typing import Any, List, Tuple
19
+
20
+ import numpy as np
21
+ from PIL import Image
22
+ # pylint: disable=g-importing-member
23
+ from torchvision.datasets.voc import _VOCBase
24
+
25
+
26
+ PASCAL_CONTEXT_CLASSES = [
27
+ 'airplane', 'bag', 'bed', 'bedclothes', 'bench', 'bicycle', 'bird', 'boat',
28
+ 'book', 'bottle', 'building', 'bus', 'cabinet', 'car', 'cat', 'ceiling',
29
+ 'chair', 'cloth', 'computer', 'cow', 'cup', 'curtain', 'dog', 'door',
30
+ 'fence', 'floor', 'flower', 'food', 'grass', 'ground', 'horse', 'keyboard',
31
+ 'light', 'motorbike', 'mountain', 'mouse', 'person', 'plate', 'platform',
32
+ 'plant', 'road', 'rock', 'sheep', 'shelves', 'sidewalk', 'sign', 'sky',
33
+ 'snow', 'sofa', 'table', 'track', 'train', 'tree', 'truck', 'monitor',
34
+ 'wall', 'water', 'window', 'wood']
35
+
36
+ PASCAL_CONTEXT_STUFF_CLASS = [
37
+ 'bedclothes', 'ceiling', 'cloth', 'curtain', 'floor', 'grass', 'ground',
38
+ 'light', 'mountain', 'platform', 'road', 'sidewalk', 'sky', 'snow', 'wall',
39
+ 'water', 'window', 'wood', 'door', 'fence', 'rock']
40
+
41
+ PASCAL_CONTEXT_THING_CLASS = [
42
+ 'airplane', 'bag', 'bed', 'bench', 'bicycle', 'bird', 'boat', 'book',
43
+ 'bottle', 'building', 'bus', 'cabinet', 'car', 'cat', 'chair', 'computer',
44
+ 'cow', 'cup', 'dog', 'flower', 'food', 'horse', 'keyboard', 'motorbike',
45
+ 'mouse', 'person', 'plate', 'plant', 'sheep', 'shelves', 'sign', 'sofa',
46
+ 'table', 'track', 'train', 'tree', 'truck', 'monitor']
47
+
48
+ PASCAL_CONTEXT_STUFF_CLASS_ID = [
49
+ 3, 15, 17, 21, 25, 28, 29, 32, 34, 38, 40, 44, 46, 47, 55, 56, 57, 58, 23,
50
+ 24, 41]
51
+
52
+ PASCAL_CONTEXT_THING_CLASS_ID = [
53
+ 0, 1, 2, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 16, 18, 19, 20, 22, 26, 27,
54
+ 30, 31, 33, 35, 36, 37, 39, 42, 43, 45, 48, 49, 50, 51, 52, 53, 54]
55
+
56
+
57
+ class CONTEXTSegmentation(_VOCBase):
58
+ """Pascal VOC <http://host.robots.ox.ac.uk/pascal/VOC/> Segmentation Dataset.
59
+
60
+ Attributes:
61
+ root (string): Root directory of the VOC Dataset.
62
+ year (string, optional): The dataset year, supports years ``"2007"`` to
63
+ ``"2012"``.
64
+ image_set (string, optional): Select the image_set to use, ``"train"``,
65
+ ``"trainval"`` or ``"val"``. If ``year=="2007"``, can also be
66
+ ``"test"``.
67
+ download (bool, optional): If true, downloads the dataset from the
68
+ internet and puts it in root directory. If dataset is already
69
+ downloaded, it is not downloaded again.
70
+ transform (callable, optional): A function/transform that takes in an PIL
71
+ image and returns a transformed version. E.g, ``transforms.RandomCrop``
72
+ target_transform (callable, optional): A function/transform that takes in
73
+ the target and transforms it.
74
+ transforms (callable, optional): A function/transform that takes input
75
+ sample and its target as entry and returns a transformed version.
76
+ """
77
+
78
+ _SPLITS_DIR = 'SegmentationContext'
79
+ _TARGET_DIR = 'SegmentationClassContext'
80
+ _TARGET_FILE_EXT = '.png'
81
+
82
+ @property
83
+ def masks(self):
84
+ return self.targets
85
+
86
+ def __getitem__(self, index):
87
+ """Get a sample of image and segmentation.
88
+
89
+ Args:
90
+ index (int): Index
91
+ Returns:
92
+ tuple: (image, target) where target is the image segmentation.
93
+ """
94
+ img = Image.open(self.images[index]).convert('RGB')
95
+ target = Image.open(self.masks[index])
96
+
97
+ if self.transforms is not None:
98
+ img, target = self.transforms(img, target)
99
+
100
+ return img, target
101
+
102
+
103
+ class CONTEXTDataset(CONTEXTSegmentation):
104
+ """Pascal Context Dataset."""
105
+
106
+ def __init__(self, root, year='2012', split='val', transform=None):
107
+ super(CONTEXTDataset, self).__init__(
108
+ root=root,
109
+ image_set=split,
110
+ year=year,
111
+ transform=transform,
112
+ download=False,
113
+ )
114
+ # self.idx_to_class = {val: key for (key, val) in CLASS2ID.items()}
115
+
116
+ def __getitem__(self, index):
117
+ image_path = self.images[index]
118
+ image = Image.open(image_path).convert('RGB')
119
+ target = np.asarray(Image.open(self.masks[index]), dtype=np.int32)
120
+ # transpose the target width and height
121
+ # target = target.transpose(1, 0)
122
+
123
+ if self.transforms:
124
+ image = self.transform(image)
125
+
126
+ return image, str(image_path), target, index
data/gres.py ADDED
@@ -0,0 +1,455 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """grefer v0.1.
17
+
18
+ This interface provides access to gRefCOCO.
19
+
20
+ The following API functions are defined:
21
+ G_REFER - REFER api class
22
+ getRefIds - get ref ids that satisfy given filter conditions.
23
+ getAnnIds - get ann ids that satisfy given filter conditions.
24
+ getImgIds - get image ids that satisfy given filter conditions.
25
+ getCatIds - get category ids that satisfy given filter conditions.
26
+ loadRefs - load refs with the specified ref ids.
27
+ loadAnns - load anns with the specified ann ids.
28
+ loadImgs - load images with the specified image ids.
29
+ loadCats - load category names with the specified category ids.
30
+ getRefBox - get ref's bounding box [x, y, w, h] given the ref_id
31
+ showRef - show image, segmentation or box of the referred object with the
32
+ ref
33
+ getMaskByRef - get mask and area of the referred object given ref or ref ids
34
+ getMask - get mask and area of the referred object given ref
35
+ showMask - show mask of the referred object given ref
36
+ """
37
+ # Adapted from
38
+ # https://github.com/yz93/LAVT-RIS/blob/main/data/dataset_refer_bert.py
39
+
40
+ # pylint: disable=all
41
+ import itertools
42
+ import json
43
+ import os
44
+ import os.path as osp
45
+ import pickle
46
+ import time
47
+ # pylint: disable=g-importing-member
48
+ from matplotlib.collections import PatchCollection
49
+ from matplotlib.patches import Polygon
50
+ from matplotlib.patches import Rectangle
51
+ import matplotlib.pyplot as plt
52
+ import numpy as np
53
+ from PIL import Image
54
+ from pycocotools import mask
55
+ from skimage import io
56
+ import torch
57
+ from torch.utils import data
58
+
59
+
60
+ class G_REFER:
61
+ """GRES dataset."""
62
+
63
+ def __init__(self, data_root, dataset='grefcoco', splitBy='unc'):
64
+ # provide data_root folder which contains grefcoco
65
+ print('loading dataset %s into memory...' % dataset)
66
+ self.ROOT_DIR = osp.abspath(osp.dirname(__file__))
67
+ self.DATA_DIR = osp.join(data_root, dataset)
68
+ if dataset in ['grefcoco']:
69
+ self.IMAGE_DIR = osp.join(data_root, 'images/mscoco/images/train2014')
70
+ else:
71
+ raise KeyError('No refer dataset is called [%s]' % dataset)
72
+
73
+ tic = time.time()
74
+
75
+ # load refs from data/dataset/refs(dataset).json
76
+ self.data = {}
77
+ self.data['dataset'] = dataset
78
+
79
+ ref_file = osp.join(self.DATA_DIR, f'grefs({splitBy}).p')
80
+ if osp.exists(ref_file):
81
+ self.data['refs'] = pickle.load(open(ref_file, 'rb'), fix_imports=True)
82
+ else:
83
+ ref_file = osp.join(self.DATA_DIR, f'grefs({splitBy}).json')
84
+ if osp.exists(ref_file):
85
+ self.data['refs'] = json.load(open(ref_file, 'rb'))
86
+ else:
87
+ raise FileNotFoundError('JSON file not found')
88
+
89
+ # load annotations from data/dataset/instances.json
90
+ instances_file = osp.join(self.DATA_DIR, 'instances.json')
91
+ instances = json.load(open(instances_file, 'r'))
92
+ self.data['images'] = instances['images']
93
+ self.data['annotations'] = instances['annotations']
94
+ self.data['categories'] = instances['categories']
95
+
96
+ # create index
97
+ self.createIndex()
98
+ print('DONE (t=%.2fs)' % (time.time() - tic))
99
+
100
+ @staticmethod
101
+ def _toList(x):
102
+ return x if isinstance(x, list) else [x]
103
+
104
+ @staticmethod
105
+ def match_any(a, b):
106
+ a = a if isinstance(a, list) else [a]
107
+ b = b if isinstance(b, list) else [b]
108
+ return set(a) & set(b)
109
+
110
+ def createIndex(self):
111
+ # create sets of mapping
112
+ # 1) Refs: {ref_id: ref}
113
+ # 2) Anns: {ann_id: ann}
114
+ # 3) Imgs: {image_id: image}
115
+ # 4) Cats: {category_id: category_name}
116
+ # 5) Sents: {sent_id: sent}
117
+ # 6) imgToRefs: {image_id: refs}
118
+ # 7) imgToAnns: {image_id: anns}
119
+ # 8) refToAnn: {ref_id: ann}
120
+ # 9) annToRef: {ann_id: ref}
121
+ # 10) catToRefs: {category_id: refs}
122
+ # 11) sentToRef: {sent_id: ref}
123
+ # 12) sentToTokens: {sent_id: tokens}
124
+ print('creating index...')
125
+ # fetch info from instances
126
+ Anns, Imgs, Cats, imgToAnns = {}, {}, {}, {}
127
+ Anns[-1] = None
128
+ for ann in self.data['annotations']:
129
+ Anns[ann['id']] = ann
130
+ imgToAnns[ann['image_id']] = imgToAnns.get(ann['image_id'], []) + [ann]
131
+ for img in self.data['images']:
132
+ Imgs[img['id']] = img
133
+ for cat in self.data['categories']:
134
+ Cats[cat['id']] = cat['name']
135
+
136
+ # fetch info from refs
137
+ Refs, imgToRefs, refToAnn, annToRef, catToRefs = {}, {}, {}, {}, {}
138
+ Sents, sentToRef, sentToTokens = {}, {}, {}
139
+ availableSplits = []
140
+ for ref in self.data['refs']:
141
+ # ids
142
+ ref_id = ref['ref_id']
143
+ ann_id = ref['ann_id']
144
+ category_id = ref['category_id']
145
+ image_id = ref['image_id']
146
+
147
+ if ref['split'] not in availableSplits:
148
+ availableSplits.append(ref['split'])
149
+
150
+ # add mapping related to ref
151
+ if ref_id in Refs:
152
+ print('Duplicate ref id')
153
+ Refs[ref_id] = ref
154
+ imgToRefs[image_id] = imgToRefs.get(image_id, []) + [ref]
155
+
156
+ category_id = self._toList(category_id)
157
+ added_cats = []
158
+ for cat in category_id:
159
+ if cat not in added_cats:
160
+ added_cats.append(cat)
161
+ catToRefs[cat] = catToRefs.get(cat, []) + [ref]
162
+
163
+ ann_id = self._toList(ann_id)
164
+ refToAnn[ref_id] = [Anns[ann] for ann in ann_id]
165
+ for ann_id_n in ann_id:
166
+ annToRef[ann_id_n] = annToRef.get(ann_id_n, []) + [ref]
167
+
168
+ # add mapping of sent
169
+ for sent in ref['sentences']:
170
+ Sents[sent['sent_id']] = sent
171
+ sentToRef[sent['sent_id']] = ref
172
+ sentToTokens[sent['sent_id']] = sent['tokens']
173
+
174
+ # create class members
175
+ self.Refs = Refs
176
+ self.Anns = Anns
177
+ self.Imgs = Imgs
178
+ self.Cats = Cats
179
+ self.Sents = Sents
180
+ self.imgToRefs = imgToRefs
181
+ self.imgToAnns = imgToAnns
182
+ self.refToAnn = refToAnn
183
+ self.annToRef = annToRef
184
+ self.catToRefs = catToRefs
185
+ self.sentToRef = sentToRef
186
+ self.sentToTokens = sentToTokens
187
+ self.availableSplits = availableSplits
188
+ print('index created.')
189
+
190
+ def getRefIds(self, image_ids=[], cat_ids=[], split=[]):
191
+ image_ids = self._toList(image_ids)
192
+ cat_ids = self._toList(cat_ids)
193
+ split = self._toList(split)
194
+
195
+ for s in split:
196
+ if s not in self.availableSplits:
197
+ raise ValueError(f'Invalid split name: {s}')
198
+
199
+ refs = self.data['refs']
200
+
201
+ if len(image_ids) > 0:
202
+ lists = [self.imgToRefs[image_id] for image_id in image_ids]
203
+ refs = list(itertools.chain.from_iterable(lists))
204
+ if len(cat_ids) > 0:
205
+ refs = [
206
+ ref for ref in refs if self.match_any(ref['category_id'], cat_ids)
207
+ ]
208
+ if len(split) > 0:
209
+ refs = [ref for ref in refs if ref['split'] in split]
210
+
211
+ ref_ids = [ref['ref_id'] for ref in refs]
212
+ return ref_ids
213
+
214
+ def getAnnIds(self, image_ids=[], ref_ids=[]):
215
+ image_ids = self._toList(image_ids)
216
+ ref_ids = self._toList(ref_ids)
217
+
218
+ if any([len(image_ids), len(ref_ids)]):
219
+ if len(image_ids) > 0:
220
+ lists = [
221
+ self.imgToAnns[image_id]
222
+ for image_id in image_ids
223
+ if image_id in self.imgToAnns
224
+ ]
225
+ anns = list(itertools.chain.from_iterable(lists))
226
+ else:
227
+ anns = self.data['annotations']
228
+ ann_ids = [ann['id'] for ann in anns]
229
+ if len(ref_ids) > 0:
230
+ lists = [self.Refs[ref_id]['ann_id'] for ref_id in ref_ids]
231
+ anns_by_ref_id = list(itertools.chain.from_iterable(lists))
232
+ ann_ids = list(set(ann_ids).intersection(set(anns_by_ref_id)))
233
+ else:
234
+ ann_ids = [ann['id'] for ann in self.data['annotations']]
235
+
236
+ return ann_ids
237
+
238
+ def getImgIds(self, ref_ids=[]):
239
+ ref_ids = self._toList(ref_ids)
240
+
241
+ if len(ref_ids) > 0:
242
+ image_ids = list(
243
+ set([self.Refs[ref_id]['image_id'] for ref_id in ref_ids])
244
+ )
245
+ else:
246
+ image_ids = self.Imgs.keys()
247
+ return image_ids
248
+
249
+ def getCatIds(self):
250
+ return self.Cats.keys()
251
+
252
+ def loadRefs(self, ref_ids=[]):
253
+ return [self.Refs[ref_id] for ref_id in self._toList(ref_ids)]
254
+
255
+ def loadAnns(self, ann_ids=[]):
256
+ if isinstance(ann_ids, str):
257
+ ann_ids = int(ann_ids)
258
+ return [self.Anns[ann_id] for ann_id in self._toList(ann_ids)]
259
+
260
+ def loadImgs(self, image_ids=[]):
261
+ return [self.Imgs[image_id] for image_id in self._toList(image_ids)]
262
+
263
+ def loadCats(self, cat_ids=[]):
264
+ return [self.Cats[cat_id] for cat_id in self._toList(cat_ids)]
265
+
266
+ def getRefBox(self, ref_id):
267
+ anns = self.refToAnn[ref_id]
268
+ return [ann['bbox'] for ann in anns] # [x, y, w, h]
269
+
270
+ def showRef(self, ref, seg_box='seg'):
271
+ ax = plt.gca()
272
+ # show image
273
+ image = self.Imgs[ref['image_id']]
274
+ I = io.imread(osp.join(self.IMAGE_DIR, image['file_name']))
275
+ ax.imshow(I)
276
+ # show refer expression
277
+ for sid, sent in enumerate(ref['sentences']):
278
+ print('%s. %s' % (sid + 1, sent['sent']))
279
+ # show segmentations
280
+ if seg_box == 'seg':
281
+ ann_id = ref['ann_id']
282
+ ann = self.Anns[ann_id]
283
+ polygons = []
284
+ color = []
285
+ c = 'none'
286
+ if type(ann['segmentation'][0]) == list:
287
+ # polygon used for refcoco*
288
+ for seg in ann['segmentation']:
289
+ poly = np.array(seg).reshape((len(seg) / 2, 2))
290
+ polygons.append(Polygon(poly, True, alpha=0.4))
291
+ color.append(c)
292
+ p = PatchCollection(
293
+ polygons,
294
+ facecolors=color,
295
+ edgecolors=(1, 1, 0, 0),
296
+ linewidths=3,
297
+ alpha=1,
298
+ )
299
+ ax.add_collection(p) # thick yellow polygon
300
+ p = PatchCollection(
301
+ polygons,
302
+ facecolors=color,
303
+ edgecolors=(1, 0, 0, 0),
304
+ linewidths=1,
305
+ alpha=1,
306
+ )
307
+ ax.add_collection(p) # thin red polygon
308
+ else:
309
+ # mask used for refclef
310
+ rle = ann['segmentation']
311
+ m = mask.decode(rle)
312
+ img = np.ones((m.shape[0], m.shape[1], 3))
313
+ color_mask = np.array([2.0, 166.0, 101.0]) / 255
314
+ for i in range(3):
315
+ img[:, :, i] = color_mask[i]
316
+ ax.imshow(np.dstack((img, m * 0.5)))
317
+ # show bounding-box
318
+ elif seg_box == 'box':
319
+ # ann_id = ref['ann_id']
320
+ # ann = self.Anns[ann_id]
321
+ bbox = self.getRefBox(ref['ref_id'])
322
+ box_plot = Rectangle(
323
+ (bbox[0], bbox[1]),
324
+ bbox[2],
325
+ bbox[3],
326
+ fill=False,
327
+ edgecolor='green',
328
+ linewidth=3,
329
+ )
330
+ ax.add_patch(box_plot)
331
+
332
+ def getMask(self, ann):
333
+ if not ann:
334
+ return None
335
+ if ann['iscrowd']:
336
+ raise ValueError('Crowd object')
337
+ image = self.Imgs[ann['image_id']]
338
+ if type(ann['segmentation'][0]) == list: # polygon
339
+ rle = mask.frPyObjects(
340
+ ann['segmentation'], image['height'], image['width']
341
+ )
342
+ else:
343
+ rle = ann['segmentation']
344
+
345
+ m = mask.decode(rle)
346
+ # sometimes there are multiple binary map (corresponding to multiple segs)
347
+ m = np.sum(m, axis=2)
348
+ m = m.astype(np.uint8) # convert to np.uint8
349
+ # compute area
350
+ area = sum(mask.area(rle)) # should be close to ann['area']
351
+ return {'mask': m, 'area': area}
352
+
353
+ def getMaskByRef(self, ref=None, ref_id=None, merge=False):
354
+ if not ref and not ref_id:
355
+ raise ValueError
356
+ if ref:
357
+ ann_ids = ref['ann_id']
358
+ ref_id = ref['ref_id']
359
+ else:
360
+ ann_ids = self.getAnnIds(ref_ids=ref_id)
361
+
362
+ if ann_ids == [-1]:
363
+ img = self.Imgs[self.Refs[ref_id]['image_id']]
364
+ return {
365
+ 'mask': np.zeros([img['height'], img['width']], dtype=np.uint8),
366
+ 'empty': True,
367
+ }
368
+
369
+ anns = self.loadAnns(ann_ids)
370
+ mask_list = [self.getMask(ann) for ann in anns if not ann['iscrowd']]
371
+
372
+ if merge:
373
+ merged_masks = sum([mask['mask'] for mask in mask_list])
374
+ merged_masks[np.where(merged_masks > 1)] = 1
375
+ return {'mask': merged_masks, 'empty': False}
376
+ else:
377
+ return mask_list
378
+
379
+ def showMask(self, ref):
380
+ M = self.getMask(ref)
381
+ msk = M['mask']
382
+ ax = plt.gca()
383
+ ax.imshow(msk)
384
+
385
+
386
+ class GReferDataset(data.Dataset):
387
+
388
+ def __init__(self, root, transform=None, split='val'):
389
+
390
+ self.classes = []
391
+ self.image_transforms = transform
392
+ self.split = split
393
+ self.refer = G_REFER(root)
394
+
395
+ ref_ids = self.refer.getRefIds(split=self.split)
396
+ img_ids = self.refer.getImgIds(ref_ids)
397
+
398
+ all_imgs = self.refer.Imgs
399
+ self.imgs = list(all_imgs[i] for i in img_ids)
400
+ self.ref_ids = []
401
+ # print(len(ref_ids))
402
+ # print(len(self.imgs))
403
+ self.sentence_raw = []
404
+ # if we are testing on a dataset, test all sentences of an object;
405
+ # o/w, we are validating during training, randomly sample one sentence
406
+ # for efficiency
407
+ for r in ref_ids:
408
+ ref = self.refer.Refs[r]
409
+ # ref_sentences = []
410
+ # for i, (el, sent_id) in enumerate(zip(ref['sentences'],
411
+ # ref['sent_ids'])):
412
+ for el in ref['sentences']:
413
+ sentence_raw = el['raw']
414
+ if len(sentence_raw) == 0:
415
+ continue
416
+ self.sentence_raw.append(sentence_raw)
417
+ self.ref_ids.append(r)
418
+
419
+ # print(len(self.sentence_raw))
420
+
421
+ def get_classes(self):
422
+ return self.classes
423
+
424
+ def __len__(self):
425
+ return len(self.ref_ids)
426
+
427
+ def __getitem__(self, index):
428
+ this_ref_id = self.ref_ids[index]
429
+ this_img_id = self.refer.getImgIds(this_ref_id)
430
+ this_img = self.refer.Imgs[this_img_id[0]]
431
+ # print(this_ref_id, this_img_id)
432
+ # print(len(self.ref_ids))
433
+ img_path = os.path.join(self.refer.IMAGE_DIR, this_img['file_name'])
434
+ img = Image.open(img_path).convert('RGB')
435
+ ref = self.refer.loadRefs(this_ref_id)
436
+ # print("ref",ref)
437
+
438
+ ref_mask_ann = self.refer.getMaskByRef(ref[0])
439
+ if type(ref_mask_ann) == list:
440
+ ref_mask_ann = ref_mask_ann[0]
441
+ ref_mask = ref_mask_ann['mask']
442
+ annot = np.zeros(ref_mask.shape)
443
+ annot[ref_mask == 1] = 1
444
+
445
+ target = Image.fromarray(annot.astype(np.uint8), mode='P')
446
+ # print(np.array(target), np.unique(np.array(target).flatten()))
447
+ if self.image_transforms is not None:
448
+ # resize, from PIL to tensor, and mean and std normalization
449
+ img = self.image_transforms(img)
450
+ # target = self.target_transforms(target)
451
+ target = torch.as_tensor(np.array(target, copy=True))
452
+ # target = target.permute((2, 0, 1))
453
+ sentence = self.sentence_raw[index]
454
+
455
+ return img, img_path, target, sentence
data/pascal459.py ADDED
@@ -0,0 +1,998 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Pascal-459 Dataset."""
17
+
18
+ import os
19
+ import numpy as np
20
+ from PIL import Image
21
+ # pylint: disable=g-importing-member
22
+ from torch.utils.data import Dataset
23
+
24
+
25
+ PASCAL_459_CLASSES = [
26
+ 'accordion',
27
+ 'aeroplane',
28
+ 'air conditioner',
29
+ 'antenna',
30
+ 'artillery',
31
+ 'ashtray',
32
+ 'atrium',
33
+ 'baby carriage',
34
+ 'bag',
35
+ 'ball',
36
+ 'balloon',
37
+ 'bamboo weaving',
38
+ 'barrel',
39
+ 'baseball bat',
40
+ 'basket',
41
+ 'basketball backboard',
42
+ 'bathtub',
43
+ 'bed',
44
+ 'bedclothes',
45
+ 'beer',
46
+ 'bell',
47
+ 'bench',
48
+ 'bicycle',
49
+ 'binoculars',
50
+ 'bird',
51
+ 'bird cage',
52
+ 'bird feeder',
53
+ 'bird nest',
54
+ 'blackboard',
55
+ 'board',
56
+ 'boat',
57
+ 'bone',
58
+ 'book',
59
+ 'bottle',
60
+ 'bottle opener',
61
+ 'bowl',
62
+ 'box',
63
+ 'bracelet',
64
+ 'brick',
65
+ 'bridge',
66
+ 'broom',
67
+ 'brush',
68
+ 'bucket',
69
+ 'building',
70
+ 'bus',
71
+ 'cabinet',
72
+ 'cabinet door',
73
+ 'cage',
74
+ 'cake',
75
+ 'calculator',
76
+ 'calendar',
77
+ 'camel',
78
+ 'camera',
79
+ 'camera lens',
80
+ 'can',
81
+ 'candle',
82
+ 'candle holder',
83
+ 'cap',
84
+ 'car',
85
+ 'card',
86
+ 'cart',
87
+ 'case',
88
+ 'casette recorder',
89
+ 'cash register',
90
+ 'cat',
91
+ 'cd',
92
+ 'cd player',
93
+ 'ceiling',
94
+ 'cell phone',
95
+ 'cello',
96
+ 'chain',
97
+ 'chair',
98
+ 'chessboard',
99
+ 'chicken',
100
+ 'chopstick',
101
+ 'clip',
102
+ 'clippers',
103
+ 'clock',
104
+ 'closet',
105
+ 'cloth',
106
+ 'clothes tree',
107
+ 'coffee',
108
+ 'coffee machine',
109
+ 'comb',
110
+ 'computer',
111
+ 'concrete',
112
+ 'cone',
113
+ 'container',
114
+ 'control booth',
115
+ 'controller',
116
+ 'cooker',
117
+ 'copying machine',
118
+ 'coral',
119
+ 'cork',
120
+ 'corkscrew',
121
+ 'counter',
122
+ 'court',
123
+ 'cow',
124
+ 'crabstick',
125
+ 'crane',
126
+ 'crate',
127
+ 'cross',
128
+ 'crutch',
129
+ 'cup',
130
+ 'curtain',
131
+ 'cushion',
132
+ 'cutting board',
133
+ 'dais',
134
+ 'disc',
135
+ 'disc case',
136
+ 'dishwasher',
137
+ 'dock',
138
+ 'dog',
139
+ 'dolphin',
140
+ 'door',
141
+ 'drainer',
142
+ 'dray',
143
+ 'drink dispenser',
144
+ 'drinking machine',
145
+ 'drop',
146
+ 'drug',
147
+ 'drum',
148
+ 'drum kit',
149
+ 'duck',
150
+ 'dumbbell',
151
+ 'earphone',
152
+ 'earrings',
153
+ 'egg',
154
+ 'electric fan',
155
+ 'electric iron',
156
+ 'electric pot',
157
+ 'electric saw',
158
+ 'electronic keyboard',
159
+ 'engine',
160
+ 'envelope',
161
+ 'equipment',
162
+ 'escalator',
163
+ 'exhibition booth',
164
+ 'extinguisher',
165
+ 'eyeglass',
166
+ 'fan',
167
+ 'faucet',
168
+ 'fax machine',
169
+ 'fence',
170
+ 'ferris wheel',
171
+ 'fire extinguisher',
172
+ 'fire hydrant',
173
+ 'fire place',
174
+ 'fish',
175
+ 'fish tank',
176
+ 'fishbowl',
177
+ 'fishing net',
178
+ 'fishing pole',
179
+ 'flag',
180
+ 'flagstaff',
181
+ 'flame',
182
+ 'flashlight',
183
+ 'floor',
184
+ 'flower',
185
+ 'fly',
186
+ 'foam',
187
+ 'food',
188
+ 'footbridge',
189
+ 'forceps',
190
+ 'fork',
191
+ 'forklift',
192
+ 'fountain',
193
+ 'fox',
194
+ 'frame',
195
+ 'fridge',
196
+ 'frog',
197
+ 'fruit',
198
+ 'funnel',
199
+ 'furnace',
200
+ 'game controller',
201
+ 'game machine',
202
+ 'gas cylinder',
203
+ 'gas hood',
204
+ 'gas stove',
205
+ 'gift box',
206
+ 'glass',
207
+ 'glass marble',
208
+ 'globe',
209
+ 'glove',
210
+ 'goal',
211
+ 'grandstand',
212
+ 'grass',
213
+ 'gravestone',
214
+ 'ground',
215
+ 'guardrail',
216
+ 'guitar',
217
+ 'gun',
218
+ 'hammer',
219
+ 'hand cart',
220
+ 'handle',
221
+ 'handrail',
222
+ 'hanger',
223
+ 'hard disk drive',
224
+ 'hat',
225
+ 'hay',
226
+ 'headphone',
227
+ 'heater',
228
+ 'helicopter',
229
+ 'helmet',
230
+ 'holder',
231
+ 'hook',
232
+ 'horse',
233
+ 'horse-drawn carriage',
234
+ 'hot-air balloon',
235
+ 'hydrovalve',
236
+ 'ice',
237
+ 'inflator pump',
238
+ 'ipod',
239
+ 'iron',
240
+ 'ironing board',
241
+ 'jar',
242
+ 'kart',
243
+ 'kettle',
244
+ 'key',
245
+ 'keyboard',
246
+ 'kitchen range',
247
+ 'kite',
248
+ 'knife',
249
+ 'knife block',
250
+ 'ladder',
251
+ 'ladder truck',
252
+ 'ladle',
253
+ 'laptop',
254
+ 'leaves',
255
+ 'lid',
256
+ 'life buoy',
257
+ 'light',
258
+ 'light bulb',
259
+ 'lighter',
260
+ 'line',
261
+ 'lion',
262
+ 'lobster',
263
+ 'lock',
264
+ 'machine',
265
+ 'mailbox',
266
+ 'mannequin',
267
+ 'map',
268
+ 'mask',
269
+ 'mat',
270
+ 'match book',
271
+ 'mattress',
272
+ 'menu',
273
+ 'metal',
274
+ 'meter box',
275
+ 'microphone',
276
+ 'microwave',
277
+ 'mirror',
278
+ 'missile',
279
+ 'model',
280
+ 'money',
281
+ 'monkey',
282
+ 'mop',
283
+ 'motorbike',
284
+ 'mountain',
285
+ 'mouse',
286
+ 'mouse pad',
287
+ 'musical instrument',
288
+ 'napkin',
289
+ 'net',
290
+ 'newspaper',
291
+ 'oar',
292
+ 'ornament',
293
+ 'outlet',
294
+ 'oven',
295
+ 'oxygen bottle',
296
+ 'pack',
297
+ 'pan',
298
+ 'paper',
299
+ 'paper box',
300
+ 'paper cutter',
301
+ 'parachute',
302
+ 'parasol',
303
+ 'parterre',
304
+ 'patio',
305
+ 'pelage',
306
+ 'pen',
307
+ 'pen container',
308
+ 'pencil',
309
+ 'person',
310
+ 'photo',
311
+ 'piano',
312
+ 'picture',
313
+ 'pig',
314
+ 'pillar',
315
+ 'pillow',
316
+ 'pipe',
317
+ 'pitcher',
318
+ 'plant',
319
+ 'plastic',
320
+ 'plate',
321
+ 'platform',
322
+ 'player',
323
+ 'playground',
324
+ 'pliers',
325
+ 'plume',
326
+ 'poker',
327
+ 'poker chip',
328
+ 'pole',
329
+ 'pool table',
330
+ 'postcard',
331
+ 'poster',
332
+ 'pot',
333
+ 'pottedplant',
334
+ 'printer',
335
+ 'projector',
336
+ 'pumpkin',
337
+ 'rabbit',
338
+ 'racket',
339
+ 'radiator',
340
+ 'radio',
341
+ 'rail',
342
+ 'rake',
343
+ 'ramp',
344
+ 'range hood',
345
+ 'receiver',
346
+ 'recorder',
347
+ 'recreational machines',
348
+ 'remote control',
349
+ 'road',
350
+ 'robot',
351
+ 'rock',
352
+ 'rocket',
353
+ 'rocking horse',
354
+ 'rope',
355
+ 'rug',
356
+ 'ruler',
357
+ 'runway',
358
+ 'saddle',
359
+ 'sand',
360
+ 'saw',
361
+ 'scale',
362
+ 'scanner',
363
+ 'scissors',
364
+ 'scoop',
365
+ 'screen',
366
+ 'screwdriver',
367
+ 'sculpture',
368
+ 'scythe',
369
+ 'sewer',
370
+ 'sewing machine',
371
+ 'shed',
372
+ 'sheep',
373
+ 'shell',
374
+ 'shelves',
375
+ 'shoe',
376
+ 'shopping cart',
377
+ 'shovel',
378
+ 'sidecar',
379
+ 'sidewalk',
380
+ 'sign',
381
+ 'signal light',
382
+ 'sink',
383
+ 'skateboard',
384
+ 'ski',
385
+ 'sky',
386
+ 'sled',
387
+ 'slippers',
388
+ 'smoke',
389
+ 'snail',
390
+ 'snake',
391
+ 'snow',
392
+ 'snowmobiles',
393
+ 'sofa',
394
+ 'spanner',
395
+ 'spatula',
396
+ 'speaker',
397
+ 'speed bump',
398
+ 'spice container',
399
+ 'spoon',
400
+ 'sprayer',
401
+ 'squirrel',
402
+ 'stage',
403
+ 'stair',
404
+ 'stapler',
405
+ 'stick',
406
+ 'sticky note',
407
+ 'stone',
408
+ 'stool',
409
+ 'stove',
410
+ 'straw',
411
+ 'stretcher',
412
+ 'sun',
413
+ 'sunglass',
414
+ 'sunshade',
415
+ 'surveillance camera',
416
+ 'swan',
417
+ 'sweeper',
418
+ 'swim ring',
419
+ 'swimming pool',
420
+ 'swing',
421
+ 'switch',
422
+ 'table',
423
+ 'tableware',
424
+ 'tank',
425
+ 'tap',
426
+ 'tape',
427
+ 'tarp',
428
+ 'telephone',
429
+ 'telephone booth',
430
+ 'tent',
431
+ 'tire',
432
+ 'toaster',
433
+ 'toilet',
434
+ 'tong',
435
+ 'tool',
436
+ 'toothbrush',
437
+ 'towel',
438
+ 'toy',
439
+ 'toy car',
440
+ 'track',
441
+ 'train',
442
+ 'trampoline',
443
+ 'trash bin',
444
+ 'tray',
445
+ 'tree',
446
+ 'tricycle',
447
+ 'tripod',
448
+ 'trophy',
449
+ 'truck',
450
+ 'tube',
451
+ 'turtle',
452
+ 'tvmonitor',
453
+ 'tweezers',
454
+ 'typewriter',
455
+ 'umbrella',
456
+ 'unknown',
457
+ 'vacuum cleaner',
458
+ 'vending machine',
459
+ 'video camera',
460
+ 'video game console',
461
+ 'video player',
462
+ 'video tape',
463
+ 'violin',
464
+ 'wakeboard',
465
+ 'wall',
466
+ 'wallet',
467
+ 'wardrobe',
468
+ 'washing machine',
469
+ 'watch',
470
+ 'water',
471
+ 'water dispenser',
472
+ 'water pipe',
473
+ 'water skate board',
474
+ 'watermelon',
475
+ 'whale',
476
+ 'wharf',
477
+ 'wheel',
478
+ 'wheelchair',
479
+ 'window',
480
+ 'window blinds',
481
+ 'wineglass',
482
+ 'wire',
483
+ 'wood',
484
+ 'wool',
485
+ ]
486
+
487
+ PASCAL_459_CLASSE_ID = list(range(459))
488
+
489
+
490
+ PASCAL_459_STUFF_CLASS = [
491
+ 'atrium',
492
+ 'ceiling',
493
+ 'concrete',
494
+ 'coral',
495
+ 'court',
496
+ 'dock',
497
+ 'floor',
498
+ 'foam',
499
+ 'grass',
500
+ 'ground',
501
+ 'ice',
502
+ 'leaves',
503
+ 'mountain',
504
+ 'parterre',
505
+ 'patio',
506
+ 'road',
507
+ 'rock',
508
+ 'rug',
509
+ 'sand',
510
+ 'sky',
511
+ 'snow',
512
+ 'stone',
513
+ 'sun',
514
+ 'wall',
515
+ 'water',
516
+ 'wood',
517
+ ]
518
+
519
+ PASCAL_459_THING_CLASS = [
520
+ 'accordion',
521
+ 'aeroplane',
522
+ 'air conditioner',
523
+ 'antenna',
524
+ 'artillery',
525
+ 'ashtray',
526
+ 'baby carriage',
527
+ 'bag',
528
+ 'ball',
529
+ 'balloon',
530
+ 'bamboo weaving',
531
+ 'barrel',
532
+ 'baseball bat',
533
+ 'basket',
534
+ 'basketball backboard',
535
+ 'bathtub',
536
+ 'bed',
537
+ 'bedclothes',
538
+ 'beer',
539
+ 'bell',
540
+ 'bench',
541
+ 'bicycle',
542
+ 'binoculars',
543
+ 'bird',
544
+ 'bird cage',
545
+ 'bird feeder',
546
+ 'bird nest',
547
+ 'blackboard',
548
+ 'board',
549
+ 'boat',
550
+ 'bone',
551
+ 'book',
552
+ 'bottle',
553
+ 'bottle opener',
554
+ 'bowl',
555
+ 'box',
556
+ 'bracelet',
557
+ 'brick',
558
+ 'bridge',
559
+ 'broom',
560
+ 'brush',
561
+ 'bucket',
562
+ 'building',
563
+ 'bus',
564
+ 'cabinet',
565
+ 'cabinet door',
566
+ 'cage',
567
+ 'cake',
568
+ 'calculator',
569
+ 'calendar',
570
+ 'camel',
571
+ 'camera',
572
+ 'camera lens',
573
+ 'can',
574
+ 'candle',
575
+ 'candle holder',
576
+ 'cap',
577
+ 'car',
578
+ 'card',
579
+ 'cart',
580
+ 'case',
581
+ 'casette recorder',
582
+ 'cash register',
583
+ 'cat',
584
+ 'cd',
585
+ 'cd player',
586
+ 'cell phone',
587
+ 'cello',
588
+ 'chain',
589
+ 'chair',
590
+ 'chessboard',
591
+ 'chicken',
592
+ 'chopstick',
593
+ 'clip',
594
+ 'clippers',
595
+ 'clock',
596
+ 'closet',
597
+ 'cloth',
598
+ 'clothes tree',
599
+ 'coffee',
600
+ 'coffee machine',
601
+ 'comb',
602
+ 'computer',
603
+ 'cone',
604
+ 'container',
605
+ 'control booth',
606
+ 'controller',
607
+ 'cooker',
608
+ 'copying machine',
609
+ 'cork',
610
+ 'corkscrew',
611
+ 'counter',
612
+ 'cow',
613
+ 'crabstick',
614
+ 'crane',
615
+ 'crate',
616
+ 'cross',
617
+ 'crutch',
618
+ 'cup',
619
+ 'curtain',
620
+ 'cushion',
621
+ 'cutting board',
622
+ 'dais',
623
+ 'disc',
624
+ 'disc case',
625
+ 'dishwasher',
626
+ 'dog',
627
+ 'dolphin',
628
+ 'door',
629
+ 'drainer',
630
+ 'dray',
631
+ 'drink dispenser',
632
+ 'drinking machine',
633
+ 'drop',
634
+ 'drug',
635
+ 'drum',
636
+ 'drum kit',
637
+ 'duck',
638
+ 'dumbbell',
639
+ 'earphone',
640
+ 'earrings',
641
+ 'egg',
642
+ 'electric fan',
643
+ 'electric iron',
644
+ 'electric pot',
645
+ 'electric saw',
646
+ 'electronic keyboard',
647
+ 'engine',
648
+ 'envelope',
649
+ 'equipment',
650
+ 'escalator',
651
+ 'exhibition booth',
652
+ 'extinguisher',
653
+ 'eyeglass',
654
+ 'fan',
655
+ 'faucet',
656
+ 'fax machine',
657
+ 'fence',
658
+ 'ferris wheel',
659
+ 'fire extinguisher',
660
+ 'fire hydrant',
661
+ 'fire place',
662
+ 'fish',
663
+ 'fish tank',
664
+ 'fishbowl',
665
+ 'fishing net',
666
+ 'fishing pole',
667
+ 'flag',
668
+ 'flagstaff',
669
+ 'flame',
670
+ 'flashlight',
671
+ 'flower',
672
+ 'fly',
673
+ 'food',
674
+ 'footbridge',
675
+ 'forceps',
676
+ 'fork',
677
+ 'forklift',
678
+ 'fountain',
679
+ 'fox',
680
+ 'frame',
681
+ 'fridge',
682
+ 'frog',
683
+ 'fruit',
684
+ 'funnel',
685
+ 'furnace',
686
+ 'game controller',
687
+ 'game machine',
688
+ 'gas cylinder',
689
+ 'gas hood',
690
+ 'gas stove',
691
+ 'gift box',
692
+ 'glass',
693
+ 'glass marble',
694
+ 'globe',
695
+ 'glove',
696
+ 'goal',
697
+ 'grandstand',
698
+ 'gravestone',
699
+ 'guardrail',
700
+ 'guitar',
701
+ 'gun',
702
+ 'hammer',
703
+ 'hand cart',
704
+ 'handle',
705
+ 'handrail',
706
+ 'hanger',
707
+ 'hard disk drive',
708
+ 'hat',
709
+ 'hay',
710
+ 'headphone',
711
+ 'heater',
712
+ 'helicopter',
713
+ 'helmet',
714
+ 'holder',
715
+ 'hook',
716
+ 'horse',
717
+ 'horse-drawn carriage',
718
+ 'hot-air balloon',
719
+ 'hydrovalve',
720
+ 'inflator pump',
721
+ 'ipod',
722
+ 'iron',
723
+ 'ironing board',
724
+ 'jar',
725
+ 'kart',
726
+ 'kettle',
727
+ 'key',
728
+ 'keyboard',
729
+ 'kitchen range',
730
+ 'kite',
731
+ 'knife',
732
+ 'knife block',
733
+ 'ladder',
734
+ 'ladder truck',
735
+ 'ladle',
736
+ 'laptop',
737
+ 'lid',
738
+ 'life buoy',
739
+ 'light',
740
+ 'light bulb',
741
+ 'lighter',
742
+ 'line',
743
+ 'lion',
744
+ 'lobster',
745
+ 'lock',
746
+ 'machine',
747
+ 'mailbox',
748
+ 'mannequin',
749
+ 'map',
750
+ 'mask',
751
+ 'mat',
752
+ 'match book',
753
+ 'mattress',
754
+ 'menu',
755
+ 'metal',
756
+ 'meter box',
757
+ 'microphone',
758
+ 'microwave',
759
+ 'mirror',
760
+ 'missile',
761
+ 'model',
762
+ 'money',
763
+ 'monkey',
764
+ 'mop',
765
+ 'motorbike',
766
+ 'mouse',
767
+ 'mouse pad',
768
+ 'musical instrument',
769
+ 'napkin',
770
+ 'net',
771
+ 'newspaper',
772
+ 'oar',
773
+ 'ornament',
774
+ 'outlet',
775
+ 'oven',
776
+ 'oxygen bottle',
777
+ 'pack',
778
+ 'pan',
779
+ 'paper',
780
+ 'paper box',
781
+ 'paper cutter',
782
+ 'parachute',
783
+ 'parasol',
784
+ 'pelage',
785
+ 'pen',
786
+ 'pen container',
787
+ 'pencil',
788
+ 'person',
789
+ 'photo',
790
+ 'piano',
791
+ 'picture',
792
+ 'pig',
793
+ 'pillar',
794
+ 'pillow',
795
+ 'pipe',
796
+ 'pitcher',
797
+ 'plant',
798
+ 'plastic',
799
+ 'plate',
800
+ 'platform',
801
+ 'player',
802
+ 'playground',
803
+ 'pliers',
804
+ 'plume',
805
+ 'poker',
806
+ 'poker chip',
807
+ 'pole',
808
+ 'pool table',
809
+ 'postcard',
810
+ 'poster',
811
+ 'pot',
812
+ 'pottedplant',
813
+ 'printer',
814
+ 'projector',
815
+ 'pumpkin',
816
+ 'rabbit',
817
+ 'racket',
818
+ 'radiator',
819
+ 'radio',
820
+ 'rail',
821
+ 'rake',
822
+ 'ramp',
823
+ 'range hood',
824
+ 'receiver',
825
+ 'recorder',
826
+ 'recreational machines',
827
+ 'remote control',
828
+ 'robot',
829
+ 'rocket',
830
+ 'rocking horse',
831
+ 'rope',
832
+ 'ruler',
833
+ 'runway',
834
+ 'saddle',
835
+ 'saw',
836
+ 'scale',
837
+ 'scanner',
838
+ 'scissors',
839
+ 'scoop',
840
+ 'screen',
841
+ 'screwdriver',
842
+ 'sculpture',
843
+ 'scythe',
844
+ 'sewer',
845
+ 'sewing machine',
846
+ 'shed',
847
+ 'sheep',
848
+ 'shell',
849
+ 'shelves',
850
+ 'shoe',
851
+ 'shopping cart',
852
+ 'shovel',
853
+ 'sidecar',
854
+ 'sidewalk',
855
+ 'sign',
856
+ 'signal light',
857
+ 'sink',
858
+ 'skateboard',
859
+ 'ski',
860
+ 'sled',
861
+ 'slippers',
862
+ 'smoke',
863
+ 'snail',
864
+ 'snake',
865
+ 'snowmobiles',
866
+ 'sofa',
867
+ 'spanner',
868
+ 'spatula',
869
+ 'speaker',
870
+ 'speed bump',
871
+ 'spice container',
872
+ 'spoon',
873
+ 'sprayer',
874
+ 'squirrel',
875
+ 'stage',
876
+ 'stair',
877
+ 'stapler',
878
+ 'stick',
879
+ 'sticky note',
880
+ 'stool',
881
+ 'stove',
882
+ 'straw',
883
+ 'stretcher',
884
+ 'sunglass',
885
+ 'sunshade',
886
+ 'surveillance camera',
887
+ 'swan',
888
+ 'sweeper',
889
+ 'swim ring',
890
+ 'swimming pool',
891
+ 'swing',
892
+ 'switch',
893
+ 'table',
894
+ 'tableware',
895
+ 'tank',
896
+ 'tap',
897
+ 'tape',
898
+ 'tarp',
899
+ 'telephone',
900
+ 'telephone booth',
901
+ 'tent',
902
+ 'tire',
903
+ 'toaster',
904
+ 'toilet',
905
+ 'tong',
906
+ 'tool',
907
+ 'toothbrush',
908
+ 'towel',
909
+ 'toy',
910
+ 'toy car',
911
+ 'track',
912
+ 'train',
913
+ 'trampoline',
914
+ 'trash bin',
915
+ 'tray',
916
+ 'tree',
917
+ 'tricycle',
918
+ 'tripod',
919
+ 'trophy',
920
+ 'truck',
921
+ 'tube',
922
+ 'turtle',
923
+ 'tvmonitor',
924
+ 'tweezers',
925
+ 'typewriter',
926
+ 'umbrella',
927
+ 'unknown',
928
+ 'vacuum cleaner',
929
+ 'vending machine',
930
+ 'video camera',
931
+ 'video game console',
932
+ 'video player',
933
+ 'video tape',
934
+ 'violin',
935
+ 'wakeboard',
936
+ 'wallet',
937
+ 'wardrobe',
938
+ 'washing machine',
939
+ 'watch',
940
+ 'water dispenser',
941
+ 'water pipe',
942
+ 'water skate board',
943
+ 'watermelon',
944
+ 'whale',
945
+ 'wharf',
946
+ 'wheel',
947
+ 'wheelchair',
948
+ 'window',
949
+ 'window blinds',
950
+ 'wineglass',
951
+ 'wire',
952
+ 'wool',
953
+ ]
954
+
955
+ PASCAL_459_STUFF_CLASS_ID = [
956
+ 6, 67, 85, 92, 96, 111, 157, 160, 186, 188, 210, 228, 258, 277, 278, 323,
957
+ 325, 329, 333, 359, 365, 381, 386, 439, 444, 457,
958
+ ]
959
+
960
+ PASCAL_459_THING_CLASS_ID = [
961
+ i for i in range(459) if i not in PASCAL_459_STUFF_CLASS_ID
962
+ ]
963
+
964
+
965
+ class Pascal459Dataset(Dataset):
966
+ """PASCAL 459 dataset."""
967
+
968
+ def __init__(self, root, split='validation', transform=None):
969
+ super(Pascal459Dataset, self).__init__()
970
+ self.root = root
971
+ self.split = split
972
+ self.transforms = transform
973
+ self.image_dir = os.path.join(root, 'images', split)
974
+ self.mask_dir = os.path.join(root, 'annotations_ctx459', split)
975
+ self.images = os.listdir(self.image_dir)
976
+
977
+ def __getitem__(self, index):
978
+ image_path = os.path.join(self.image_dir, self.images[index])
979
+ image = Image.open(image_path).convert('RGB')
980
+ target = (
981
+ np.asarray(
982
+ Image.open(
983
+ os.path.join(
984
+ self.mask_dir, self.images[index].replace('jpg', 'tif')
985
+ )
986
+ ),
987
+ dtype=np.int32,
988
+ )
989
+ + 1
990
+ )
991
+
992
+ if self.transforms:
993
+ image = self.transforms(image)
994
+
995
+ return image, image_path, target, index
996
+
997
+ def __len__(self):
998
+ return len(self.images)
data/preprocess.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Preprocess for referring datasets.
17
+
18
+ Adapted from
19
+ https://github.com/yz93/LAVT-RIS/blob/main/data/dataset_refer_bert.py
20
+ """
21
+ # pylint: disable=all
22
+ from refer.refer import REFER
23
+ from torch.utils import data
24
+
25
+
26
+ class ReferDataset(data.Dataset):
27
+ """Refer dataset."""
28
+
29
+ def __init__(
30
+ self,
31
+ root,
32
+ dataset='refcoco',
33
+ splitBy='unc',
34
+ image_transforms=None,
35
+ target_transforms=None,
36
+ split='train',
37
+ eval_mode=False,
38
+ ):
39
+
40
+ self.classes = []
41
+ self.image_transforms = image_transforms
42
+ self.target_transforms = target_transforms
43
+ self.split = split
44
+ self.refer = REFER(root, dataset=dataset, splitBy=splitBy)
45
+
46
+ ref_ids = self.refer.getRefIds(split=self.split)
47
+ img_ids = self.refer.getImgIds(ref_ids)
48
+
49
+ all_imgs = self.refer.Imgs
50
+ self.imgs = list(all_imgs[i] for i in img_ids)
51
+ self.ref_ids = ref_ids
52
+ print(len(ref_ids))
53
+ print(len(self.imgs))
54
+ # print(self.imgs)
55
+ self.sentence_raw = []
56
+
57
+ self.eval_mode = eval_mode
58
+ # if we are testing on a dataset, test all sentences of an object;
59
+ # o/w, we are validating during training, randomly sample one sentence for
60
+ # efficiency
61
+ for r in ref_ids:
62
+ ref = self.refer.Refs[r]
63
+ ref_sentences = []
64
+ for el, _ in zip(ref['sentences'], ref['sent_ids']):
65
+ sentence_raw = el['raw']
66
+ ref_sentences.append(sentence_raw)
67
+
68
+ self.sentence_raw.append(ref_sentences)
69
+ # print(len(self.sentence_raw))
70
+
71
+ def get_classes(self):
72
+ return self.classes
73
+
74
+ def __len__(self):
75
+ return len(self.imgs)
76
+
77
+ def __getitem__(self, index):
78
+ this_img_id = self.imgs[index]['id']
79
+ this_ref_ids = self.refer.getRefIds(this_img_id)
80
+ this_img = self.refer.Imgs[this_img_id]
81
+ refs = [self.refer.loadRefs(this_ref_id) for this_ref_id in this_ref_ids]
82
+
83
+ batch_sentences = {}
84
+ # batch_targets = {}
85
+ for ref in refs:
86
+ # Get sentence
87
+ sentence_lis = []
88
+ for el, _ in zip(ref[0]['sentences'], ref[0]['sent_ids']):
89
+ sentence_raw = el['raw']
90
+ sentence_lis.append(sentence_raw)
91
+ batch_sentences.update({ref[0]['ref_id']: sentence_lis})
92
+
93
+ return [this_img['file_name']], batch_sentences
94
+
95
+ def get_ref(self):
96
+ name_lis = []
97
+ for i in range(len(self.ref_ids)):
98
+ rid = self.ref_ids[i]
99
+ # print(rid)
100
+ ref = self.refer.loadRefs(rid)
101
+ if ref[0]['file_name'] == '':
102
+ print(1)
103
+ # print(ref[0]['file_name'])
104
+ # if ref[0]['file_name'] in name_lis:
105
+ # print("md")
106
+ name_lis.append(ref[0]['file_name'])
107
+ print(ref[0]['file_name'])
108
+ # print(name_lis)
109
+ print(len(name_lis))
110
+ print(len(list(set(name_lis))))
data/refcoco.py ADDED
@@ -0,0 +1,449 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """RefCOCO dataset."""
17
+
18
+ # Adapted from
19
+ # https://github.com/yz93/LAVT-RIS/blob/main/data/dataset_refer_bert.py
20
+ # pylint: disable=all
21
+ import itertools
22
+ import json
23
+ import os
24
+ import os.path as osp
25
+ import pickle as pickle
26
+ import sys
27
+ import time
28
+ # pylint: disable=g-importing-member
29
+ from matplotlib.collections import PatchCollection
30
+ from matplotlib.patches import Polygon
31
+ from matplotlib.patches import Rectangle
32
+ import matplotlib.pyplot as plt
33
+ import numpy as np
34
+ from PIL import Image
35
+ from pycocotools import mask
36
+ import skimage.io as io
37
+ import torch
38
+ import torch.utils.data as data
39
+ from torchvision import transforms
40
+
41
+
42
+ class REFER:
43
+ """RefCOCO dataset."""
44
+
45
+ def __init__(self, data_root, dataset='refcoco', splitBy='unc', split='val'):
46
+ # provide data_root folder which contains refclef, refcoco, refcoco+ and refcocog
47
+ # also provide dataset name and splitBy information
48
+ # e.g., dataset = 'refcoco', splitBy = 'unc'
49
+ print('loading dataset %s into memory...' % dataset)
50
+ if dataset == 'refcocog':
51
+ print('Split by {}!'.format(splitBy))
52
+ self.ROOT_DIR = osp.abspath(osp.dirname(__file__))
53
+ self.DATA_DIR = osp.join(data_root, dataset)
54
+ if dataset in ['refcoco', 'refcoco+', 'refcocog']:
55
+ self.IMAGE_DIR = osp.join(data_root, 'images/mscoco/images/train2014')
56
+ elif dataset == 'refclef':
57
+ self.IMAGE_DIR = osp.join(data_root, 'images/saiapr_tc-12')
58
+ else:
59
+ print('No refer dataset is called [%s]' % dataset)
60
+ sys.exit()
61
+
62
+ # load refs from data/dataset/refs(dataset).json
63
+ tic = time.time()
64
+ ref_file = osp.join(self.DATA_DIR, 'refs(' + splitBy + ').p')
65
+ self.data = {}
66
+ self.data['dataset'] = dataset
67
+ # f = open(ref_file, 'r')
68
+ self.data['refs'] = pickle.load(open(ref_file, 'rb'))
69
+
70
+ # load annotations from data/dataset/instances.json
71
+ instances_file = osp.join(self.DATA_DIR, 'instances.json')
72
+ instances = json.load(open(instances_file, 'r'))
73
+ self.data['images'] = instances['images']
74
+ self.data['annotations'] = instances['annotations']
75
+ self.data['categories'] = instances['categories']
76
+
77
+ # create index
78
+ self.createIndex()
79
+ self.split = split
80
+ print('DONE (t=%.2fs)' % (time.time() - tic))
81
+
82
+ def createIndex(self):
83
+ # create sets of mapping
84
+ # 1) Refs: {ref_id: ref}
85
+ # 2) Anns: {ann_id: ann}
86
+ # 3) Imgs: {image_id: image}
87
+ # 4) Cats: {category_id: category_name}
88
+ # 5) Sents: {sent_id: sent}
89
+ # 6) imgToRefs: {image_id: refs}
90
+ # 7) imgToAnns: {image_id: anns}
91
+ # 8) refToAnn: {ref_id: ann}
92
+ # 9) annToRef: {ann_id: ref}
93
+ # 10) catToRefs: {category_id: refs}
94
+ # 11) sentToRef: {sent_id: ref}
95
+ # 12) sentToTokens: {sent_id: tokens}
96
+ print('creating index...')
97
+ # fetch info from instances
98
+ Anns, Imgs, Cats, imgToAnns = {}, {}, {}, {}
99
+ for ann in self.data['annotations']:
100
+ Anns[ann['id']] = ann
101
+ imgToAnns[ann['image_id']] = imgToAnns.get(ann['image_id'], []) + [ann]
102
+ for img in self.data['images']:
103
+ Imgs[img['id']] = img
104
+ for cat in self.data['categories']:
105
+ Cats[cat['id']] = cat['name']
106
+
107
+ # fetch info from refs
108
+ Refs, imgToRefs, refToAnn, annToRef, catToRefs = {}, {}, {}, {}, {}
109
+ Sents, sentToRef, sentToTokens = {}, {}, {}
110
+ for ref in self.data['refs']:
111
+ # ids
112
+ ref_id = ref['ref_id']
113
+ ann_id = ref['ann_id']
114
+ category_id = ref['category_id']
115
+ image_id = ref['image_id']
116
+
117
+ # add mapping related to ref
118
+ Refs[ref_id] = ref
119
+ imgToRefs[image_id] = imgToRefs.get(image_id, []) + [ref]
120
+ catToRefs[category_id] = catToRefs.get(category_id, []) + [ref]
121
+ refToAnn[ref_id] = Anns[ann_id]
122
+ annToRef[ann_id] = ref
123
+
124
+ # add mapping of sent
125
+ for sent in ref['sentences']:
126
+ Sents[sent['sent_id']] = sent
127
+ sentToRef[sent['sent_id']] = ref
128
+ sentToTokens[sent['sent_id']] = sent['tokens']
129
+
130
+ # create class members
131
+ self.Refs = Refs
132
+ self.Anns = Anns
133
+ self.Imgs = Imgs
134
+ self.Cats = Cats
135
+ self.Sents = Sents
136
+ self.imgToRefs = imgToRefs
137
+ self.imgToAnns = imgToAnns
138
+ self.refToAnn = refToAnn
139
+ self.annToRef = annToRef
140
+ self.catToRefs = catToRefs
141
+ self.sentToRef = sentToRef
142
+ self.sentToTokens = sentToTokens
143
+ print('index created.')
144
+
145
+ def getRefIds(self, image_ids=[], cat_ids=[], ref_ids=[], split=''):
146
+ image_ids = image_ids if type(image_ids) == list else [image_ids]
147
+ cat_ids = cat_ids if type(cat_ids) == list else [cat_ids]
148
+ ref_ids = ref_ids if type(ref_ids) == list else [ref_ids]
149
+
150
+ if len(image_ids) == len(cat_ids) == len(ref_ids) == len(split) == 0:
151
+ refs = self.data['refs']
152
+ else:
153
+ if not len(image_ids) == 0:
154
+ refs = [self.imgToRefs[image_id] for image_id in image_ids]
155
+ ref_ids = []
156
+ for img_ref in refs:
157
+ ref_ids.extend([ref['ref_id'] for ref in img_ref])
158
+ return ref_ids
159
+ else:
160
+ refs = self.data['refs']
161
+ if not len(cat_ids) == 0:
162
+ refs = [ref for ref in refs if ref['category_id'] in cat_ids]
163
+ if not len(ref_ids) == 0:
164
+ refs = [ref for ref in refs if ref['ref_id'] in ref_ids]
165
+ if not len(split) == 0:
166
+ if split in ['testA', 'testB', 'testC']:
167
+ # we also consider testAB, testBC, ...
168
+ refs = [ref for ref in refs if split[-1] in ref['split']]
169
+ elif split in ['testAB', 'testBC', 'testAC']:
170
+ # rarely used I guess...
171
+ refs = [ref for ref in refs if ref['split'] == split]
172
+ elif split == 'test':
173
+ refs = [ref for ref in refs if 'test' in ref['split']]
174
+ elif split == 'train' or split == 'val':
175
+ refs = [ref for ref in refs if ref['split'] == split]
176
+ else:
177
+ print('No such split [%s]' % split)
178
+ sys.exit()
179
+ ref_ids = [ref['ref_id'] for ref in refs]
180
+ return ref_ids
181
+
182
+ def getAnnIds(self, image_ids=[], cat_ids=[], ref_ids=[]):
183
+ image_ids = image_ids if type(image_ids) == list else [image_ids]
184
+ cat_ids = cat_ids if type(cat_ids) == list else [cat_ids]
185
+ ref_ids = ref_ids if type(ref_ids) == list else [ref_ids]
186
+
187
+ if len(image_ids) == len(cat_ids) == len(ref_ids) == 0:
188
+ ann_ids = [ann['id'] for ann in self.data['annotations']]
189
+ else:
190
+ if not len(image_ids) == 0:
191
+ lists = [
192
+ self.imgToAnns[image_id]
193
+ for image_id in image_ids
194
+ if image_id in self.imgToAnns
195
+ ] # list of [anns]
196
+ anns = list(itertools.chain.from_iterable(lists))
197
+ else:
198
+ anns = self.data['annotations']
199
+ if not len(cat_ids) == 0:
200
+ anns = [ann for ann in anns if ann['category_id'] in cat_ids]
201
+ ann_ids = [ann['id'] for ann in anns]
202
+ # if not len(ref_ids) == 0:
203
+ # ids = set(ann_ids).intersection(
204
+ # set([self.Refs[ref_id]['ann_id'] for ref_id in ref_ids])
205
+ # )
206
+ return ann_ids
207
+
208
+ def getImgIds(self, ref_ids=[]):
209
+ ref_ids = ref_ids if type(ref_ids) == list else [ref_ids]
210
+
211
+ if not len(ref_ids) == 0:
212
+ image_ids = list(
213
+ set([self.Refs[ref_id]['image_id'] for ref_id in ref_ids])
214
+ )
215
+ else:
216
+ image_ids = self.Imgs.keys()
217
+ return image_ids
218
+
219
+ def getCatIds(self):
220
+ return self.Cats.keys()
221
+
222
+ def loadRefs(self, ref_ids=[]):
223
+ if type(ref_ids) == list:
224
+ return [self.Refs[ref_id] for ref_id in ref_ids]
225
+ elif type(ref_ids) == int:
226
+ return [self.Refs[ref_ids]]
227
+
228
+ def loadAnns(self, ann_ids=[]):
229
+ if type(ann_ids) == list:
230
+ return [self.Anns[ann_id] for ann_id in ann_ids]
231
+ elif type(ann_ids) == int or type(ann_ids) == unicode:
232
+ return [self.Anns[ann_ids]]
233
+
234
+ def loadImgs(self, image_ids=[]):
235
+ if type(image_ids) == list:
236
+ return [self.Imgs[image_id] for image_id in image_ids]
237
+ elif type(image_ids) == int:
238
+ return [self.Imgs[image_ids]]
239
+
240
+ def loadCats(self, cat_ids=[]):
241
+ if type(cat_ids) == list:
242
+ return [self.Cats[cat_id] for cat_id in cat_ids]
243
+ elif type(cat_ids) == int:
244
+ return [self.Cats[cat_ids]]
245
+
246
+ def getRefBox(self, ref_id):
247
+ # ref = self.Refs[ref_id]
248
+ ann = self.refToAnn[ref_id]
249
+ return ann['bbox'] # [x, y, w, h]
250
+
251
+ def showRef(self, ref, seg_box='seg'):
252
+ ax = plt.gca()
253
+ # show image
254
+ image = self.Imgs[ref['image_id']]
255
+ I = io.imread(osp.join(self.IMAGE_DIR, image['file_name']))
256
+ ax.imshow(I)
257
+ # show refer expression
258
+ for sid, sent in enumerate(ref['sentences']):
259
+ print('%s. %s' % (sid + 1, sent['sent']))
260
+ # show segmentations
261
+ if seg_box == 'seg':
262
+ ann_id = ref['ann_id']
263
+ ann = self.Anns[ann_id]
264
+ polygons = []
265
+ color = []
266
+ c = 'none'
267
+ if type(ann['segmentation'][0]) == list:
268
+ # polygon used for refcoco*
269
+ for seg in ann['segmentation']:
270
+ poly = np.array(seg).reshape((len(seg) / 2, 2))
271
+ polygons.append(Polygon(poly, True, alpha=0.4))
272
+ color.append(c)
273
+ p = PatchCollection(
274
+ polygons,
275
+ facecolors=color,
276
+ edgecolors=(1, 1, 0, 0),
277
+ linewidths=3,
278
+ alpha=1,
279
+ )
280
+ ax.add_collection(p) # thick yellow polygon
281
+ p = PatchCollection(
282
+ polygons,
283
+ facecolors=color,
284
+ edgecolors=(1, 0, 0, 0),
285
+ linewidths=1,
286
+ alpha=1,
287
+ )
288
+ ax.add_collection(p) # thin red polygon
289
+ else:
290
+ # mask used for refclef
291
+ rle = ann['segmentation']
292
+ m = mask.decode(rle)
293
+ img = np.ones((m.shape[0], m.shape[1], 3))
294
+ color_mask = np.array([2.0, 166.0, 101.0]) / 255
295
+ for i in range(3):
296
+ img[:, :, i] = color_mask[i]
297
+ ax.imshow(np.dstack((img, m * 0.5)))
298
+ # show bounding-box
299
+ elif seg_box == 'box':
300
+ # ann_id = ref['ann_id']
301
+ # ann = self.Anns[ann_id]
302
+ bbox = self.getRefBox(ref['ref_id'])
303
+ box_plot = Rectangle(
304
+ (bbox[0], bbox[1]),
305
+ bbox[2],
306
+ bbox[3],
307
+ fill=False,
308
+ edgecolor='green',
309
+ linewidth=3,
310
+ )
311
+ ax.add_patch(box_plot)
312
+
313
+ def getMask(self, ref):
314
+ # return mask, area and mask-center
315
+ ann = self.refToAnn[ref['ref_id']]
316
+ image = self.Imgs[ref['image_id']]
317
+
318
+ if type(ann['segmentation'][0]) == list: # polygon
319
+ rle = mask.frPyObjects(
320
+ ann['segmentation'], image['height'], image['width']
321
+ )
322
+ else:
323
+ rle = ann['segmentation']
324
+
325
+ m = mask.decode(rle)
326
+ # sometimes there are multiple binary map (corresponding to multiple segs)
327
+ m = np.sum(m, axis=2)
328
+ m = m.astype(np.uint8) # convert to np.uint8
329
+ # compute area
330
+ area = sum(mask.area(rle)) # should be close to ann['area']
331
+ return {'mask': m, 'area': area}
332
+
333
+ def showMask(self, ref):
334
+ M = self.getMask(ref)
335
+ msk = M['mask']
336
+ ax = plt.gca()
337
+ ax.imshow(msk)
338
+
339
+
340
+ class ReferDataset(data.Dataset):
341
+
342
+ def __init__(
343
+ self,
344
+ root,
345
+ dataset='refcoco',
346
+ splitBy='google',
347
+ image_transforms=None,
348
+ target_transforms=None,
349
+ split='train',
350
+ eval_mode=False,
351
+ ):
352
+
353
+ self.classes = []
354
+ self.image_transforms = image_transforms
355
+ self.target_transforms = target_transforms
356
+ self.split = split
357
+ self.refer = REFER(root, dataset=dataset, splitBy=splitBy)
358
+
359
+ ref_ids = self.refer.getRefIds(split=self.split)
360
+ img_ids = self.refer.getImgIds(ref_ids)
361
+
362
+ all_imgs = self.refer.Imgs
363
+ self.imgs = list(all_imgs[i] for i in img_ids)
364
+ self.ref_ids = ref_ids
365
+ # print(len(ref_ids))
366
+ # print(len(self.imgs))
367
+ self.sentence_raw = []
368
+
369
+ self.eval_mode = eval_mode
370
+ # if we are testing on a dataset, test all sentences of an object;
371
+ # o/w, we are validating during training, randomly sample one sentence
372
+ # for efficiency
373
+ for r in ref_ids:
374
+ ref = self.refer.Refs[r]
375
+ # ref_sentences = []
376
+ # for i, (el, sent_id) in enumerate(zip(ref['sentences'],
377
+ # ref['sent_ids'])):
378
+ for el in ref['sentences']:
379
+ sentence_raw = el['raw']
380
+ ref_sentences.append(sentence_raw)
381
+ self.sentence_raw.append(ref_sentences)
382
+ # print(len(self.sentence_raw))
383
+
384
+ def get_classes(self):
385
+ return self.classes
386
+
387
+ def __len__(self):
388
+ return len(self.ref_ids)
389
+
390
+ def __getitem__(self, index):
391
+ this_ref_id = self.ref_ids[index]
392
+ this_img_id = self.refer.getImgIds(this_ref_id)
393
+ this_img = self.refer.Imgs[this_img_id[0]]
394
+ # print(this_ref_id, this_img_id)
395
+ # print(len(self.ref_ids))
396
+ img_path = os.path.join(self.refer.IMAGE_DIR, this_img['file_name'])
397
+ img = Image.open(img_path).convert('RGB')
398
+ ref = self.refer.loadRefs(this_ref_id)
399
+ # print("ref",ref)
400
+
401
+ ref_mask = np.array(self.refer.getMask(ref[0])['mask'])
402
+ annot = np.zeros(ref_mask.shape)
403
+ annot[ref_mask == 1] = 1
404
+
405
+ target = Image.fromarray(annot.astype(np.uint8), mode='P')
406
+ # print(np.array(target), np.unique(np.array(target).flatten()))
407
+ if self.image_transforms is not None:
408
+ # resize, from PIL to tensor, and mean and std normalization
409
+ img = self.image_transforms(img)
410
+ # target = self.target_transforms(target)
411
+ target = torch.as_tensor(np.array(target, copy=True))
412
+ # target = target.permute((2, 0, 1))
413
+ sentence = self.sentence_raw[index]
414
+
415
+ return img, img_path, target, sentence
416
+
417
+
418
+ if __name__ == '__main__':
419
+
420
+ def get_transform():
421
+ transform = [
422
+ transforms.Resize((224, 224)),
423
+ transforms.ToTensor(),
424
+ # T.Normalize(mean=[0.485, 0.456, 0.406],
425
+ # std=[0.229, 0.224, 0.225])
426
+ ]
427
+
428
+ return transforms.Compose(transform)
429
+
430
+ transform = get_transform()
431
+ dataset_test = ReferDataset(
432
+ root='/datasets/refseg',
433
+ dataset='refcoco+',
434
+ splitBy='google',
435
+ image_transforms=transform,
436
+ target_transforms=transform,
437
+ split='train',
438
+ eval_mode=False,
439
+ )
440
+ print('loaded')
441
+ test_sampler = torch.utils.data.SequentialSampler(dataset_test)
442
+ data_loader_test = torch.utils.data.DataLoader(
443
+ dataset_test, batch_size=1, sampler=test_sampler, num_workers=1
444
+ )
445
+
446
+ for img, target, sentence in data_loader_test:
447
+ # print(type(img),type(target))
448
+ print(sentence)
449
+ break
data/voc.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Pascal VOC dataset."""
17
+
18
+ import numpy as np
19
+ from PIL import Image
20
+ # pylint: disable=g-importing-member
21
+ from torchvision.datasets import VOCSegmentation
22
+
23
+ CLASS2ID = {
24
+ 'Background': 0,
25
+ 'Aero plane': 1,
26
+ 'Bicycle': 2,
27
+ 'Bird': 3,
28
+ 'Boat': 4,
29
+ 'Bottle': 5,
30
+ 'Bus': 6,
31
+ 'Car': 7,
32
+ 'Cat': 8,
33
+ 'Chair': 9,
34
+ 'Cow': 10,
35
+ 'Dining table': 11,
36
+ 'Dog': 12,
37
+ 'Horse': 13,
38
+ 'Motorbike': 14,
39
+ 'Person': 15,
40
+ 'Potted plant': 16,
41
+ 'Sheep': 17,
42
+ 'Sofa': 18,
43
+ 'Train': 19,
44
+ 'Tv/Monitor': 20,
45
+ # ... add more entries as needed
46
+ 'Border': 255,
47
+ }
48
+
49
+
50
+ VOC_CLASSES = [
51
+ 'aeroplane',
52
+ 'bicycle',
53
+ 'bird avian',
54
+ 'boat',
55
+ 'bottle',
56
+ 'bus',
57
+ 'car',
58
+ 'cat',
59
+ 'chair seat',
60
+ 'cow',
61
+ 'diningtable',
62
+ 'dog',
63
+ 'horse',
64
+ 'motorbike',
65
+ 'person with clothes,people,human',
66
+ 'pottedplant',
67
+ 'sheep',
68
+ 'sofa',
69
+ 'train',
70
+ 'tvmonitor screen',
71
+ ]
72
+
73
+
74
+ BACKGROUND_CATEGORY = [
75
+ 'ground',
76
+ 'land',
77
+ 'grass',
78
+ 'tree',
79
+ 'building',
80
+ 'wall',
81
+ 'sky',
82
+ 'lake',
83
+ 'water',
84
+ 'river',
85
+ 'sea',
86
+ 'keyboard',
87
+ 'helmet',
88
+ 'cloud',
89
+ 'house',
90
+ 'mountain',
91
+ 'ocean',
92
+ 'road',
93
+ 'rock',
94
+ 'street',
95
+ 'valley',
96
+ 'bridge',
97
+ 'sign',
98
+ ]
99
+
100
+
101
+ class VOCDataset(VOCSegmentation):
102
+ """Pascal VOC dataset."""
103
+
104
+ def __init__(
105
+ self,
106
+ root='/datasets/jianhaoy/PASCAL/',
107
+ year='2012',
108
+ split='val',
109
+ target_transform=None,
110
+ download=False,
111
+ transform=None,
112
+ ):
113
+ super(VOCDataset, self).__init__(
114
+ root=root,
115
+ image_set=split,
116
+ year=year,
117
+ target_transform=transform,
118
+ download=download,
119
+ transform=transform,
120
+ )
121
+ self.idx_to_class = {val: key for (key, val) in CLASS2ID.items()}
122
+
123
+ def __getitem__(self, index):
124
+ image_path = self.images[index]
125
+ image = Image.open(image_path).convert('RGB')
126
+ target = np.asarray(Image.open(self.masks[index]), dtype=np.int32)
127
+
128
+ _, unique_values = self.process_target(np.array(target))
129
+ classnames = [self.idx_to_class[idx] for idx in unique_values]
130
+
131
+ if self.transforms:
132
+ image = self.transform(image)
133
+
134
+ return image, str(image_path), target, classnames
135
+
136
+ def process_target(self, arr):
137
+ # Set values 0 and 255 to 1
138
+ arr[(arr == 0) | (arr == 255)] = 0
139
+
140
+ # Find unique values (excluding 0 and 255)
141
+ unique_values = np.unique(arr[(arr != 0) & (arr != 255)])
142
+
143
+ # Create separate masks for each unique value
144
+ masks = [arr == value for value in unique_values]
145
+ masks = [Image.fromarray(arr) for arr in masks]
146
+ masks = [self.target_transform(arr) for arr in masks]
147
+
148
+ return masks, unique_values
demo.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Run a demo of the CaR model on a single image."""
2
+
3
+ import numpy as np
4
+ import os
5
+ import argparse
6
+ from functools import reduce
7
+ import PIL.Image as Image
8
+ import torch
9
+ from modeling.model import CaR
10
+ from utils.utils import Config, load_yaml
11
+ import matplotlib.pyplot as plt
12
+ import colorsys
13
+ from modeling.post_process.post_process import (
14
+ match_masks,
15
+ generate_masks_from_sam,
16
+ )
17
+ from sam.sam import SAMPipeline
18
+ from sam.utils import build_sam_config
19
+ import random
20
+ import time
21
+
22
+
23
+ def generate_distinct_colors(n):
24
+ colors = []
25
+ # generate a random number from 0 to 1
26
+ random_color_bias = random.random()
27
+
28
+ for i in range(n):
29
+ hue = float(i) / n
30
+ hue += random_color_bias
31
+ hue = hue % 1.0
32
+ rgb = colorsys.hsv_to_rgb(hue, 1.0, 1.0)
33
+ # Convert RGB values from [0, 1] range to [0, 255]
34
+ colors.append(tuple(int(val * 255) for val in rgb))
35
+ return colors
36
+
37
+
38
+ def overlap_masks(masks):
39
+ """
40
+ Overlap masks to generate a single mask for visualization.
41
+
42
+ Parameters:
43
+ - masks: list of np.arrays of shape (H, W) representing binary masks
44
+ for each class.
45
+
46
+ Returns:
47
+ - overlap_mask: list of np.array of shape (H, W) that have no overlaps
48
+ """
49
+ overlap_mask = torch.zeros_like(masks[0])
50
+ for mask_idx, mask in enumerate(masks):
51
+ overlap_mask[mask > 0] = mask_idx + 1
52
+
53
+ clean_masks = [
54
+ overlap_mask == mask_idx + 1 for mask_idx in range(len(masks))
55
+ ]
56
+ clean_masks = torch.stack(clean_masks, dim=0)
57
+
58
+ return clean_masks
59
+
60
+
61
+ def visualize_segmentation(
62
+ image, masks, class_names, alpha=0.45, y_list=None, x_list=None
63
+ ):
64
+ """
65
+ Visualize segmentation masks on an image.
66
+
67
+ Parameters:
68
+ - image: np.array of shape (H, W, 3) representing the RGB image
69
+ - masks: list of np.arrays of shape (H, W) representing binary masks
70
+ for each class.
71
+ - class_names: list of strings representing names of each class
72
+ - alpha: float, transparency level of masks on the image
73
+
74
+ Returns:
75
+ - visualization: plt.figure object
76
+ """
77
+ # Create a figure and axis
78
+ fig, ax = plt.subplots(1, figsize=(12, 9))
79
+ # Display the image
80
+ # ax.imshow(image)
81
+ # Generate distinct colors for each mask
82
+ final_mask = np.zeros(
83
+ (masks.shape[1], masks.shape[2], 3), dtype=np.float32
84
+ )
85
+ colors = generate_distinct_colors(len(class_names))
86
+ idx = 0
87
+ for mask, color, class_name in zip(masks, colors, class_names):
88
+ # Overlay the mask
89
+ final_mask += np.dstack([mask * c for c in color])
90
+ # Find a representative point (e.g., centroid) for placing the label
91
+ if y_list is None or x_list is None:
92
+ y, x = np.argwhere(mask).mean(axis=0)
93
+ else:
94
+ y, x = y_list[idx], x_list[idx]
95
+ ax.text(
96
+ x,
97
+ y,
98
+ class_name,
99
+ color="white",
100
+ fontsize=36,
101
+ va="center",
102
+ ha="center",
103
+ bbox=dict(facecolor="black", alpha=0.7, edgecolor="none"),
104
+ )
105
+
106
+ idx += 1
107
+
108
+ final_image = image * (1 - alpha) + final_mask * alpha
109
+ final_image = final_image.astype(np.uint8)
110
+ ax.imshow(final_image)
111
+ # Remove axis ticks and labels
112
+ ax.axis("off")
113
+ return fig
114
+
115
+
116
+ def get_sam_masks(config, image_path, masks, img_sam=None, pipeline=None):
117
+ print("generating sam masks online")
118
+ mask_tensor, mask_list = generate_masks_from_sam(
119
+ image_path,
120
+ save_path="./",
121
+ pipeline=pipeline,
122
+ img_sam=img_sam,
123
+ visualize=False,
124
+ )
125
+ mask_tensor = mask_tensor.to(masks.device)
126
+ # only conduct sam on masks that is not all zero
127
+ attn_map, mask_ids = [], []
128
+ for mask_id, mask in enumerate(masks):
129
+ if torch.sum(mask) > 0:
130
+ attn_map.append(mask.unsqueeze(0))
131
+ mask_ids.append(mask_id)
132
+ matched_masks = [
133
+ match_masks(
134
+ mask_tensor,
135
+ attn,
136
+ mask_list,
137
+ iom_thres=config.car.iom_thres,
138
+ min_pred_threshold=config.sam.min_pred_threshold,
139
+ )
140
+ for attn in attn_map
141
+ ]
142
+ for matched_mask, mask_id in zip(matched_masks, mask_ids):
143
+ sam_masks = np.array([item["segmentation"] for item in matched_mask])
144
+ sam_mask = np.any(sam_masks, axis=0)
145
+ masks[mask_id] = torch.from_numpy(sam_mask).to(masks.device)
146
+ return masks
147
+
148
+
149
+ def load_sam(config, sam_device):
150
+ sam_checkpoint, model_type = build_sam_config(config)
151
+ pipelines = SAMPipeline(
152
+ sam_checkpoint,
153
+ model_type,
154
+ device=sam_device,
155
+ points_per_side=config.sam.points_per_side,
156
+ pred_iou_thresh=config.sam.pred_iou_thresh,
157
+ stability_score_thresh=config.sam.stability_score_thresh,
158
+ box_nms_thresh=config.sam.box_nms_thresh,
159
+ )
160
+ return pipelines
161
+
162
+
163
+ if __name__ == "__main__":
164
+ parser = argparse.ArgumentParser("CaR")
165
+ # default arguments
166
+
167
+ # additional arguments
168
+ parser.add_argument(
169
+ "--output_path", type=str, default="", help="path to save outputs"
170
+ )
171
+ parser.add_argument(
172
+ "--cfg-path",
173
+ default="configs/voc_test.yaml",
174
+ help="path to configuration file.",
175
+ )
176
+ args = parser.parse_args()
177
+
178
+ cfg = Config(**load_yaml(args.cfg_path))
179
+ device = "cuda" if torch.cuda.is_available() else "cpu"
180
+ # device = 'cpu'
181
+ folder_name = reduce(
182
+ lambda x, y: x.replace(" ", "_") + "_" + y, cfg.image_caption
183
+ )
184
+ if len(folder_name) > 20:
185
+ folder_name = folder_name[:20]
186
+
187
+ car_model = CaR(
188
+ cfg, visualize=True, seg_mode=cfg.test.seg_mode, device=device
189
+ )
190
+
191
+ sam_pipeline = load_sam(cfg, device)
192
+
193
+ img = Image.open(cfg.image_path).convert("RGB")
194
+ import pdb; pdb.set_trace()
195
+ # resize image by dividing 2 if the size is larger than 1000
196
+ if img.size[0] > 1000:
197
+ img = img.resize((img.size[0] // 3, img.size[1] // 3))
198
+
199
+ label_space = cfg.image_caption
200
+ pseudo_masks, scores, _ = car_model(img, label_space)
201
+
202
+
203
+ if not cfg.test.use_pseudo:
204
+ t1 = time.time()
205
+ pseudo_masks = get_sam_masks(
206
+ cfg,
207
+ cfg.image_path,
208
+ pseudo_masks,
209
+ img_sam=np.array(img),
210
+ pipeline=sam_pipeline,
211
+ )
212
+ pseudo_masks = overlap_masks(pseudo_masks)
213
+ t2 = time.time()
214
+ print(f"sam time: {t2 - t1}")
215
+
216
+ # visualize segmentation masks
217
+ demo_fig = visualize_segmentation(
218
+ np.array(img),
219
+ pseudo_masks.detach().cpu().numpy(),
220
+ label_space,
221
+ )
222
+ save_path = f"vis_results/{folder_name}"
223
+ if not os.path.exists(save_path):
224
+ os.makedirs(save_path)
225
+ demo_fig.savefig(os.path.join(save_path, "demo.png"), bbox_inches="tight")
226
+
227
+ print(f"results saved to {save_path}.")
evaluate.py ADDED
@@ -0,0 +1,511 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Evaluate CaR on segmentation benchmarks."""
17
+ # pylint: disable=g-importing-member
18
+ import argparse
19
+ import numpy as np
20
+ import torch
21
+ from torch.utils import tensorboard
22
+ import torch.utils.data
23
+ from torch.utils.data import Subset
24
+ import torchvision.transforms as T
25
+
26
+ # pylint: disable=g-bad-import-order
27
+ from modeling.model.car import CaR
28
+ from sam.utils import build_sam_config
29
+ from utils.utils import Config
30
+ from utils.utils import load_yaml
31
+ from utils.utils import MetricLogger
32
+ from utils.utils import SmoothedValue
33
+ from utils.inference_pipeline import inference_car
34
+ from utils.merge_mask import merge_masks_simple
35
+
36
+ # Datasets
37
+ # pylint: disable=g-multiple-import
38
+ from data.ade import ADE_THING_CLASS, ADE_STUFF_CLASS, ADE_THING_CLASS_ID, ADE_STUFF_CLASS_ID, ADEDataset
39
+ from data.ade847 import ADE_847_THING_CLASS_ID, ADE_847_STUFF_CLASS_ID, ADE_847_THING_CLASS, ADE_847_STUFF_CLASS, ADE847Dataset
40
+ from data.coco import COCO_OBJECT_CLASSES, COCODataset
41
+ from data.context import PASCAL_CONTEXT_STUFF_CLASS_ID, PASCAL_CONTEXT_THING_CLASS_ID, PASCAL_CONTEXT_STUFF_CLASS, PASCAL_CONTEXT_THING_CLASS, CONTEXTDataset
42
+ from data.gres import GReferDataset
43
+ from data.pascal459 import PASCAL_459_THING_CLASS_ID, PASCAL_459_STUFF_CLASS_ID, PASCAL_459_THING_CLASS, PASCAL_459_STUFF_CLASS, Pascal459Dataset
44
+ from data.refcoco import ReferDataset
45
+ from data.voc import VOC_CLASSES, VOCDataset
46
+
47
+
48
+ IMAGE_WIDTH, IMAGE_HEIGHT = 512, 512
49
+
50
+ # set random seed
51
+ torch.manual_seed(0)
52
+ np.random.seed(0)
53
+
54
+
55
+ def get_dataset(cfg, ds_name, split, transform, data_root=None):
56
+ """Get dataset."""
57
+ data_args = dict(root=data_root) if data_root is not None else {}
58
+ if 'refcoco' in ds_name:
59
+ splitby = cfg.test.splitby if hasattr(cfg.test, 'splitby') else 'unc'
60
+ ds = ReferDataset(
61
+ dataset=ds_name,
62
+ splitBy=splitby,
63
+ split=split,
64
+ image_transforms=transform,
65
+ target_transforms=transform,
66
+ eval_mode=True,
67
+ prompts_augment=cfg.test.prompts_augment,
68
+ **data_args,
69
+ )
70
+ elif ds_name == 'gres':
71
+ ds = GReferDataset(split=split, transform=transform, **data_args)
72
+ elif ds_name == 'voc':
73
+ ds = VOCDataset(
74
+ year='2012',
75
+ split=split,
76
+ transform=transform,
77
+ target_transform=transform,
78
+ **data_args,
79
+ )
80
+
81
+ elif ds_name == 'cocostuff':
82
+ ds = COCODataset(transform=transform, **data_args)
83
+
84
+ elif ds_name == 'context':
85
+ ds = CONTEXTDataset(
86
+ year='2010', transform=transform, split=split, **data_args
87
+ )
88
+ elif ds_name == 'ade':
89
+ ds = ADEDataset(split=split, transform=transform, **data_args)
90
+ elif ds_name == 'pascal_459':
91
+ ds = Pascal459Dataset(split=split, transform=transform, **data_args)
92
+ elif ds_name == 'ade_847':
93
+ ds = ADE847Dataset(split=split, transform=transform, **data_args)
94
+ else:
95
+ raise ValueError(f'Dataset {ds_name} not implemented')
96
+ return ds
97
+
98
+
99
+ def get_transform():
100
+ transforms = [
101
+ T.Resize((IMAGE_WIDTH, IMAGE_HEIGHT)),
102
+ T.ToTensor(),
103
+ ]
104
+
105
+ return T.Compose(transforms)
106
+
107
+
108
+ def assign_label(
109
+ all_masks,
110
+ scores,
111
+ stuff_masks=None,
112
+ stuff_scores=None,
113
+ id_mapping=None,
114
+ stuff_id_mapping=None,
115
+ ):
116
+ """Assign labels."""
117
+ label_preds = np.zeros_like(all_masks[0]).astype(np.int32)
118
+ if stuff_masks is not None:
119
+ sorted_idxs = np.argsort(stuff_scores.detach().cpu().numpy())
120
+ stuff_masks = stuff_masks[sorted_idxs]
121
+ stuff_scores = stuff_scores.detach().cpu().numpy()[sorted_idxs]
122
+ for sorted_idx, mask, score in zip(sorted_idxs, stuff_masks, stuff_scores):
123
+ if score > 0:
124
+ # convert mask to boolean
125
+ mask = mask > 0.5
126
+ # assign label
127
+ if stuff_id_mapping is not None:
128
+ label_preds[mask] = stuff_id_mapping[sorted_idx] + 1
129
+ else:
130
+ label_preds[mask] = sorted_idx + 1
131
+ sorted_idxs = np.argsort(scores.detach().cpu().numpy())
132
+ all_masks = all_masks[sorted_idxs]
133
+ scores = scores.detach().cpu().numpy()[sorted_idxs]
134
+ for sorted_idx, mask, score in zip(sorted_idxs, all_masks, scores):
135
+ if score > 0:
136
+ # convert mask to boolean
137
+ mask = mask > 0.5
138
+ # assign label
139
+ if id_mapping is not None:
140
+ label_preds[mask] = id_mapping[sorted_idx] + 1
141
+ else:
142
+ label_preds[mask] = sorted_idx + 1
143
+
144
+ return label_preds
145
+
146
+
147
+ def eval_semantic(
148
+ label_space,
149
+ algo,
150
+ cfg,
151
+ model,
152
+ image_path,
153
+ stuff_label_space=None,
154
+ sam_pipeline=None,
155
+ ):
156
+ """Semantic segmentation evaluation."""
157
+
158
+ if label_space is None:
159
+ raise ValueError(
160
+ 'label_space must be provided for semantic segmentation evaluation'
161
+ )
162
+ if algo == 'car':
163
+ all_masks, scores = inference_car(
164
+ cfg, model, image_path, label_space, sam_pipeline=sam_pipeline
165
+ )
166
+ if stuff_label_space is not None:
167
+ if cfg.test.ds_name == 'context':
168
+ thing_id_mapping = PASCAL_CONTEXT_THING_CLASS_ID
169
+ stuff_id_mapping = PASCAL_CONTEXT_STUFF_CLASS_ID
170
+ elif cfg.test.ds_name == 'ade':
171
+ thing_id_mapping = ADE_THING_CLASS_ID
172
+ stuff_id_mapping = ADE_STUFF_CLASS_ID
173
+ elif cfg.test.ds_name == 'pascal_459':
174
+ thing_id_mapping = PASCAL_459_THING_CLASS_ID
175
+ stuff_id_mapping = PASCAL_459_STUFF_CLASS_ID
176
+ elif cfg.test.ds_name == 'ade_847':
177
+ thing_id_mapping = ADE_847_THING_CLASS_ID
178
+ stuff_id_mapping = ADE_847_STUFF_CLASS_ID
179
+ else:
180
+ raise ValueError(f'Dataset {cfg.test.ds_name} not supported')
181
+
182
+ model.mask_generator.set_bg_cls(label_space)
183
+ model.set_visual_prompt_type(cfg.car.stuff_visual_prompt_type)
184
+ model.set_bg_factor(cfg.car.stuff_bg_factor)
185
+ stuff_masks, stuff_scores = inference_car(
186
+ cfg, model, image_path, stuff_label_space, sam_pipeline=sam_pipeline
187
+ )
188
+ model.mask_generator.set_bg_cls(cfg.car.bg_cls)
189
+ model.set_visual_prompt_type(cfg.car.visual_prompt_type)
190
+ model.set_bg_factor(cfg.car.bg_factor)
191
+ all_masks = all_masks.detach().cpu().numpy()
192
+ stuff_masks = stuff_masks.detach().cpu().numpy()
193
+ label_preds = assign_label(
194
+ all_masks,
195
+ scores,
196
+ stuff_masks=stuff_masks,
197
+ stuff_scores=stuff_scores,
198
+ id_mapping=thing_id_mapping,
199
+ stuff_id_mapping=stuff_id_mapping,
200
+ )
201
+ else:
202
+ all_masks = all_masks.detach().cpu().numpy()
203
+ label_preds = assign_label(all_masks, scores)
204
+ return label_preds.squeeze()
205
+ else:
206
+ raise NotImplementedError(f'algo {algo} not implemented')
207
+
208
+
209
+ def _fast_hist(label_true, label_pred, n_class=21):
210
+ mask = (label_true >= 0) & (label_true < n_class)
211
+ hist = np.bincount(
212
+ n_class * label_true[mask].astype(int) + label_pred[mask],
213
+ minlength=n_class**2,
214
+ ).reshape(n_class, n_class)
215
+ return hist
216
+
217
+
218
+ def semantic_iou(label_trues, label_preds, n_class=21, ignore_background=False):
219
+ """Semantic segmentation IOU."""
220
+
221
+ hist = np.zeros((n_class, n_class))
222
+ for lt, lp in zip(label_trues, label_preds):
223
+ hist += _fast_hist(lt.flatten(), lp.flatten(), n_class)
224
+ if ignore_background:
225
+ hist = hist[1:, 1:]
226
+ acc = np.diag(hist).sum() / hist.sum()
227
+ acc_cls = np.diag(hist) / hist.sum(axis=1)
228
+ acc_cls = np.nanmean(acc_cls)
229
+ iu = np.diag(hist) / (hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist))
230
+ valid = hist.sum(axis=1) > 0 # added
231
+ if valid.sum() == 0:
232
+ mean_iu = 0
233
+ else:
234
+ mean_iu = np.nanmean(iu[valid])
235
+ freq = hist.sum(axis=1) / hist.sum()
236
+ fwavacc = (freq[freq > 0] * iu[freq > 0]).sum()
237
+ if ignore_background:
238
+ cls_iu = dict(zip(range(1, n_class), iu))
239
+ else:
240
+ cls_iu = dict(zip(range(n_class), iu))
241
+
242
+ return {
243
+ 'Pixel Accuracy': acc,
244
+ 'Mean Accuracy': acc_cls,
245
+ 'Frequency Weighted IoU': fwavacc,
246
+ 'mIoU': mean_iu,
247
+ 'Class IoU': cls_iu,
248
+ }
249
+
250
+
251
+ def evaluate(
252
+ data_loader,
253
+ cfg,
254
+ model,
255
+ test_cfg,
256
+ label_space=None,
257
+ stuff_label_space=None,
258
+ sam_pipeline=None,
259
+ ):
260
+ """Run evaluation."""
261
+
262
+ if (
263
+ test_cfg.ds_name
264
+ not in ['voc', 'cocostuff', 'context', 'ade', 'pascal_459', 'ade_847']
265
+ and test_cfg.seg_mode == 'semantic'
266
+ ):
267
+ raise ValueError((
268
+ 'Semantic segmentation evaluation is only implemented for voc, '
269
+ 'context, coco object, ade, pascal459, ade847 dataset'
270
+ ))
271
+
272
+ metric_logger = MetricLogger(delimiter=' ')
273
+ metric_logger.add_meter(
274
+ 'mIoU', SmoothedValue(window_size=1, fmt='{value:.4f} ({global_avg:.4f})')
275
+ )
276
+ # evaluation variables
277
+ cum_i, cum_u = 0, 0
278
+ eval_seg_iou_list = [0.5, 0.6, 0.7, 0.8, 0.9]
279
+ seg_correct = np.zeros(len(eval_seg_iou_list), dtype=np.int32)
280
+ seg_total = 0
281
+ mean_iou = []
282
+ header = 'Test:'
283
+
284
+ # all_masks = []
285
+ label_preds, label_gts = [], []
286
+ print(len(data_loader))
287
+ cc = 0
288
+ use_tensorboard = False
289
+ if hasattr(cfg.test, 'use_tensorboard'):
290
+ use_tensorboard = cfg.test.use_tensorboard
291
+
292
+ if use_tensorboard:
293
+ writer = tensorboard.SummaryWriter(log_dir=cfg.test.output_path)
294
+ for data in metric_logger.log_every(data_loader, 1, header):
295
+ _, image_paths, target_list, sentences_list = data
296
+ # print(type(target_lis))
297
+
298
+ if not isinstance(target_list, list):
299
+ target_list, sentences_list = [target_list], [sentences_list]
300
+ for target, sentences in zip(target_list, sentences_list):
301
+ image_path = image_paths[0]
302
+ # print(image_path)
303
+ if test_cfg.seg_mode == 'refer':
304
+ all_masks, all_scores = inference_car(
305
+ cfg, model, image_path, sentences, sam_pipeline=sam_pipeline
306
+ )
307
+ # final_mask = merge_masks(all_masks, *target.shape[1:])
308
+ final_mask = merge_masks_simple(
309
+ all_masks, *target.shape[1:], scores=all_scores
310
+ )
311
+ intersection, union, cur_iou = compute_iou(final_mask, target)
312
+ # cur_iou = IoU(final_mask, target, 0)
313
+ metric_logger.update(mIoU=cur_iou)
314
+ mean_iou.append(cur_iou)
315
+ if use_tensorboard:
316
+ writer.add_scalar('Mean IoU', cur_iou, cc)
317
+ cum_i += intersection
318
+ cum_u += union
319
+ for n_eval_iou in range(len(eval_seg_iou_list)):
320
+ eval_seg_iou = eval_seg_iou_list[n_eval_iou]
321
+ seg_correct[n_eval_iou] += cur_iou >= eval_seg_iou
322
+ seg_total += 1
323
+ elif test_cfg.seg_mode == 'semantic':
324
+ # torch.cuda.empty_cache()
325
+ label_pred = eval_semantic(
326
+ label_space,
327
+ test_cfg.algo,
328
+ cfg,
329
+ model,
330
+ image_path,
331
+ stuff_label_space,
332
+ )
333
+ label_gt = target.squeeze().cpu().numpy()
334
+ cur_iou = semantic_iou(
335
+ [label_gt],
336
+ [label_pred],
337
+ n_class=cfg.test.n_class,
338
+ ignore_background=cfg.test.ignore_background,
339
+ )['mIoU']
340
+ metric_logger.update(mIoU=cur_iou)
341
+ label_preds.append(label_pred)
342
+ label_gts.append(label_gt)
343
+
344
+ cc += 1
345
+
346
+ if test_cfg.seg_mode == 'refer':
347
+ mean_iou = np.array(mean_iou)
348
+ m_iou = np.mean(mean_iou)
349
+ if use_tensorboard:
350
+ writer.add_scalar('mIoU', m_iou.item(), len(data_loader))
351
+ print('Final results:')
352
+ print('Mean IoU is %.2f\n' % (m_iou * 100.0))
353
+ results_str = ''
354
+ for n_eval_iou in range(len(eval_seg_iou_list)):
355
+ results_str += ' precision@%s = %.2f\n' % (
356
+ str(eval_seg_iou_list[n_eval_iou]),
357
+ seg_correct[n_eval_iou] * 100.0 / seg_total,
358
+ )
359
+ o_iou = cum_i * 100.0 / cum_u
360
+ results_str += ' overall IoU = %.2f\n' % o_iou
361
+ if use_tensorboard:
362
+ writer.add_scalar('oIoU', o_iou, 0)
363
+ print(results_str)
364
+ elif test_cfg.seg_mode == 'semantic':
365
+ iou_score = semantic_iou(
366
+ label_gts,
367
+ label_preds,
368
+ n_class=cfg.test.n_class,
369
+ ignore_background=cfg.test.ignore_background,
370
+ )
371
+ if use_tensorboard:
372
+ writer.add_scalar('mIoU', iou_score['mIoU'].item(), len(data_loader))
373
+
374
+ print(iou_score)
375
+ if use_tensorboard:
376
+ writer.close()
377
+
378
+
379
+ def compute_iou(pred_seg, gd_seg):
380
+ """Compute IoU."""
381
+ intersection = torch.sum(torch.logical_and(pred_seg, gd_seg))
382
+ union = torch.sum(torch.logical_or(pred_seg, gd_seg))
383
+ iou = intersection * 1.0 / union
384
+ if union == 0:
385
+ iou = 0
386
+ return intersection, union, iou
387
+
388
+
389
+ def list_of_strings(arg):
390
+ return [a.strip() for a in arg.split(',')]
391
+
392
+
393
+ # pylint: disable=redefined-outer-name
394
+ def parse_args():
395
+ """Parse arguments."""
396
+ parser = argparse.ArgumentParser(description='Training')
397
+ parser.add_argument(
398
+ '--cfg-path',
399
+ default='configs/refcoco_test_prompt.yaml',
400
+ help='path to configuration file.',
401
+ )
402
+ parser.add_argument('--index', default=0, type=int, help='split task')
403
+ parser.add_argument('--mask_threshold', default=0.0, type=float)
404
+ parser.add_argument('--confidence_threshold', default=0.0, type=float)
405
+ parser.add_argument('--clipes_threshold', default=0.0, type=float)
406
+ parser.add_argument('--stuff_bg_factor', default=0.0, type=float)
407
+ parser.add_argument('--bg_factor', default=0.0, type=float)
408
+ parser.add_argument('--output_path', default=None, type=str)
409
+ parser.add_argument(
410
+ '--visual_prompt_type', default=None, type=list_of_strings
411
+ )
412
+ parser.add_argument(
413
+ '--stuff_visual_prompt_type', default=None, type=list_of_strings
414
+ )
415
+
416
+ args = parser.parse_args()
417
+
418
+ return args
419
+
420
+
421
+ def main(args):
422
+ cfg = Config(**load_yaml(args.cfg_path))
423
+ if args.mask_threshold > 0:
424
+ cfg.car.mask_threshold = args.mask_threshold
425
+ if args.confidence_threshold > 0:
426
+ cfg.car.confidence_threshold = args.confidence_threshold
427
+ if args.clipes_threshold > 0:
428
+ cfg.car.clipes_threshold = args.clipes_threshold
429
+ if args.bg_factor > 0:
430
+ cfg.car.bg_factor = args.bg_factor
431
+ if args.stuff_bg_factor > 0:
432
+ cfg.car.stuff_bg_factor = args.stuff_bg_factor
433
+ if args.output_path is not None:
434
+ cfg.test.output_path = args.output_path
435
+ if args.visual_prompt_type is not None:
436
+ cfg.car.visual_prompt_type = args.visual_prompt_type
437
+ if args.stuff_visual_prompt_type is not None:
438
+ cfg.car.stuff_visual_prompt_type = args.stuff_visual_prompt_type
439
+
440
+ try:
441
+ data_root = cfg.test.data_root
442
+ except ValueError:
443
+ data_root = None
444
+
445
+ dataset_test = get_dataset(
446
+ cfg, cfg.test.ds_name, cfg.test.split, get_transform(), data_root
447
+ )
448
+
449
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
450
+
451
+ stuff_label_space = None
452
+ if cfg.test.ds_name == 'voc':
453
+ label_space = VOC_CLASSES
454
+ elif cfg.test.ds_name == 'cocostuff':
455
+ label_space = COCO_OBJECT_CLASSES
456
+ elif cfg.test.ds_name == 'context':
457
+ # label_space = PASCAL_CONTEXT_CLASSES
458
+ label_space = PASCAL_CONTEXT_THING_CLASS
459
+ stuff_label_space = PASCAL_CONTEXT_STUFF_CLASS
460
+ elif cfg.test.ds_name == 'ade':
461
+ label_space = ADE_THING_CLASS
462
+ stuff_label_space = ADE_STUFF_CLASS
463
+ elif cfg.test.ds_name == 'pascal_459':
464
+ label_space = PASCAL_459_THING_CLASS
465
+ stuff_label_space = PASCAL_459_STUFF_CLASS
466
+ elif cfg.test.ds_name == 'ade_847':
467
+ label_space = ADE_847_THING_CLASS
468
+ stuff_label_space = ADE_847_STUFF_CLASS
469
+ else:
470
+ label_space = None
471
+
472
+ num_chunks, chunk_index = 1, 0
473
+ if hasattr(cfg.test, 'num_chunks'):
474
+ num_chunks = cfg.test.num_chunks
475
+ if hasattr(cfg.test, 'chunk_index'):
476
+ chunk_index = cfg.test.chunk_index
477
+ # Size of each chunk
478
+ chunk_size = len(dataset_test) // num_chunks
479
+ # Choose which chunk to load (0-indexed)
480
+ # Define a subset of the dataset
481
+ subset_indices = range(
482
+ chunk_index * chunk_size, (chunk_index + 1) * chunk_size
483
+ )
484
+ subset_dataset = Subset(dataset_test, indices=subset_indices)
485
+
486
+ data_loader_test = torch.utils.data.DataLoader(
487
+ subset_dataset, batch_size=1, shuffle=False, num_workers=1
488
+ )
489
+
490
+ car_model = CaR(cfg, device=device, seg_mode=cfg.test.seg_mode)
491
+
492
+ car_model = car_model.to(device)
493
+
494
+ if not cfg.test.use_pseudo and cfg.test.sam_mask_root is None:
495
+ print('Using sam online')
496
+ # sam_checkpoint, model_type = build_sam_config(cfg)
497
+ build_sam_config(cfg)
498
+
499
+ evaluate(
500
+ data_loader_test,
501
+ cfg,
502
+ car_model,
503
+ test_cfg=cfg.test,
504
+ label_space=label_space,
505
+ stuff_label_space=stuff_label_space,
506
+ )
507
+
508
+
509
+ if __name__ == '__main__':
510
+ args = parse_args()
511
+ main(args)
modeling/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
modeling/model/cam.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Get CAM activation."""
17
+
18
+ import cv2
19
+ import numpy as np
20
+ import torch
21
+
22
+
23
+ _EPSILON = 1e-15
24
+
25
+
26
+ def scale_cam_image(cam, target_size=None):
27
+ """Normalize and rescale cam image."""
28
+ result = []
29
+ for img in cam:
30
+ img = img - np.min(img)
31
+ img = img / (_EPSILON + np.max(img))
32
+ if target_size is not None:
33
+ img = cv2.resize(img, target_size)
34
+ result.append(img)
35
+ result = np.float32(result)
36
+
37
+ return result
38
+
39
+
40
+ class ActivationsAndGradients:
41
+ """Class for extracting activations and registering gradients from targetted intermediate layers."""
42
+
43
+ def __init__(self, model, target_layers, reshape_transform, stride=16):
44
+ self.model = model
45
+ self.gradients = []
46
+ self.activations = []
47
+ self.reshape_transform = reshape_transform
48
+ self.handles = []
49
+ self.stride = stride
50
+ for target_layer in target_layers:
51
+ self.handles.append(
52
+ target_layer.register_forward_hook(self.save_activation)
53
+ )
54
+ # Because of https://github.com/pytorch/pytorch/issues/61519,
55
+ # we don't use backward hook to record gradients.
56
+ self.handles.append(
57
+ target_layer.register_forward_hook(self.save_gradient)
58
+ )
59
+
60
+ # pylint: disable=unused-argument
61
+ # pylint: disable=redefined-builtin
62
+ def save_activation(self, module, input, output):
63
+ """Saves activations from targetted layer."""
64
+ activation = output
65
+
66
+ if self.reshape_transform is not None:
67
+ activation = self.reshape_transform(activation, self.height, self.width)
68
+ self.activations.append(activation.cpu().detach())
69
+
70
+ def save_gradient(self, module, input, output):
71
+ if not hasattr(output, "requires_grad") or not output.requires_grad:
72
+ # You can only register hooks on tensor requires grad.
73
+ return
74
+
75
+ # Gradients are computed in reverse order
76
+ def _store_grad(grad):
77
+ if self.reshape_transform is not None:
78
+ grad = self.reshape_transform(grad, self.height, self.width)
79
+ self.gradients = [grad.cpu().detach()] + self.gradients
80
+
81
+ output.register_hook(_store_grad)
82
+
83
+ # pylint: enable=unused-argument
84
+ # pylint: enable=redefined-builtin
85
+
86
+ def __call__(self, x, h, w):
87
+ self.height = h // self.stride
88
+ self.width = w // self.stride
89
+ self.gradients = []
90
+ self.activations = []
91
+ if isinstance(x, tuple) or isinstance(x, list):
92
+ return self.model.forward_last_layer(x[0], x[1])
93
+ else:
94
+ return self.model(x)
95
+
96
+ def release(self):
97
+ for handle in self.handles:
98
+ handle.remove()
99
+
100
+
101
+ # pylint: disable=g-bare-generic
102
+ class CAM:
103
+ """CAM module."""
104
+
105
+ def __init__(
106
+ self,
107
+ model,
108
+ target_layers,
109
+ use_cuda=False,
110
+ reshape_transform=None,
111
+ compute_input_gradient=False,
112
+ stride=16,
113
+ ):
114
+ self.model = model.eval()
115
+ self.target_layers = target_layers
116
+ self.cuda = use_cuda
117
+ self.model = model.cuda() if self.cuda else self.model
118
+ self.reshape_transform = reshape_transform
119
+ self.compute_input_gradient = compute_input_gradient
120
+ self.activations_and_grads = ActivationsAndGradients(
121
+ self.model, target_layers, reshape_transform, stride=stride
122
+ )
123
+
124
+ def get_cam(self, activations, grads):
125
+ weights = np.mean(grads, axis=(2, 3))
126
+ weighted_activations = weights[:, :, None, None] * activations
127
+ cam = weighted_activations.sum(axis=1)
128
+ return cam
129
+
130
+ def forward(
131
+ self,
132
+ input_tensor,
133
+ targets,
134
+ target_size,
135
+ ):
136
+ """CAM forward pass."""
137
+ if self.compute_input_gradient:
138
+ input_tensor = torch.autograd.Variable(input_tensor, requires_grad=True)
139
+
140
+ w, h = self.get_target_width_height(input_tensor)
141
+ outputs = self.activations_and_grads(input_tensor, h, w)
142
+
143
+ self.model.zero_grad()
144
+ if isinstance(input_tensor, (tuple, list)):
145
+ loss = sum(
146
+ [target(output[0]) for target, output in zip(targets, outputs)]
147
+ )
148
+ else:
149
+ loss = sum([target(output) for target, output in zip(targets, outputs)])
150
+ loss.backward(retain_graph=True)
151
+ cam_per_layer = self.compute_cam_per_layer(target_size)
152
+ if isinstance(input_tensor, (tuple, list)):
153
+ return (
154
+ self.aggregate_multi_layers(cam_per_layer),
155
+ outputs[0],
156
+ outputs[1],
157
+ )
158
+ else:
159
+ return self.aggregate_multi_layers(cam_per_layer), outputs
160
+
161
+ def get_target_width_height(self, input_tensor):
162
+ width = None
163
+ height = None
164
+ if isinstance(input_tensor, (tuple, list)):
165
+ width, height = input_tensor[-1], input_tensor[-2]
166
+ return width, height
167
+
168
+ def compute_cam_per_layer(self, target_size):
169
+ """Computes cam per target layer."""
170
+ activations_list = [
171
+ a.cpu().data.numpy() for a in self.activations_and_grads.activations
172
+ ]
173
+ grads_list = [
174
+ g.cpu().data.numpy() for g in self.activations_and_grads.gradients
175
+ ]
176
+
177
+ cam_per_target_layer = []
178
+ # Loop over the saliency image from every layer
179
+ for i in range(len(self.target_layers)):
180
+ layer_activations = None
181
+ layer_grads = None
182
+ if i < len(activations_list):
183
+ layer_activations = activations_list[i]
184
+ if i < len(grads_list):
185
+ layer_grads = grads_list[i]
186
+
187
+ cam = self.get_cam(layer_activations, layer_grads)
188
+ cam = np.maximum(cam, 0).astype(np.float32) # float16->32
189
+ scaled = scale_cam_image(cam, target_size)
190
+ cam_per_target_layer.append(scaled[:, None, :])
191
+
192
+ return cam_per_target_layer
193
+
194
+ def aggregate_multi_layers(self, cam_per_target_layer):
195
+ cam_per_target_layer = np.concatenate(cam_per_target_layer, axis=1)
196
+ cam_per_target_layer = np.maximum(cam_per_target_layer, 0)
197
+ result = np.mean(cam_per_target_layer, axis=1)
198
+ return scale_cam_image(result)
199
+
200
+ def __call__(
201
+ self,
202
+ input_tensor,
203
+ targets=None,
204
+ target_size=None,
205
+ ):
206
+ return self.forward(input_tensor, targets, target_size)
207
+
208
+ def __del__(self):
209
+ self.activations_and_grads.release()
210
+
211
+ def __enter__(self):
212
+ return self
213
+
214
+ def __exit__(self, exc_type, exc_value, exc_tb):
215
+ self.activations_and_grads.release()
216
+ if isinstance(exc_value, IndexError):
217
+ # Handle IndexError here...
218
+ print(
219
+ f"An exception occurred in CAM with block: {exc_type}. "
220
+ f"Message: {exc_value}"
221
+ )
222
+ return True
modeling/model/car.py ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Implementation of CaR."""
17
+
18
+ import os
19
+
20
+ import clip
21
+ import numpy as np
22
+ import torch
23
+ from torch import nn
24
+ import torch.nn.functional as F
25
+
26
+ # pylint: disable=g-importing-member
27
+ # pylint: disable=g-bad-import-order
28
+ from modeling.model.clip_wrapper import CLIPWrapper
29
+ from modeling.model.clip_wrapper import forward_clip
30
+ from modeling.model.clipcam import CLIPCAM
31
+ from modeling.model.crf import PostProcess
32
+ from modeling.model.utils import apply_visual_prompts
33
+ from utils.visualize import viz_attn
34
+
35
+
36
+ class CaR(nn.Module):
37
+ """CaR module."""
38
+
39
+ def __init__(
40
+ self,
41
+ cfg,
42
+ device="cpu",
43
+ visualize=False,
44
+ confidence_threshold=0.45,
45
+ save_path="save_path",
46
+ seg_mode="refer",
47
+ semantic_clip_model_name=None,
48
+ semantic_pretrained_data=None,
49
+ semantic_templates=None,
50
+ text_template=None,
51
+ visual_prompt_type="circle",
52
+ clipes_threshold=0.4,
53
+ cam_text_template="a clean origami {}.",
54
+ bg_cls=None,
55
+ iom_thres=0.6,
56
+ min_pred_threshold=0.01,
57
+ bg_factor=1.0,
58
+ mask_threshold=0.5,
59
+ ):
60
+ """CaR model for image segmentation.
61
+
62
+ Args:
63
+ cfg: the config file.
64
+ device: the device to run the model.
65
+ visualize: whether to visualize the intermediate results
66
+ confidence_threshold: the confidence threshold for semantic
67
+ segmentation. If the confidence score is lower than the threshold, the
68
+ mask will be discarded.
69
+ save_path: the path to save the intermediate results
70
+ seg_mode: the segmentation mode, can be 'refer' or 'semantic'
71
+ semantic_clip_model_name: the name of the semantic segmentation model.
72
+ semantic_pretrained_data: the path to the pretrained semantic
73
+ segmentation model.
74
+ semantic_templates: the templates for semantic segmentation.
75
+ text_template: the template for visual prompting.
76
+ visual_prompt_type: the type of visual prompting.
77
+ clipes_threshold: the threshold for CLIPES.
78
+ cam_text_template: the template for CAM.
79
+ bg_cls: background classes.
80
+ iom_thres: IoM threshold.
81
+ min_pred_threshold: Prediction threshold.
82
+ bg_factor: Background factor.
83
+ mask_threshold: Mask threshold.
84
+ """
85
+ super(CaR, self).__init__()
86
+ # CLIP parameters
87
+ self.confidence_threshold = confidence_threshold
88
+ self.device = device
89
+ self.visualize = visualize
90
+ self.save_path = save_path
91
+ self.seg_mode = seg_mode
92
+ self.semantic_clip_model_name = semantic_clip_model_name
93
+ self.semantic_pretrained_data = semantic_pretrained_data
94
+ self.semantic_templates = semantic_templates
95
+ self.text_template = text_template
96
+ self.visual_prompt_type = visual_prompt_type
97
+ self.clipes_threshold = clipes_threshold
98
+ self.cam_text_template = cam_text_template
99
+ self.iom_thres = iom_thres
100
+ self.min_pred_threshold = min_pred_threshold
101
+ self.bg_cls = bg_cls
102
+ self.bg_factor = bg_factor
103
+ self.mask_threshold = mask_threshold
104
+
105
+ if not hasattr(cfg, "clip"):
106
+ raise ValueError("The config file should contain the CLIP parameters.")
107
+
108
+ if not hasattr(cfg, "car"):
109
+ raise ValueError("The config file should contain the car parameters.")
110
+
111
+ if hasattr(cfg, "cam"):
112
+ raise ValueError("cfg.cam is deprecated, please use cfg.car ")
113
+
114
+ for k, v in vars(cfg.clip).items():
115
+ setattr(self, k, v)
116
+
117
+ for k, v in vars(cfg.car).items():
118
+ setattr(self, k, v)
119
+
120
+ if hasattr(cfg, "sam"):
121
+ for k, v in vars(cfg.sam).items():
122
+ setattr(self, k, v)
123
+ if not self.bg_cls:
124
+ self.bg_cls = None
125
+ print(f"The model is running on {self.device}")
126
+ self.clip_model, self.preprocess = clip.load(
127
+ self.clip_model_name, device=self.device
128
+ )
129
+ self.clip_model = CLIPWrapper(self.clip_model)
130
+ self.post_process = PostProcess(device=self.device)
131
+ self.mask_generator = CLIPCAM(
132
+ self.clip_model,
133
+ device=self.device,
134
+ text_template=self.text_template,
135
+ threshold=self.clipes_threshold,
136
+ bg_cls=self.bg_cls,
137
+ )
138
+ self.semantic_clip_model, self.semantic_preprocess = clip.load(
139
+ self.semantic_clip_model_name, device=self.device
140
+ )
141
+ self.semantic_clip_model = CLIPWrapper(self.semantic_clip_model)
142
+
143
+ def get_confidence(self, cam_map, binary_cam_map):
144
+ confidence_map = torch.sum(cam_map * binary_cam_map[None], dim=[2, 3])
145
+ confidence_map = confidence_map / torch.sum(binary_cam_map, dim=[1, 2])
146
+ confidence_score = confidence_map.squeeze()
147
+ return confidence_score
148
+
149
+ def set_visual_prompt_type(self, visual_prompt_type):
150
+ self.visual_prompt_type = visual_prompt_type
151
+
152
+ def set_bg_factor(self, bg_factor):
153
+ self.bg_factor = bg_factor
154
+
155
+ def set_confidence_threshold(self, confidence_threshold):
156
+ self.confidence_threshold = confidence_threshold
157
+
158
+ def set_mask_threshold(self, mask_threshold):
159
+ self.mask_threshold = mask_threshold
160
+
161
+ def apply_visual_prompts(self, image, mask):
162
+ if torch.sum(mask).item() <= 1:
163
+ return image
164
+ image_array = np.array(image)
165
+ img_h = image_array.shape[0]
166
+ img_w = image_array.shape[1]
167
+ mask = (
168
+ F.interpolate(mask[None][None], size=(img_h, img_w), mode="nearest")
169
+ .squeeze()
170
+ .detach()
171
+ .cpu()
172
+ .numpy()
173
+ )
174
+ mask = (mask > self.mask_threshold).astype(np.uint8)
175
+ prompted_image = apply_visual_prompts(
176
+ image_array, mask, self.visual_prompt_type, self.visualize
177
+ )
178
+ return prompted_image
179
+
180
+ def get_mask_confidence(self, prompted_images, prompt_text):
181
+ """Get the confidene for each mask with visual prompting."""
182
+ # get the center, width and height of the mask
183
+ prompted_tensor = torch.stack(
184
+ [self.semantic_preprocess(img) for img in prompted_images], dim=0
185
+ )
186
+ prompted_tensor = prompted_tensor.to(self.device)
187
+ h, w = prompted_tensor.shape[-2:]
188
+ text_prediction = forward_clip(
189
+ self.semantic_clip_model, prompted_tensor, prompt_text, h, w
190
+ )
191
+ return text_prediction
192
+
193
+ def _filter_texts(self, ori_mask_id, sem_scores, prompt_text):
194
+ """Remove false positive masks by score filtering and recall the backbone to get the CAM maps for the filtered texts."""
195
+ if not ori_mask_id:
196
+ max_id = np.argmax(sem_scores)
197
+ ori_mask_id.append(max_id)
198
+ filtered_text = [prompt_text[i] for i in ori_mask_id]
199
+ return filtered_text
200
+
201
+ def _forward_stage(self, ori_img, cam_text, clip_text, semantic_prompt_text):
202
+ mask_proposals = self.get_mask_proposals(ori_img, cam_text)
203
+ num_texts = len(cam_text)
204
+ ori_mask_id = []
205
+ sem_scores = torch.zeros((num_texts,), device=self.device).float()
206
+ prompted_imgs = [
207
+ self.apply_visual_prompts(ori_img, cam_map)
208
+ for cam_map in mask_proposals
209
+ ]
210
+ text_scores = self.get_mask_confidence(prompted_imgs, semantic_prompt_text)
211
+ mask_scores = torch.diagonal(text_scores)
212
+ for mask_idx, mask_score in enumerate(mask_scores):
213
+ # record mask idx
214
+ if mask_score > self.confidence_threshold:
215
+ ori_mask_id.append(mask_idx)
216
+ sem_scores[mask_idx] = mask_score
217
+ sem_scores = sem_scores.cpu().detach().numpy()
218
+ filtered_texts = self._filter_texts(ori_mask_id, sem_scores, clip_text)
219
+ # if isinstance(ori_img, list):
220
+ # ori_img = [ori_img[i] for i in ori_mask_id]
221
+
222
+ all_scores = torch.zeros((num_texts,), device=self.device).float()
223
+ sem_scores = torch.from_numpy(sem_scores).to(self.device)
224
+ for new_id, ori_id in enumerate(ori_mask_id):
225
+ if new_id >= len(mask_proposals):
226
+ # the mask is filtered out.
227
+ continue
228
+ all_scores[ori_id] = sem_scores[ori_id]
229
+ return filtered_texts, all_scores, mask_proposals
230
+
231
+ def _get_save_path(self, text):
232
+ folder_name = "_".join([t.replace(" ", "_") for t in text])
233
+ if len(folder_name) > 20:
234
+ folder_name = folder_name[:20]
235
+ output_path = os.path.join(self.save_path, folder_name)
236
+ sub_output_path = [
237
+ os.path.join(output_path, t.replace(" ", "_")) for t in text
238
+ ]
239
+ return output_path, sub_output_path
240
+
241
+ def get_mask_proposals(self, img, text):
242
+ if self.seg_mode == "refer":
243
+ if isinstance(img, list):
244
+ cam_map_list = [self.mask_generator(i, t)[0] for i, t in zip(img, text)]
245
+ else:
246
+ cam_map_list = [self.mask_generator(img, t)[0] for t in text]
247
+ return torch.cat(cam_map_list, dim=0)
248
+ elif self.seg_mode == "semantic":
249
+ return self.mask_generator(img, text)[0]
250
+ else:
251
+ raise ValueError(
252
+ "Unknown segmentation mode. Only refer and semantic segmentation are"
253
+ " supported."
254
+ )
255
+
256
+ def _forward_car(self, ori_img, text):
257
+ if isinstance(text, str):
258
+ text = [text]
259
+ _, sub_output_path = self._get_save_path(text)
260
+ image_array = np.array(ori_img)
261
+ clip_text = [self.cam_text_template.format(t) for t in text]
262
+ cam_text = text
263
+ init_clip_text = clip_text # the text prompts of CLIP is different.
264
+ semantic_prompt_text = clip_text
265
+ # Apply semantic prompting augmentation.
266
+ if self.semantic_templates is not None:
267
+ semantic_prompt_text = []
268
+ for template in self.semantic_templates:
269
+ templated_text = [template.format(t) for t in text]
270
+ semantic_prompt_text.append(templated_text)
271
+
272
+ num_positive_last = 0
273
+ run = 0
274
+ while True:
275
+ run += 1
276
+ cur_texts, all_scores, mask_proposals = self._forward_stage(
277
+ ori_img, cam_text, clip_text, semantic_prompt_text
278
+ )
279
+ if cur_texts: # if there is no text, skip the refinement
280
+ cam_text = cur_texts
281
+ clip_text = cur_texts
282
+
283
+ num_positive = (all_scores > 0).sum().item()
284
+ if num_positive == num_positive_last:
285
+ # stop the refinement if the number of positive masks
286
+ # does not change.
287
+ break
288
+ num_positive_last = num_positive
289
+ # Apply densecrf for refinement.
290
+ # SAM is optional and is applied outside the model.
291
+ refined_masks = self.post_process(
292
+ ori_img,
293
+ mask_proposals,
294
+ separate=self.seg_mode == "refer",
295
+ bg_factor=self.bg_factor,
296
+ )
297
+ predicted_class_idx = [init_clip_text.index(t) for t in cur_texts]
298
+ if self.visualize:
299
+ _ = [
300
+ viz_attn(
301
+ image_array,
302
+ attn,
303
+ prefix=sub_output_path[aid],
304
+ img_name="semantic_mask",
305
+ )
306
+ for aid, attn in enumerate(refined_masks)
307
+ ]
308
+ final_predicted_masks = torch.zeros(len(text), *refined_masks[0].shape)
309
+ final_all_scores = torch.zeros(len(text))
310
+ for idx, mask, score in zip(predicted_class_idx, refined_masks, all_scores):
311
+ final_predicted_masks[idx] = mask
312
+ final_all_scores[idx] = score
313
+ return final_predicted_masks, final_all_scores
314
+
315
+ def forward(self, im_ori, text):
316
+ # raw_image_np is the padded image input with shape (512, 512, 3)
317
+ pseudo_masks, conf_scores = self._forward_car(im_ori, text)
318
+ return pseudo_masks, conf_scores
modeling/model/clip_wrapper.py ADDED
@@ -0,0 +1,297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """A wrapper for CLIP model to support forward with a list of text inputs."""
17
+
18
+ # pylint: disable=g-importing-member
19
+ import clip
20
+ import numpy as np
21
+ import torch
22
+ from torch import nn
23
+ import torch.nn.functional as F
24
+
25
+ _CONTEXT_LENGTH = 77
26
+
27
+
28
+ def forward_clip_single(model, image, text, h, w):
29
+ """Forward a single text input.
30
+
31
+ Args:
32
+ model (CLIPWrapper or CLIP): the CLIP model.
33
+ image (torch.Tensor): the image tensor.
34
+ text (List[str]): the text input.
35
+ h (int): the height of the image.
36
+ w (int): the width of the image.
37
+
38
+ Returns:
39
+ torch.Tensor: the logits.
40
+ """
41
+ if isinstance(text, str):
42
+ text = [text]
43
+ text_tokens = clip.tokenize(text).to(image.device)
44
+ text_prediction = model(image, text_tokens, h, w)
45
+ return text_prediction.detach().cpu()
46
+
47
+
48
+ def forward_clip(model, image, text, h, w):
49
+ """Forward a list of text inputs.
50
+
51
+ Args:
52
+ model (CLIPWrapper or CLIP): the CLIP model.
53
+ image (torch.Tensor): the image tensor.
54
+ text (List[str] or List[List[str]]): the text input.
55
+ h (int): the height of the image.
56
+ w (int): the width of the image.
57
+
58
+ Returns:
59
+ torch.Tensor: the logits.
60
+ """
61
+ if isinstance(text[0], list):
62
+ text_prediction = torch.stack(
63
+ [forward_clip_single(model, image, t, h, w) for t in text], dim=0
64
+ )
65
+ text_prediction = torch.sum(text_prediction, dim=0)
66
+ text_prediction = F.softmax(text_prediction.float(), dim=-1)
67
+ else:
68
+ text_prediction = forward_clip_single(model, image, text, h, w)
69
+ return text_prediction.float()
70
+
71
+
72
+ def upsample_position_embedding(embed, new_size):
73
+ """Upsample the pretrained embedding to a higher resolution.
74
+
75
+ Args:
76
+ embed (torch.Tensor): the pretrained embedding.
77
+ new_size (Tuple[int, int]): the new size of the embedding.
78
+
79
+ Returns:
80
+ torch.Tensor: the upsampled embedding.
81
+ """
82
+ # emb size NxD
83
+ first = embed[:1, :]
84
+ embed = embed[1:, :]
85
+ n = embed.size(0)
86
+ d = embed.size(1)
87
+ size = int(np.sqrt(n))
88
+ if size * size != n:
89
+ raise ValueError(f'The size of embed {n} is not a perfect square number.')
90
+ # new_size = size * self.upsample
91
+ embed = embed.permute(1, 0)
92
+ embed = embed.view(1, d, size, size).contiguous()
93
+ embed = F.upsample(
94
+ embed,
95
+ size=new_size,
96
+ mode='bilinear',
97
+ )
98
+ embed = embed.view(d, -1).contiguous()
99
+ embed = embed.permute(1, 0)
100
+ embed = torch.cat([first, embed], 0)
101
+ embed = nn.parameter.Parameter(embed.half())
102
+ return embed
103
+
104
+
105
+ class CustomBlock(nn.Module):
106
+ """A customized attention block."""
107
+
108
+ def __init__(self, block):
109
+ super().__init__()
110
+ for k, v in vars(block).items():
111
+ setattr(self, k, v)
112
+
113
+ def attention(self, x):
114
+ self.attn_mask = (
115
+ self.attn_mask.to(dtype=x.dtype, device=x.device)
116
+ if self.attn_mask is not None
117
+ else None
118
+ )
119
+ self.attn = self.attn.to(dtype=x.dtype, device=x.device)
120
+ # Setting need_weights to True also returns the attention weights
121
+ return self.attn(x, x, x, need_weights=True, attn_mask=self.attn_mask)
122
+
123
+ def forward(self, x):
124
+ # attn_output: (L,N,E), attn_weight: (N,L,L)
125
+ attn_output, attn_weight = self.attention(self.ln_1(x))
126
+ x = x + attn_output
127
+ x = x + self.mlp(self.ln_2(x))
128
+ return x, attn_weight
129
+
130
+
131
+ class CustomTransformer(nn.Module):
132
+ """A customized Transformer to support CAM calculation."""
133
+
134
+ def __init__(self, transformer):
135
+ """Initialize the wrapper.
136
+
137
+ Args:
138
+ transformer (nn.Module): the Transformer to be wrapped.
139
+ """
140
+ super().__init__()
141
+ for k, v in vars(transformer).items():
142
+ setattr(self, k, v)
143
+
144
+ self.resblocks = nn.Sequential(
145
+ *[CustomBlock(block) for block in self.resblocks]
146
+ )
147
+
148
+ def forward(self, x):
149
+ attn_weights = []
150
+ with torch.no_grad():
151
+ layers = self.layers if x.shape[0] == _CONTEXT_LENGTH else self.layers - 1
152
+ for i in range(layers):
153
+ x, attn_weight = self.resblocks[i](x)
154
+ attn_weights.append(attn_weight)
155
+ return x, attn_weights
156
+
157
+
158
+ class CustomVisionTransformer(nn.Module):
159
+ """A customized VisionTransformer to support CAM calculation."""
160
+
161
+ def __init__(self, model):
162
+ """Initialize the wrapper.
163
+
164
+ Args:
165
+ model (VisionTransformer): the VisionTransformer to be wrapped.
166
+ """
167
+ super().__init__()
168
+ for k, v in vars(model).items():
169
+ setattr(self, k, v)
170
+ self.patch_size = self.conv1.kernel_size[0]
171
+ self.transformer = CustomTransformer(self.transformer)
172
+
173
+ def forward(self, x, h, w):
174
+ self.positional_embedding_new = upsample_position_embedding(
175
+ self.positional_embedding, (h // self.patch_size, w // self.patch_size)
176
+ )
177
+ # shape = [*, width, grid, grid]
178
+ x = self.conv1(x)
179
+ # shape = [*, width, grid ** 2]
180
+ x = x.reshape(x.shape[0], x.shape[1], -1)
181
+ # shape = [*, grid ** 2, width]
182
+ x = x.permute(0, 2, 1)
183
+ zeros = torch.zeros(
184
+ x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device
185
+ )
186
+ # shape = [*, grid ** 2 + 1, width]
187
+ x = torch.cat([self.class_embedding.to(x.dtype) + zeros, x], dim=1)
188
+ x = x + self.positional_embedding_new.to(x.dtype)
189
+ x = self.ln_pre(x)
190
+ # NLD -> LND
191
+ x = x.permute(1, 0, 2)
192
+ x, attn_weight = self.transformer(x)
193
+ return x, attn_weight
194
+
195
+
196
+ class CLIPWrapper(nn.Module):
197
+ """A wrapper for CLIP to support forward with a list of text inputs."""
198
+
199
+ def __init__(self, clip_model):
200
+ """Initialize the wrapper.
201
+
202
+ Args:
203
+ clip_model (CLIP): the CLIP model to be wrapped.
204
+ """
205
+ super().__init__()
206
+ # copy all attributes from clip_model to self
207
+ for k, v in vars(clip_model).items():
208
+ setattr(self, k, v)
209
+ self.visual = CustomVisionTransformer(self.visual)
210
+ self.transformer = CustomTransformer(self.transformer)
211
+
212
+ @property
213
+ def dtype(self):
214
+ return self.visual.conv1.weight.dtype
215
+
216
+ def encode_image(self, image, h, w):
217
+ return self.visual(image.type(self.dtype), h, w)
218
+
219
+ def encode_text(self, text):
220
+ x = self.token_embedding(text).type(
221
+ self.dtype
222
+ ) # [batch_size, n_ctx, d_model]
223
+
224
+ x = x + self.positional_embedding.type(self.dtype)
225
+ x = x.permute(1, 0, 2) # NLD -> LND
226
+ x, _ = self.transformer(x)
227
+ x = x.permute(1, 0, 2) # LND -> NLD
228
+ x = self.ln_final(x).type(self.dtype)
229
+
230
+ # x.shape = [batch_size, n_ctx, transformer.width]
231
+ # take features from the eot embedding
232
+ # (eot_token is the highest number in each sequence)
233
+ x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
234
+
235
+ return x
236
+
237
+ def pool_visual(self, x, use_cls_token=False):
238
+ if use_cls_token:
239
+ return x[:, 0]
240
+ else:
241
+ return torch.mean(x[:, 1:, :], dim=1)
242
+
243
+ def forward_last_layer(
244
+ self, image_features, text_features, use_cls_token=False, repeat_last=True
245
+ ):
246
+ """Forward the last layer of CLIP.
247
+
248
+ Args:
249
+ image_features (torch.Tensor): the image features.
250
+ text_features (torch.Tensor): the text features.
251
+ use_cls_token (bool, optional): whether to use the CLS token. Defaults
252
+ to False.
253
+ repeat_last (bool, optional): whether to repeat the last layer. Defaults
254
+ to True.
255
+
256
+ Returns:
257
+ torch.Tensor: the logits.
258
+ torch.Tensor: the attention weights.
259
+ """
260
+ if repeat_last:
261
+ x, attention_weight = self.visual.transformer.resblocks[
262
+ self.visual.transformer.layers - 1
263
+ ](image_features)
264
+ else:
265
+ x = image_features
266
+ attention_weight = None
267
+ x = x.permute(1, 0, 2) # LND -> NLD
268
+
269
+ x = self.visual.ln_post(x)
270
+ x = self.pool_visual(x, use_cls_token=use_cls_token)
271
+
272
+ if self.visual.proj is not None:
273
+ x = x @ self.visual.proj
274
+
275
+ image_features = x
276
+
277
+ # normalized features
278
+ image_features = image_features / image_features.norm(dim=1, keepdim=True)
279
+ text_features = text_features / text_features.norm(dim=1, keepdim=True)
280
+ # cosine similarity as logits
281
+ logit_scale = self.logit_scale.exp()
282
+ logits_per_image = logit_scale * image_features @ text_features.t()
283
+
284
+ # shape = [global_batch_size, global_batch_size]
285
+ logits_per_image = F.softmax(logits_per_image.float(), dim=-1)
286
+
287
+ return logits_per_image, attention_weight
288
+
289
+ def forward(self, image, text, h=224, w=224):
290
+ with torch.no_grad():
291
+ text_features = self.encode_text(text)
292
+ feature_map, _ = self.visual(image.type(self.dtype), h, w)
293
+
294
+ logits_per_image, _ = self.forward_last_layer(
295
+ feature_map, text_features, use_cls_token=True, repeat_last=False
296
+ )
297
+ return logits_per_image
modeling/model/clipcam.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Calculate CAM with CLIP model."""
17
+
18
+ import warnings
19
+
20
+ import clip
21
+ import cv2
22
+ import numpy as np
23
+ import torch
24
+
25
+ # pylint: disable=g-importing-member
26
+ # pylint: disable=g-bad-import-order
27
+ from modeling.model.cam import CAM
28
+ from modeling.model.cam import scale_cam_image
29
+ from modeling.model.utils import img_ms_and_flip
30
+ from modeling.model.utils import reshape_transform
31
+ from modeling.model.utils import scoremap2bbox
32
+
33
+ warnings.filterwarnings("ignore")
34
+
35
+
36
+ class ClipOutputTarget:
37
+
38
+ def __init__(self, category):
39
+ self.category = category
40
+
41
+ def __call__(self, model_output):
42
+ if len(model_output.shape) == 1:
43
+ return model_output[self.category]
44
+ return model_output[:, self.category]
45
+
46
+
47
+ def zeroshot_classifier(classnames, templates, model, device):
48
+ """Zeroshot classifier."""
49
+ with torch.no_grad():
50
+ zeroshot_weights = []
51
+ for classname in classnames:
52
+ if templates is None:
53
+ texts = [classname]
54
+ else:
55
+ # format with class
56
+ texts = [template.format(classname) for template in templates]
57
+ texts = clip.tokenize(texts).to(device) # tokenize
58
+ class_embeddings = model.encode_text(texts) # embed with text encoder
59
+ class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
60
+ class_embedding = class_embeddings.mean(dim=0)
61
+ class_embedding /= class_embedding.norm()
62
+ zeroshot_weights.append(class_embedding)
63
+ zeroshot_weights = torch.stack(zeroshot_weights, dim=1).to(device)
64
+ return zeroshot_weights.t()
65
+
66
+
67
+ class CLIPCAM:
68
+ """Generate CAM with CLIP model."""
69
+
70
+ def __init__(
71
+ self,
72
+ clip_model,
73
+ device,
74
+ text_template=None,
75
+ threshold=0.4,
76
+ bg_cls=None,
77
+ ):
78
+ self.device = device
79
+ self.clip_model = clip_model.to(device)
80
+ self.text_template = text_template
81
+ self.threshold = threshold
82
+ self.stride = self.clip_model.visual.patch_size
83
+
84
+ # if self.dataset_name == 'voc' else BACKGROUND_CATEGORY_COCO
85
+ self.bg_cls = bg_cls
86
+ self.bg_text_features = None
87
+ if self.bg_cls is not None:
88
+ self.bg_text_features = zeroshot_classifier(
89
+ self.bg_cls,
90
+ ("a clean origami {}.",),
91
+ self.clip_model,
92
+ self.device,
93
+ ).to(self.device)
94
+ self.target_layers = [self.clip_model.visual.transformer.resblocks[-1].ln_1]
95
+ self.cam = CAM(
96
+ model=self.clip_model,
97
+ target_layers=self.target_layers,
98
+ reshape_transform=reshape_transform,
99
+ use_cuda="cuda" in device,
100
+ stride=self.stride,
101
+ )
102
+
103
+ def set_bg_cls(self, bg_cls):
104
+ # if len(bg_cls) == 0:
105
+ if not bg_cls:
106
+ self.bg_cls = None
107
+ self.bg_text_features = None
108
+ else:
109
+ self.bg_cls = bg_cls
110
+ self.bg_text_features = zeroshot_classifier(
111
+ self.bg_cls,
112
+ ("a clean origami {}.",),
113
+ self.clip_model,
114
+ self.device,
115
+ ).to(self.device)
116
+
117
+ def __call__(self, ori_img, text, scale=1.0):
118
+ """Get CAM masks and features.
119
+
120
+ Args:
121
+ ori_img(Image): image to be searched.
122
+ text (str): text to be searched.
123
+ scale (float): image scale.
124
+ Returns:
125
+ CAM masks and features.
126
+ """
127
+ ori_width = ori_img.size[0]
128
+ ori_height = ori_img.size[1]
129
+ if isinstance(text, str):
130
+ text = [text]
131
+
132
+ # convert image to bgr channel
133
+ ms_imgs = img_ms_and_flip(ori_img, ori_height, ori_width, scales=[scale])
134
+ image = ms_imgs[0]
135
+
136
+ image = image.unsqueeze(0)
137
+ h, w = image.shape[-2], image.shape[-1]
138
+ image = image.to(self.device)
139
+ image_features, attn_weight_list = self.clip_model.encode_image(image, h, w)
140
+
141
+ highres_cam_to_save = []
142
+ refined_cam_to_save = []
143
+ # keys = []
144
+
145
+ # [bg_id_for_each_image[im_idx]].to(device_id)
146
+ bg_features_temp = None
147
+ if self.bg_text_features is not None:
148
+ bg_features_temp = self.bg_text_features.to(self.device)
149
+ fg_features_temp = zeroshot_classifier(
150
+ text, self.text_template, self.clip_model, self.device
151
+ ).to(self.device)
152
+ if bg_features_temp is None:
153
+ text_features_temp = fg_features_temp
154
+ else:
155
+ text_features_temp = torch.cat(
156
+ [fg_features_temp, bg_features_temp], dim=0
157
+ )
158
+ input_tensor = [
159
+ image_features,
160
+ text_features_temp.to(self.device),
161
+ h,
162
+ w,
163
+ ]
164
+
165
+ # for idx, label in enumerate(label_list):
166
+ # keys.append(new_class_names.index(label))
167
+ for idx, _ in enumerate(text):
168
+ targets = [ClipOutputTarget(idx)]
169
+
170
+ # torch.cuda.empty_cache()
171
+ grayscale_cam, _, attn_weight_last = self.cam(
172
+ input_tensor=input_tensor, targets=targets, target_size=None
173
+ ) # (ori_width, ori_height))
174
+
175
+ grayscale_cam = grayscale_cam[0, :]
176
+ if grayscale_cam.max() == 0:
177
+ input_tensor_fg = (
178
+ image_features,
179
+ fg_features_temp.to(self.device),
180
+ h,
181
+ w,
182
+ )
183
+ grayscale_cam, _, attn_weight_last = self.cam(
184
+ input_tensor=input_tensor_fg,
185
+ targets=targets,
186
+ target_size=None,
187
+ )
188
+ grayscale_cam = grayscale_cam[0, :]
189
+
190
+ grayscale_cam_highres = cv2.resize(grayscale_cam, (ori_width, ori_height))
191
+ highres_cam_to_save.append(torch.tensor(grayscale_cam_highres))
192
+
193
+ if idx == 0:
194
+ attn_weight_list.append(attn_weight_last)
195
+ attn_weight = [
196
+ aw[:, 1:, 1:] for aw in attn_weight_list
197
+ ] # (b, hxw, hxw)
198
+ attn_weight = torch.stack(attn_weight, dim=0)[-8:]
199
+ attn_weight = torch.mean(attn_weight, dim=0)
200
+ attn_weight = attn_weight[0].cpu().detach()
201
+ attn_weight = attn_weight.float()
202
+
203
+ box, cnt = scoremap2bbox(
204
+ scoremap=grayscale_cam,
205
+ threshold=self.threshold,
206
+ multi_contour_eval=True,
207
+ )
208
+ aff_mask = torch.zeros((grayscale_cam.shape[0], grayscale_cam.shape[1]))
209
+ for i_ in range(cnt):
210
+ x0_, y0_, x1_, y1_ = box[i_]
211
+ aff_mask[y0_:y1_, x0_:x1_] = 1
212
+
213
+ aff_mask = aff_mask.view(
214
+ 1, grayscale_cam.shape[0] * grayscale_cam.shape[1]
215
+ )
216
+ aff_mat = attn_weight
217
+
218
+ trans_mat = aff_mat / torch.sum(aff_mat, dim=0, keepdim=True)
219
+ trans_mat = trans_mat / torch.sum(trans_mat, dim=1, keepdim=True)
220
+
221
+ for _ in range(2):
222
+ trans_mat = trans_mat / torch.sum(trans_mat, dim=0, keepdim=True)
223
+ trans_mat = trans_mat / torch.sum(trans_mat, dim=1, keepdim=True)
224
+ trans_mat = (trans_mat + trans_mat.transpose(1, 0)) / 2
225
+
226
+ # This is copied from CLIP-ES
227
+ for _ in range(1):
228
+ trans_mat = torch.matmul(trans_mat, trans_mat)
229
+
230
+ trans_mat = trans_mat * aff_mask
231
+
232
+ cam_to_refine = torch.FloatTensor(grayscale_cam)
233
+ cam_to_refine = cam_to_refine.view(-1, 1)
234
+
235
+ # (n,n) * (n,1)->(n,1)
236
+ cam_refined = torch.matmul(trans_mat, cam_to_refine).reshape(
237
+ h // self.stride, w // self.stride
238
+ )
239
+ cam_refined = cam_refined.cpu().numpy().astype(np.float32)
240
+ cam_refined_highres = scale_cam_image(
241
+ [cam_refined], (ori_width, ori_height)
242
+ )[0]
243
+ refined_cam_to_save.append(torch.tensor(cam_refined_highres))
244
+
245
+ # post process the cam map
246
+ # label = process(raw_image, refined_cam, postprocessor)
247
+ # vis_img = vis_mask(np.asarray(raw_image), label, [0, 255, 0])
248
+ # vis_img.save(f'clip_es_crf_{idx}.jpg')
249
+
250
+ # keys = torch.tensor(keys)
251
+ # cam_all_scales.append(torch.stack(cam_to_save,dim=0))
252
+
253
+ cam_masks = torch.stack(refined_cam_to_save, dim=0)
254
+
255
+ return cam_masks.to(self.device), fg_features_temp.to(self.device)
modeling/model/crf.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """DenseCRF."""
17
+
18
+ import numpy as np
19
+ from pydensecrf import densecrf as dcrf
20
+ from pydensecrf import utils
21
+ import torch
22
+ import torch.nn.functional as F
23
+
24
+
25
+ class DenseCRF(object):
26
+ """DenseCRF class."""
27
+
28
+ def __init__(self, iter_max, pos_w, pos_xy_std, bi_w, bi_xy_std, bi_rgb_std):
29
+ self.iter_max = iter_max
30
+ self.pos_w = pos_w
31
+ self.pos_xy_std = pos_xy_std
32
+ self.bi_w = bi_w
33
+ self.bi_xy_std = bi_xy_std
34
+ self.bi_rgb_std = bi_rgb_std
35
+
36
+ def __call__(self, image, probmap):
37
+ c, h, w = probmap.shape
38
+
39
+ u = utils.unary_from_softmax(probmap)
40
+ u = np.ascontiguousarray(u)
41
+
42
+ image = np.ascontiguousarray(image)
43
+
44
+ d = dcrf.DenseCRF2D(w, h, c)
45
+ d.setUnaryEnergy(u)
46
+ d.addPairwiseGaussian(sxy=self.pos_xy_std, compat=self.pos_w)
47
+ d.addPairwiseBilateral(
48
+ sxy=self.bi_xy_std,
49
+ srgb=self.bi_rgb_std,
50
+ rgbim=image,
51
+ compat=self.bi_w,
52
+ )
53
+
54
+ q = d.inference(self.iter_max)
55
+ q = np.array(q).reshape((c, h, w))
56
+
57
+ return q
58
+
59
+
60
+ class PostProcess:
61
+ """Post processing with dense CRF."""
62
+
63
+ def __init__(self, device):
64
+ self.device = device
65
+ self.postprocessor = DenseCRF(
66
+ iter_max=10,
67
+ pos_xy_std=1,
68
+ pos_w=3,
69
+ bi_xy_std=67,
70
+ bi_rgb_std=3,
71
+ bi_w=4,
72
+ )
73
+
74
+ def apply_crf(self, image, cams, bg_factor=1.0):
75
+ """Apply dense CRF."""
76
+ bg_score = np.power(1 - np.max(cams, axis=0, keepdims=True), bg_factor)
77
+ cams = np.concatenate((bg_score, cams), axis=0)
78
+ prob = cams
79
+
80
+ image = image.astype(np.uint8).transpose(1, 2, 0)
81
+ prob = self.postprocessor(image, prob)
82
+
83
+ label = np.argmax(prob, axis=0)
84
+
85
+ label_tensor = torch.from_numpy(label).long()
86
+ refined_mask = F.one_hot(label_tensor).to(device=self.device)
87
+ refined_mask = refined_mask.permute(2, 0, 1)
88
+ refined_mask = refined_mask[1:].float()
89
+ return refined_mask
90
+
91
+ def __call__(self, image, cams, separate=False, bg_factor=1.0):
92
+ mean_bgr = (104.008, 116.669, 122.675)
93
+ # covert Image to numpy array
94
+ image = np.array(image).astype(np.float32)
95
+
96
+ # RGB -> BGR
97
+ image = image[:, :, ::-1]
98
+ # Mean subtraction
99
+ image -= mean_bgr
100
+ # HWC -> CHW
101
+ image = image.transpose(2, 0, 1)
102
+
103
+ if isinstance(cams, torch.Tensor):
104
+ cams = cams.cpu().detach().numpy()
105
+ if separate:
106
+ refined_mask = [
107
+ self.apply_crf(image, cam[None], bg_factor) for cam in cams
108
+ ]
109
+ refined_mask = torch.cat(refined_mask, dim=0)
110
+ else:
111
+ refined_mask = self.apply_crf(image, cams, bg_factor)
112
+
113
+ return refined_mask
modeling/model/utils.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """CAM utils."""
17
+
18
+ # pylint: disable=g-importing-member
19
+ import os
20
+
21
+ import cv2
22
+ import numpy as np
23
+ from PIL import Image
24
+ from scipy.ndimage import binary_fill_holes
25
+ import torch
26
+ from torchvision.transforms import Compose
27
+ from torchvision.transforms import Normalize
28
+ from torchvision.transforms import Resize
29
+ from torchvision.transforms import ToTensor
30
+
31
+ # pylint: disable=g-import-not-at-top
32
+ try:
33
+ from torchvision.transforms import InterpolationMode
34
+
35
+ BICUBIC = InterpolationMode.BICUBIC
36
+ except ImportError:
37
+ BICUBIC = Image.BICUBIC
38
+
39
+ _CONTOUR_INDEX = 1 if cv2.__version__.split('.')[0] == '3' else 0
40
+
41
+
42
+ def _convert_image_to_rgb(image):
43
+ return image.convert('RGB')
44
+
45
+
46
+ def _transform_resize(h, w):
47
+ return Compose([
48
+ Resize((h, w), interpolation=BICUBIC),
49
+ _convert_image_to_rgb,
50
+ ToTensor(),
51
+ Normalize(
52
+ (0.48145466, 0.4578275, 0.40821073),
53
+ (0.26862954, 0.26130258, 0.27577711),
54
+ ),
55
+ ])
56
+
57
+
58
+ def img_ms_and_flip(image, ori_height, ori_width, scales=1.0, patch_size=16):
59
+ """Resizes and flips the image."""
60
+ if isinstance(scales, float):
61
+ scales = [scales]
62
+
63
+ all_imgs = []
64
+ for scale in scales:
65
+ preprocess = _transform_resize(
66
+ int(np.ceil(scale * int(ori_height) / patch_size) * patch_size),
67
+ int(np.ceil(scale * int(ori_width) / patch_size) * patch_size),
68
+ )
69
+ image = preprocess(image)
70
+ image_ori = image
71
+ image_flip = torch.flip(image, [-1])
72
+ all_imgs.append(image_ori)
73
+ all_imgs.append(image_flip)
74
+ return all_imgs
75
+
76
+
77
+ def reshape_transform(tensor, height=28, width=28):
78
+ tensor = tensor.permute(1, 0, 2)
79
+ result = tensor[:, 1:, :].reshape(
80
+ tensor.size(0), height, width, tensor.size(2)
81
+ )
82
+
83
+ # Bring the channels to the first dimension, like in CNNs.
84
+ result = result.transpose(2, 3).transpose(1, 2)
85
+ return result
86
+
87
+
88
+ def vis_mask(image, mask, mask_color):
89
+ # switch the height and width of image
90
+ # image = image.transpose(1, 0, 2)
91
+ if mask.shape[0] != image.shape[0] or mask.shape[1] != image.shape[1]:
92
+ mask = cv2.resize(mask, (image.shape[1], image.shape[0]))
93
+ fg = mask > 0.5
94
+ rgb = np.copy(image)
95
+ rgb[fg] = (rgb[fg] * 0.3 + np.array(mask_color) * 0.7).astype(np.uint8)
96
+ return Image.fromarray(rgb)
97
+
98
+
99
+ def scoremap2bbox(scoremap, threshold, multi_contour_eval=False):
100
+ """Get bounding boxes from scoremap."""
101
+ height, width = scoremap.shape
102
+ scoremap_image = np.expand_dims((scoremap * 255).astype(np.uint8), 2)
103
+ while True:
104
+ _, thr_gray_heatmap = cv2.threshold(
105
+ src=scoremap_image,
106
+ thresh=int(threshold * np.max(scoremap_image)),
107
+ maxval=255,
108
+ type=cv2.THRESH_BINARY,
109
+ )
110
+ if thr_gray_heatmap.max() > 0 or threshold <= 0:
111
+ break
112
+ threshold -= 0.1
113
+ contours = cv2.findContours(
114
+ image=thr_gray_heatmap, mode=cv2.RETR_TREE, method=cv2.CHAIN_APPROX_SIMPLE
115
+ )[_CONTOUR_INDEX]
116
+
117
+ # if len(contours) == 0:
118
+ if not contours:
119
+ return np.asarray([[0, 0, 0, 0]]), 1
120
+
121
+ if not multi_contour_eval:
122
+ contours = [max(contours, key=cv2.contourArea)]
123
+
124
+ estimated_boxes = []
125
+ for contour in contours:
126
+ x, y, w, h = cv2.boundingRect(contour)
127
+ x0, y0, x1, y1 = x, y, x + w, y + h
128
+ x1 = min(x1, width - 1)
129
+ y1 = min(y1, height - 1)
130
+ estimated_boxes.append([x0, y0, x1, y1])
131
+
132
+ return np.asarray(estimated_boxes), len(contours)
133
+
134
+
135
+ def mask2chw(arr):
136
+ # Find the row and column indices where the array is 1
137
+ rows, cols = np.where(arr == 1)
138
+ # Calculate center of the mask
139
+ center_y = int(np.mean(rows))
140
+ center_x = int(np.mean(cols))
141
+ # Calculate height and width of the mask
142
+ height = rows.max() - rows.min() + 1
143
+ width = cols.max() - cols.min() + 1
144
+ return (center_y, center_x), height, width
145
+
146
+
147
+ def unpad(image_array, pad=None):
148
+ if pad is not None:
149
+ left, top, width, height = pad
150
+ image_array = image_array[top : top + height, left : left + width, :]
151
+ return image_array
152
+
153
+
154
+ def apply_visual_prompts(
155
+ image_array,
156
+ mask,
157
+ visual_prompt_type=('circle',),
158
+ visualize=False,
159
+ color=(255, 0, 0),
160
+ thickness=1,
161
+ blur_strength=(15, 15),
162
+ ):
163
+ """Applies visual prompts to the image."""
164
+ prompted_image = image_array.copy()
165
+ if 'blur' in visual_prompt_type:
166
+ # blur the part out side the mask
167
+ # Blur the entire image
168
+ blurred = cv2.GaussianBlur(prompted_image.copy(), blur_strength, 0)
169
+ # Get the sharp region using the mask
170
+ sharp_region = cv2.bitwise_and(
171
+ prompted_image.copy(),
172
+ prompted_image.copy(),
173
+ mask=np.clip(mask, 0, 255).astype(np.uint8),
174
+ )
175
+ # Get the blurred region using the inverted mask
176
+ inv_mask = 1 - mask
177
+ blurred_region = (blurred * inv_mask[:, :, None]).astype(np.uint8)
178
+ # Combine the sharp and blurred regions
179
+ prompted_image = cv2.add(sharp_region, blurred_region)
180
+ if 'gray' in visual_prompt_type:
181
+ gray = cv2.cvtColor(prompted_image.copy(), cv2.COLOR_BGR2GRAY)
182
+ # make gray part 3 channel
183
+ gray = np.stack([gray, gray, gray], axis=-1)
184
+ # Get the sharp region using the mask
185
+ color_region = cv2.bitwise_and(
186
+ prompted_image.copy(),
187
+ prompted_image.copy(),
188
+ mask=np.clip(mask, 0, 255).astype(np.uint8),
189
+ )
190
+ # Get the blurred region using the inverted mask
191
+ inv_mask = 1 - mask
192
+ gray_region = (gray * inv_mask[:, :, None]).astype(np.uint8)
193
+ # Combine the sharp and blurred regions
194
+ prompted_image = cv2.add(color_region, gray_region)
195
+ if 'black' in visual_prompt_type:
196
+ prompted_image = cv2.bitwise_and(
197
+ prompted_image.copy(),
198
+ prompted_image.copy(),
199
+ mask=np.clip(mask, 0, 255).astype(np.uint8),
200
+ )
201
+ if 'circle' in visual_prompt_type:
202
+ mask_center, mask_height, mask_width = mask2chw(mask)
203
+ center_coordinates = (mask_center[1], mask_center[0])
204
+ axes_length = (mask_width // 2, mask_height // 2)
205
+ prompted_image = cv2.ellipse(
206
+ prompted_image,
207
+ center_coordinates,
208
+ axes_length,
209
+ 0,
210
+ 0,
211
+ 360,
212
+ color,
213
+ thickness,
214
+ )
215
+ if 'rectangle' in visual_prompt_type:
216
+ mask_center, mask_height, mask_width = mask2chw(mask)
217
+ # center_coordinates = (mask_center[1], mask_center[0])
218
+ # axes_length = (mask_width // 2, mask_height // 2)
219
+ start_point = (
220
+ mask_center[1] - mask_width // 2,
221
+ mask_center[0] - mask_height // 2,
222
+ )
223
+ end_point = (
224
+ mask_center[1] + mask_width // 2,
225
+ mask_center[0] + mask_height // 2,
226
+ )
227
+ prompted_image = cv2.rectangle(
228
+ prompted_image, start_point, end_point, color, thickness
229
+ )
230
+ if 'contour' in visual_prompt_type:
231
+ # Find the contours of the mask
232
+ # fill holes for the mask
233
+ mask = binary_fill_holes(mask)
234
+ contours, _ = cv2.findContours(
235
+ mask.astype(np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE
236
+ )
237
+ # Draw the contours on the image
238
+ prompted_image = cv2.drawContours(
239
+ prompted_image.copy(), contours, -1, color, thickness
240
+ )
241
+
242
+ if visualize:
243
+ cv2.imwrite(os.path.join('masked_img.png'), prompted_image)
244
+ prompted_image = Image.fromarray(prompted_image.astype(np.uint8))
245
+ return prompted_image
modeling/model/utils_test.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """This file contains the unit tests for the utils.py file."""
17
+
18
+ import numpy as np
19
+ from PIL import Image
20
+ import torch
21
+
22
+ # pylint: disable=g-bad-import-order
23
+ from modeling.model import utils
24
+
25
+
26
+ def test_scoremap2bbox():
27
+ """Test the scoremap2bbox function."""
28
+ scoremap = np.zeros((10, 10))
29
+ scoremap[1:5, 1:5] = 1
30
+ scoremap[5:9, 5:9] = 2
31
+ scoremap[5:9, 1:5] = 3
32
+ scoremap[1:5, 5:9] = 4
33
+ bbox, len_bboxes = utils.scoremap2bbox(scoremap, 0.5)
34
+ assert len_bboxes == 1
35
+ assert bbox[0, 0] == 1
36
+ assert bbox[0, 1] == 1
37
+ assert bbox[0, 2] == 9
38
+ assert bbox[0, 3] == 9
39
+
40
+
41
+ def test_mask2chw():
42
+ """Test the mask2chw function."""
43
+ mask = np.zeros((10, 10))
44
+ mask[1:5, 1:5] = 1
45
+ mask[5:9, 5:9] = 2
46
+ mask[5:9, 1:5] = 3
47
+ mask[1:5, 5:9] = 4
48
+ mask = torch.tensor(mask)
49
+ mask_center, mask_height, mask_width = utils.mask2chw(mask)
50
+ assert len(mask_center) == 2
51
+ assert mask_center[0] == 2
52
+ assert mask_center[1] == 2
53
+ assert mask_height == 4
54
+ assert mask_width == 4
55
+
56
+
57
+ def test_unpad():
58
+ """Test the unpad function."""
59
+ image = np.zeros((10, 10, 1))
60
+ image[1:5, 1:5] = 1
61
+ image[5:9, 5:9] = 2
62
+ image[5:9, 1:5] = 3
63
+ image[1:5, 5:9] = 4
64
+ unpad_image = utils.unpad(image, pad=(1, 1, 8, 8))
65
+ assert len(unpad_image[0]) == 8, 'The width of the image is not 8.'
66
+ assert len(unpad_image[1]) == 8, 'The height of the image is not 8.'
67
+ unpad_image = utils.unpad(image, None)
68
+ assert (unpad_image == image).sum() == 100
69
+
70
+
71
+ def test_apply_visual_prompts():
72
+ """Test the apply_visual_prompts function."""
73
+ image = np.ones((5, 5))
74
+ mask = np.array([
75
+ [0, 0, 0, 0, 0],
76
+ [0, 0, 0, 0, 0],
77
+ [0, 0, 1.0, 0, 0],
78
+ [0, 0, 0, 0, 0],
79
+ [0, 0, 0, 0, 0],
80
+ ])
81
+
82
+ target = np.array([
83
+ [1, 1, 255, 1, 1],
84
+ [1, 255, 1, 255, 1],
85
+ [255, 1, 1, 1, 255],
86
+ [1, 255, 1, 255, 1],
87
+ [1, 1, 255, 1, 1],
88
+ ])
89
+ mask[1:5, 1:5] = 1
90
+ prompted_image = utils.apply_visual_prompts(
91
+ image, mask, visual_prompt_type='circle', thickness=1
92
+ )
93
+ prompted_array = np.array(prompted_image)
94
+ assert (prompted_array == target).sum() == 25
95
+
96
+
97
+ def test_reshape_transform():
98
+ """Test the reshape_transform function."""
99
+ image = torch.zeros((101, 10, 32))
100
+ image = utils.reshape_transform(image, height=10, width=10)
101
+ b, c, h, w = image.shape
102
+ assert b == 10
103
+ assert c == 32
104
+ assert h == 10
105
+ assert w == 10
106
+
107
+
108
+ def test_img_ms_and_flip():
109
+ """Test the img_ms_and_flip function."""
110
+ image = np.zeros((120, 150))
111
+ image[1:5, 1:5] = 1
112
+ image[5:9, 5:9] = 2
113
+ image[5:9, 1:5] = 3
114
+ image[1:5, 5:9] = 4
115
+ image = Image.fromarray(image)
116
+ image = utils.img_ms_and_flip(image, 120, 150, scales=[1.2], patch_size=16)
117
+ image = image[0]
118
+ h, w = image.shape[-2:]
119
+ assert h == int(np.ceil(1.2 * 120 / 16) * 16)
120
+ assert w == int(np.ceil(1.2 * 150 / 16) * 16)
121
+
122
+
123
+ if __name__ == '__main__':
124
+ test_scoremap2bbox()
125
+ test_mask2chw()
126
+ test_unpad()
127
+ test_apply_visual_prompts()
128
+ test_reshape_transform()
129
+ test_img_ms_and_flip()
modeling/post_process/object_discovery.py ADDED
@@ -0,0 +1,355 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Find objects."""
17
+
18
+ # pylint: disable=g-importing-member
19
+ import numpy as np
20
+ import scipy
21
+ from scipy import ndimage
22
+ from scipy.linalg import eigh
23
+ from scipy.ndimage import label
24
+ import torch
25
+ import torch.nn.functional as F
26
+
27
+
28
+ def ncut(
29
+ feats,
30
+ dims,
31
+ scales,
32
+ init_image_size,
33
+ tau=0,
34
+ eps=1e-5,
35
+ no_binary_graph=False,
36
+ ):
37
+ """Implementation of NCut Method.
38
+
39
+ Args:
40
+ feats: the pixel/patche features of an image
41
+ dims: dimension of the map from which the features are used
42
+ scales: from image to map scale
43
+ init_image_size: size of the image
44
+ tau: thresold for graph construction
45
+ eps: graph edge weight
46
+ no_binary_graph: ablation study for using similarity score as graph
47
+ edge weight
48
+ Returns:
49
+ TODO
50
+ """
51
+ feats = feats[0, 1:, :]
52
+ feats = F.normalize(feats, p=2)
53
+ a = feats @ feats.transpose(1, 0)
54
+ a = a.cpu().numpy()
55
+ if no_binary_graph:
56
+ a[a < tau] = eps
57
+ else:
58
+ a = a > tau
59
+ a = np.where(a.astype(float) == 0, eps, a)
60
+ d_i = np.sum(a, axis=1)
61
+ d = np.diag(d_i)
62
+
63
+ # Print second and third smallest eigenvector
64
+ _, eigenvectors = eigh(d - a, d, subset_by_index=[1, 2])
65
+ eigenvec = np.copy(eigenvectors[:, 0])
66
+
67
+ # Using average point to compute bipartition
68
+ second_smallest_vec = eigenvectors[:, 0]
69
+ avg = np.sum(second_smallest_vec) / len(second_smallest_vec)
70
+ bipartition = second_smallest_vec > avg
71
+
72
+ seed = np.argmax(np.abs(second_smallest_vec))
73
+
74
+ if bipartition[seed] != 1:
75
+ eigenvec = eigenvec * -1
76
+ bipartition = np.logical_not(bipartition)
77
+ bipartition = bipartition.reshape(dims).astype(float)
78
+
79
+ # predict BBox
80
+ # We only extract the principal object BBox
81
+ pred, _, objects, cc = detect_box(
82
+ bipartition,
83
+ seed,
84
+ dims,
85
+ scales=scales,
86
+ initial_im_size=init_image_size[1:],
87
+ )
88
+ mask = np.zeros(dims)
89
+ mask[cc[0], cc[1]] = 1
90
+
91
+ return np.asarray(pred), objects, mask, seed, None, eigenvec.reshape(dims)
92
+
93
+
94
+ def grad_obj_discover_on_attn(attn, gradcam, dims, topk=1, threshold=0.6):
95
+ """Get the gradcam and attn map, then find the seed, then use LOST algorithm to find the potential points.
96
+
97
+ Args:
98
+ attn: attention map from ViT averaged across all heads, shape: [1,
99
+ (1+num_patches), (1+num_patches)].
100
+ gradcam: gradcam map from ViT, shape: [1, 1, H, W].
101
+ dims:
102
+ topk:
103
+ threshold:
104
+ Returns:
105
+ th_attn:
106
+ """
107
+
108
+ w_featmap, h_featmap = dims
109
+ # nh = attn.shape[1]
110
+ attn = attn.squeeze()
111
+
112
+ seeds = torch.argsort(gradcam.flatten(), descending=True)[:topk]
113
+
114
+ # We keep only the output patch attention
115
+ # Get the attentions corresponding to [CLS] token
116
+ patch_attn = attn[1:, 1:]
117
+ topk_attn = patch_attn[seeds]
118
+ nh = topk_attn.shape[0]
119
+ # attentions = attn[0, :, 0, 1:].reshape(nh, -1)
120
+
121
+ # we keep only a certain percentage of the mass
122
+ val, idx = torch.sort(topk_attn)
123
+ val /= torch.sum(val, dim=1, keepdim=True)
124
+ cumval = torch.cumsum(val, dim=1)
125
+ th_attn = cumval > (1 - threshold)
126
+ idx2 = torch.argsort(idx)
127
+ for h in range(nh):
128
+ th_attn[h] = th_attn[h][idx2[h]]
129
+ th_attn = th_attn.reshape(nh, w_featmap, h_featmap).float()
130
+ th_attn = th_attn.sum(0)
131
+ th_attn[th_attn > 1] = 1
132
+ return th_attn[None, None]
133
+
134
+
135
+ def grad_obj_discover(feats, gradcam, dims):
136
+ """Using gradient heatmap to find the seed, then use LOST algorithm to find the potential points.
137
+
138
+ Args:
139
+ feats: the pixel/patche features of an image. Shape: [1, HW, C]
140
+ gradcam: the grad cam map
141
+ dims: dimension of the map from which the features are used
142
+
143
+ Returns:
144
+ pred: box predictions
145
+ A: binary affinity matrix
146
+ scores: lowest degree scores for all patches
147
+ seed: selected patch corresponding to an object
148
+ """
149
+ # Compute the similarity
150
+ a = (feats @ feats.transpose(1, 2)).squeeze()
151
+
152
+ # Compute the inverse degree centrality measure per patch
153
+ # sorted_patches, scores = patch_scoring(a)
154
+
155
+ # Select the initial seed
156
+ # seed = sorted_patches[0]
157
+ seed = gradcam.argmax()
158
+ mask = a[seed]
159
+ mask = mask.view(1, 1, *dims)
160
+
161
+ return mask
162
+
163
+
164
+ def lost(feats, dims, scales, init_image_size, k_patches=100):
165
+ """Implementation of LOST method.
166
+
167
+ Args:
168
+ feats: the pixel/patche features of an image. Shape: [1, C, H, W]
169
+ dims: dimension of the map from which the features are used
170
+ scales: from image to map scale
171
+ init_image_size: size of the image
172
+ k_patches: number of k patches retrieved that are compared to the seed
173
+ at seed expansion.
174
+ Returns:
175
+ pred: box predictions
176
+ A: binary affinity matrix
177
+ scores: lowest degree scores for all patches
178
+ seed: selected patch corresponding to an object
179
+ """
180
+ # Compute the similarity
181
+ feats = feats.flatten(2).transpose(1, 2)
182
+ a = (feats @ feats.transpose(1, 2)).squeeze()
183
+
184
+ # Compute the inverse degree centrality measure per patch
185
+ sorted_patches, _ = patch_scoring(a)
186
+
187
+ # Select the initial seed
188
+ seed = sorted_patches[0]
189
+
190
+ # Seed expansion
191
+ potentials = sorted_patches[:k_patches]
192
+ similars = potentials[a[seed, potentials] > 0.0]
193
+ m = torch.sum(a[similars, :], dim=0)
194
+
195
+ # Box extraction
196
+ _, _, _, mask = detect_box(
197
+ m, seed, dims, scales=scales, initial_im_size=init_image_size[1:]
198
+ )
199
+
200
+ return mask
201
+ # return np.asarray(bbox), A, scores, seed
202
+
203
+
204
+ def patch_scoring(m, threshold=0.0):
205
+ """Patch scoring based on the inverse degree."""
206
+ # Cloning important
207
+ a = m.clone()
208
+
209
+ # Zero diagonal
210
+ a.fill_diagonal_(0)
211
+
212
+ # Make sure symmetric and non nul
213
+ a[a < 0] = 0
214
+ # C = A + A.t()
215
+
216
+ # Sort pixels by inverse degree
217
+ cent = -torch.sum(a > threshold, dim=1).type(torch.float32)
218
+ sel = torch.argsort(cent, descending=True)
219
+
220
+ return sel, cent
221
+
222
+
223
+ def detect_box(
224
+ bipartition,
225
+ seed,
226
+ dims,
227
+ initial_im_size=None,
228
+ scales=None,
229
+ principle_object=True,
230
+ ):
231
+ """Extract a box corresponding to the seed patch."""
232
+
233
+ # Among connected components extract from the affinity matrix, select the one
234
+ # corresponding to the seed patch.
235
+
236
+ # w_featmap, h_featmap = dims
237
+ objects, _ = ndimage.label(bipartition)
238
+ cc = objects[np.unravel_index(seed, dims)]
239
+
240
+ if principle_object:
241
+ mask = np.where(objects == cc)
242
+ # Add +1 because excluded max
243
+ ymin, ymax = min(mask[0]), max(mask[0]) + 1
244
+ xmin, xmax = min(mask[1]), max(mask[1]) + 1
245
+ # Rescale to image size
246
+ r_xmin, r_xmax = scales[1] * xmin, scales[1] * xmax
247
+ r_ymin, r_ymax = scales[0] * ymin, scales[0] * ymax
248
+ pred = [r_xmin, r_ymin, r_xmax, r_ymax]
249
+
250
+ # Check not out of image size (used when padding)
251
+ if initial_im_size:
252
+ pred[2] = min(pred[2], initial_im_size[1])
253
+ pred[3] = min(pred[3], initial_im_size[0])
254
+
255
+ # Coordinate predictions for the feature space
256
+ # Axis different then in image space
257
+ pred_feats = [ymin, xmin, ymax, xmax]
258
+
259
+ return pred, pred_feats, objects, mask
260
+ else:
261
+ raise NotImplementedError
262
+
263
+
264
+ # This function is modified from
265
+ # https://github.com/facebookresearch/dino/blob/main/visualize_attention.py
266
+ # Ref: https://github.com/facebookresearch/dino.
267
+ def dino_seg(attn, dims, patch_size, head=0):
268
+ """Extraction of boxes based on the DINO segmentation method proposed in DINO."""
269
+ w_featmap, h_featmap = dims
270
+ nh = attn.shape[1]
271
+ official_th = 0.6
272
+
273
+ # We keep only the output patch attention
274
+ # Get the attentions corresponding to [CLS] token
275
+ attentions = attn[0, :, 0, 1:].reshape(nh, -1)
276
+
277
+ # we keep only a certain percentage of the mass
278
+ val, idx = torch.sort(attentions)
279
+ val /= torch.sum(val, dim=1, keepdim=True)
280
+ cumval = torch.cumsum(val, dim=1)
281
+ th_attn = cumval > (1 - official_th)
282
+ idx2 = torch.argsort(idx)
283
+ for h in range(nh):
284
+ th_attn[h] = th_attn[h][idx2[h]]
285
+ th_attn = th_attn.reshape(nh, w_featmap, h_featmap).float()
286
+
287
+ # Connected components
288
+ labeled_array, _ = scipy.ndimage.label(th_attn[head].cpu().numpy())
289
+
290
+ # Find the biggest component
291
+ size_components = [
292
+ np.sum(labeled_array == c) for c in range(np.max(labeled_array))
293
+ ]
294
+
295
+ if len(size_components) > 1:
296
+ # Select the biggest component avoiding component 0 corresponding
297
+ # to background
298
+ biggest_component = np.argmax(size_components[1:]) + 1
299
+ else:
300
+ # Cases of a single component
301
+ biggest_component = 0
302
+
303
+ # Mask corresponding to connected component
304
+ mask = np.where(labeled_array == biggest_component)
305
+
306
+ # Add +1 because excluded max
307
+ ymin, ymax = min(mask[0]), max(mask[0]) + 1
308
+ xmin, xmax = min(mask[1]), max(mask[1]) + 1
309
+
310
+ # Rescale to image
311
+ r_xmin, r_xmax = xmin * patch_size, xmax * patch_size
312
+ r_ymin, r_ymax = ymin * patch_size, ymax * patch_size
313
+ pred = [r_xmin, r_ymin, r_xmax, r_ymax]
314
+
315
+ return pred
316
+
317
+
318
+ def get_feats(feat_out, shape):
319
+ # Batch size, Number of heads, Number of tokens
320
+ nb_im, nh, nb_tokens = shape[0:3]
321
+ qkv = (
322
+ feat_out["qkv"]
323
+ .reshape(nb_im, nb_tokens, 3, nh, -1 // nh)
324
+ .permute(2, 0, 3, 1, 4)
325
+ )
326
+ k = qkv[1]
327
+ k = k.transpose(1, 2).reshape(nb_im, nb_tokens, -1)
328
+ return k
329
+
330
+
331
+ def get_instances(masks, return_largest=False):
332
+ return [
333
+ get_instances_single(m[None], return_largest=return_largest)
334
+ for m in masks
335
+ ]
336
+
337
+
338
+ def get_instances_single(mask, return_largest=False):
339
+ """Get the mask of a single instance."""
340
+ labeled_array, _ = label(mask.cpu().numpy())
341
+ instances = np.concatenate(
342
+ [labeled_array == c for c in range(np.max(labeled_array) + 1)], axis=0
343
+ )
344
+ if return_largest:
345
+ size_components = np.sum(instances, axis=(1, 2))
346
+ if len(size_components) > 1:
347
+ # Select the biggest component avoiding component 0 corresponding
348
+ # to background
349
+ biggest_component = np.argmax(size_components[1:]) + 1
350
+ else:
351
+ # Cases of a single component
352
+ biggest_component = 0
353
+ # Mask corresponding to connected component
354
+ return torch.from_numpy(labeled_array == biggest_component).float()
355
+ return torch.from_numpy(instances[1:]).float()
modeling/post_process/post_process.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Post processing."""
17
+
18
+ import torch
19
+ import torch.nn.functional as F
20
+
21
+ # pylint: disable=g-bad-import-order
22
+ # pylint: disable=g-importing-member
23
+ from modeling.post_process.object_discovery import get_instances
24
+ from utils.metrics import IoM
25
+
26
+
27
+ # This should be a abstract function to generate masks for the input image.
28
+ # However, we first hack it due to the time limit.
29
+ def generate_masks_from_sam(
30
+ image_path, save_path, pipeline, img_sam=None, visualize=True
31
+ ):
32
+ """Generate masks from SAM."""
33
+ masks, _, mask_list = pipeline.segment_automask(
34
+ image_path=image_path,
35
+ visualize=visualize,
36
+ save_path=save_path,
37
+ image=img_sam,
38
+ )
39
+ mask_tensor = torch.from_numpy(masks)
40
+ mask_tensor = mask_tensor.float()
41
+ return mask_tensor, mask_list
42
+
43
+
44
+ def match_masks(
45
+ mask_tensor, attn_map, mask_list, iom_thres=0.0, min_pred_threshold=0.2
46
+ ):
47
+ """Match masks with the attention map according to the IoU.
48
+
49
+ Args:
50
+ mask_tensor: A torch.Tensor for the masks with shape [num_masks, height,
51
+ width].
52
+ attn_map: A torch.Tensor for the attention map with shape [1, 1, height,
53
+ width].
54
+ mask_list: A list of masks with shape [num_masks, height, width]
55
+ iom_thres: A float for the threshold to apply to the attention map.
56
+ min_pred_threshold: The prediction score threshold.
57
+
58
+ Returns:
59
+ A list of matched_masks with shape [num_masks, height, width],
60
+ len(matched_masks) = number of captions
61
+ """
62
+ predictions = attn_map.squeeze(1).detach()
63
+ iom = IoM(predictions, mask_tensor, min_pred_threshold=min_pred_threshold)
64
+ keep_mask = iom > iom_thres
65
+ # mask_tensor = mask_tensor[keep_mask]
66
+ new_list = []
67
+ for mid, m_dict in enumerate(mask_list):
68
+ if keep_mask[mid]:
69
+ new_list.append(m_dict)
70
+ # if not len(new_list):
71
+ if not new_list:
72
+ max_id = torch.argmax(iom)
73
+ new_list.append(mask_list[max_id])
74
+ return new_list
75
+
76
+
77
+ def post_process_mask(attn_masks, pad=None, min_area_ratio=0.15):
78
+ """Post process attention masks."""
79
+ if pad is not None:
80
+ left, top, width, height = pad
81
+ attn_masks = attn_masks[Ellipsis, top : top + height, left : left + width]
82
+ else:
83
+ height = None
84
+ width = None
85
+ mask_area = attn_masks.sum(dim=(1, 2))
86
+ total_area = mask_area.sum()
87
+ keep_mask = mask_area / total_area > min_area_ratio
88
+ if torch.sum(keep_mask) == 0:
89
+ if keep_mask.shape[0] == 0:
90
+ return torch.zeros(
91
+ (1, height, width), device=attn_masks.device, dtype=attn_masks.dtype
92
+ )
93
+ keep_mask[torch.argmax(mask_area)] = True
94
+ attn_masks = attn_masks[keep_mask]
95
+ return attn_masks
96
+
97
+
98
+ def filter_masks(
99
+ attn_masks,
100
+ pad=None,
101
+ mask_threshold=0.3,
102
+ min_area_ratio=0.15,
103
+ return_largest=False,
104
+ device=None,
105
+ return_instances=False,
106
+ ):
107
+ """Filter attention mask below the threshold."""
108
+ attn_masks[attn_masks < mask_threshold] = 0
109
+ # get_instances will be operated on cpu
110
+ ins_masks = get_instances(attn_masks, return_largest=return_largest)
111
+ ins_masks = [post_process_mask(m, pad, min_area_ratio) for m in ins_masks]
112
+ ins_masks = list(filter(lambda x: x is not None, ins_masks))
113
+ ins_masks = [m.to(device) for m in ins_masks]
114
+ if not return_instances:
115
+ return [torch.any(m, dim=0, keepdim=True).to(m.dtype) for m in ins_masks]
116
+ return ins_masks
117
+
118
+
119
+ def post_process(
120
+ input_array,
121
+ attn_masks,
122
+ pad=None,
123
+ mask_threshold=0.3,
124
+ return_largest=False,
125
+ min_area_ratio=0.15,
126
+ return_instances=False,
127
+ ):
128
+ """post process the input tensor with the attention masks.
129
+
130
+ Args:
131
+ input_array: A np.ndarray input array to be post processed with shape
132
+ [width, height, 3, batch_size]
133
+ attn_masks: A torch.Tensor for the attention masks with shape [1,
134
+ num_texts, width, height]
135
+ pad: A list of padding: [pad_left, pad_top, width, height], where
136
+ pad_left, pad_top and width, height are int values.
137
+ mask_threshold: The threshold to binarize the mask.
138
+ return_largest: If true, return the largest connected component.
139
+ min_area_ratio: Keep the mask if its area is larger than this threshold.
140
+ return_instances: Whether to return instances or not.
141
+
142
+ Returns:
143
+ attn_masks: A list of tensors with shape [num_instances, height, width]
144
+ x num_texts, where len(attn_masks) = num_texts.
145
+ NOTE: the number_instances for each text (class) may vary.
146
+ The output is a binary tensor.
147
+ """
148
+ if len(attn_masks.shape) == 3:
149
+ attn_masks = attn_masks[None]
150
+ img_width, img_height = input_array.shape[:2]
151
+ attn_masks = F.interpolate(
152
+ attn_masks, size=(img_height, img_width), mode='bicubic'
153
+ ).squeeze(0)
154
+ device = attn_masks.device
155
+ output_masks = filter_masks(
156
+ attn_masks,
157
+ pad=pad,
158
+ mask_threshold=mask_threshold,
159
+ min_area_ratio=min_area_ratio,
160
+ return_largest=return_largest,
161
+ device=device,
162
+ return_instances=return_instances,
163
+ )
164
+ if pad is not None:
165
+ left, top, width, height = pad
166
+ input_array = input_array[top : top + height, left : left + width]
167
+ return input_array, output_masks
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ tensorflow>=2.14.0
2
+ numpy>=1.16.4
3
+ torch>=2.0.0
4
+ torchvision>=0.15.1
sam/__init__.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """SAM(Segment Anything Model)."""
17
+
18
+ from .sam import *
19
+ from .utils import *
sam/sam.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """A pipeline for segmenting objects using the SAM model."""
17
+
18
+ # Copyright 2024 The Google Research Authors.
19
+ # This file is based on the SAM (Segment Anything) and HQ-SAM.
20
+ #
21
+ # https://github.com/facebookresearch/segment-anything
22
+ # https://github.com/SysCV/sam-hq/tree/main
23
+ #
24
+ # Licensed under the Apache License, Version 2.0 (the "License");
25
+ # you may not use this file except in compliance with the License.
26
+ # You may obtain a copy of the License at
27
+ #
28
+ # http://www.apache.org/licenses/LICENSE-2.0
29
+ #
30
+ # Unless required by applicable law or agreed to in writing, software
31
+ # distributed under the License is distributed on an "AS IS" BASIS,
32
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
33
+ # See the License for the specific language governing permissions and
34
+ # limitations under the License.
35
+
36
+
37
+ # pylint: disable=all
38
+ # pylint: disable=g-importing-member
39
+ import os
40
+ import cv2
41
+ import matplotlib.pyplot as plt
42
+ import numpy as np
43
+ from sam.utils import show_anns
44
+ from sam.utils import show_box
45
+ from sam.utils import show_mask
46
+ from sam.utils import show_points
47
+ from segment_anything import sam_model_registry
48
+ from segment_anything import SamAutomaticMaskGenerator
49
+ from segment_anything import SamPredictor
50
+
51
+
52
+ class SAMPipeline:
53
+
54
+ def __init__(
55
+ self,
56
+ checkpoint,
57
+ model_type,
58
+ device="cuda:0",
59
+ points_per_side=32,
60
+ pred_iou_thresh=0.88,
61
+ stability_score_thresh=0.95,
62
+ box_nms_thresh=0.7,
63
+ ):
64
+ self.checkpoint = checkpoint
65
+ self.model_type = model_type
66
+ self.device = device
67
+ self.sam = sam_model_registry[self.model_type](checkpoint=self.checkpoint)
68
+ self.sam.to(device=self.device)
69
+ self.load_mask_generator(
70
+ points_per_side=points_per_side,
71
+ pred_iou_thresh=pred_iou_thresh,
72
+ stability_score_thresh=stability_score_thresh,
73
+ box_nms_thresh=box_nms_thresh,
74
+ )
75
+
76
+ # Default Prompt Args
77
+ self.click_args = {"k": 5, "order": "max", "how_filter": "median"}
78
+ self.box_args = None
79
+
80
+ def load_sam(self):
81
+ print("Loading SAM")
82
+ sam = sam_model_registry[self.model_type](checkpoint=self.checkpoint)
83
+ sam.to(device=self.device)
84
+ self.predictor = SamPredictor(sam)
85
+ print("Loading Done")
86
+
87
+ def load_mask_generator(
88
+ self,
89
+ points_per_side,
90
+ pred_iou_thresh,
91
+ stability_score_thresh,
92
+ box_nms_thresh,
93
+ ):
94
+ print("Loading SAM")
95
+ self.mask_generator = SamAutomaticMaskGenerator(
96
+ model=self.sam,
97
+ points_per_side=points_per_side,
98
+ pred_iou_thresh=pred_iou_thresh,
99
+ stability_score_thresh=stability_score_thresh,
100
+ box_nms_thresh=box_nms_thresh,
101
+ crop_n_layers=0,
102
+ crop_n_points_downscale_factor=1,
103
+ )
104
+ print("Loading Done")
105
+
106
+ # segment single object
107
+ def segment_image_single(
108
+ self,
109
+ image_path,
110
+ input_point=None,
111
+ input_label=None,
112
+ input_box=None,
113
+ input_mask=None,
114
+ multimask_output=True,
115
+ visualize=False,
116
+ save_path=None,
117
+ fname="",
118
+ image=None,
119
+ ):
120
+ if image is None:
121
+ image = cv2.imread(image_path)
122
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
123
+ self.predictor.set_image(image)
124
+ masks, scores, logits = self.predictor.predict(
125
+ point_coords=input_point,
126
+ point_labels=input_label,
127
+ box=input_box,
128
+ mask_input=None,
129
+ multimask_output=multimask_output,
130
+ )
131
+
132
+ if visualize:
133
+ self.visualize(
134
+ image,
135
+ masks,
136
+ scores,
137
+ save_path,
138
+ input_point=input_point,
139
+ input_label=input_label,
140
+ input_box=input_box,
141
+ input_mask=input_mask,
142
+ fname=fname,
143
+ )
144
+
145
+ return masks, scores, logits
146
+
147
+ def segment_automask(
148
+ self,
149
+ image_path,
150
+ visualize=False,
151
+ save_path=None,
152
+ image=None,
153
+ fname="automask.jpg",
154
+ ):
155
+ if image is None:
156
+ image = cv2.imread(image_path)
157
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
158
+
159
+ mask_list, bbox_list = [], []
160
+ masks = self.mask_generator.generate(image)
161
+ mask_list.extend([mask["segmentation"] for mask in masks])
162
+ bbox_list.extend([mask["bbox"] for mask in masks])
163
+
164
+ if visualize:
165
+ self.visualize_automask(image, masks, save_path, fname=fname)
166
+
167
+ masks_arr, bbox_arr = np.array(mask_list), np.array(bbox_list)
168
+ return masks_arr, bbox_arr, masks
169
+
170
+ def visualize_automask(self, image, masks, save_path, fname="mask.jpg"):
171
+ if not os.path.exists(save_path):
172
+ os.makedirs(save_path)
173
+ plt.figure(figsize=(20, 20))
174
+ plt.imshow(image)
175
+ show_anns(masks)
176
+ plt.axis("off")
177
+ plt.savefig(os.path.join(save_path, fname))
178
+
179
+ def visualize(
180
+ self,
181
+ image,
182
+ masks,
183
+ scores,
184
+ save_path,
185
+ input_point=None,
186
+ input_label=None,
187
+ input_box=None,
188
+ input_mask=None,
189
+ fname="",
190
+ ):
191
+ for i, (mask, score) in enumerate(zip(masks, scores)):
192
+ plt.figure(figsize=(10, 10))
193
+ plt.imshow(image)
194
+ show_mask(mask, plt.gca())
195
+ if input_point is not None:
196
+ show_points(input_point, input_label, plt.gca())
197
+ if input_box is not None:
198
+ show_box(input_box, plt.gca())
199
+ if input_mask is not None:
200
+ show_mask(input_mask[0], plt.gca(), True)
201
+ plt.title(f"Mask {i+1}, Score: {score:.3f}", fontsize=18)
202
+ plt.axis("off")
203
+ plt.savefig(os.path.join(save_path, f"{fname}{i}.jpg"))
204
+
205
+ return input_point, input_label, input_box, input_mask
sam/utils.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ # Copyright 2024 The Google Research Authors.
17
+ # This file is based on the SAM (Segment Anything) and HQ-SAM.
18
+ #
19
+ # https://github.com/facebookresearch/segment-anything
20
+ # https://github.com/SysCV/sam-hq/tree/main
21
+ #
22
+ # Licensed under the Apache License, Version 2.0 (the "License");
23
+ # you may not use this file except in compliance with the License.
24
+ # You may obtain a copy of the License at
25
+ #
26
+ # http://www.apache.org/licenses/LICENSE-2.0
27
+ #
28
+ # Unless required by applicable law or agreed to in writing, software
29
+ # distributed under the License is distributed on an "AS IS" BASIS,
30
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
31
+ # See the License for the specific language governing permissions and
32
+ # limitations under the License.
33
+
34
+ """SAM Utilities."""
35
+ # pylint: disable=all
36
+ # pylint: disable=g-importing-member
37
+ import json
38
+ import matplotlib.pyplot as plt
39
+ import numpy as np
40
+ from scipy.spatial.distance import cdist
41
+
42
+
43
+ def show_mask(mask, ax, random_color=False):
44
+ if random_color:
45
+ color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
46
+ else:
47
+ color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
48
+ h, w = mask.shape[-2:]
49
+ mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
50
+ ax.imshow(mask_image)
51
+
52
+
53
+ def show_points(coords, labels, ax, marker_size=375):
54
+ pos_points = coords[labels == 1]
55
+ neg_points = coords[labels == 0]
56
+ ax.scatter(
57
+ pos_points[:, 0],
58
+ pos_points[:, 1],
59
+ color='green',
60
+ marker='*',
61
+ s=marker_size,
62
+ edgecolor='white',
63
+ linewidth=1.25,
64
+ )
65
+ ax.scatter(
66
+ neg_points[:, 0],
67
+ neg_points[:, 1],
68
+ color='red',
69
+ marker='*',
70
+ s=marker_size,
71
+ edgecolor='white',
72
+ linewidth=1.25,
73
+ )
74
+
75
+
76
+ def show_box(box, ax):
77
+ x0, y0, x1, y1 = box
78
+ w, h = x1 - x0, y1 - y0
79
+ ax.add_patch(
80
+ plt.Rectangle(
81
+ (x0, y0), w, h, edgecolor='red', facecolor=(0, 0, 0, 0), lw=2
82
+ )
83
+ )
84
+
85
+
86
+ def show_anns(anns):
87
+ if len(anns) == 0:
88
+ return
89
+ for index, dictionary in enumerate(anns):
90
+ dictionary['id'] = index
91
+
92
+ sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
93
+ ax = plt.gca()
94
+ ax.set_autoscale_on(False)
95
+ # polygons = []
96
+ # color = []
97
+ for ann in sorted_anns:
98
+ m = ann['segmentation']
99
+ img = np.ones((m.shape[0], m.shape[1], 3))
100
+ color_mask = np.random.random((1, 3)).tolist()[0]
101
+ for i in range(3):
102
+ img[:, :, i] = color_mask[i]
103
+ ax.imshow(np.dstack((img, m * 0.35)))
104
+
105
+ # Get the centroid of the mask
106
+ mask_y, mask_x = np.nonzero(m)
107
+ centroid_x, centroid_y = np.mean(mask_x), np.mean(mask_y)
108
+
109
+ # Display the mask ID
110
+ mask_id = ann['id']
111
+ ax.text(
112
+ centroid_x,
113
+ centroid_y,
114
+ str(mask_id),
115
+ color='black',
116
+ fontsize=48,
117
+ weight='bold',
118
+ )
119
+
120
+
121
+ # Turn CAM result to SAM prompt
122
+ def aggregate_RGB_channel(activation_mask, how='max'):
123
+ B, C, H, W = activation_mask.shape
124
+ if how == 'max':
125
+ res_activation_mask = np.amax(activation_mask, axis=1, keepdims=True)
126
+ elif how == 'avr':
127
+ res_activation_mask = np.mean(activation_mask, axis=1, keepdims=True)
128
+ res_activation_mask = res_activation_mask.reshape(B, 1, H * W)
129
+
130
+ res_activation_mask = np.squeeze(res_activation_mask, axis=1)
131
+ return res_activation_mask
132
+
133
+
134
+ def find_k_points(arr, k, order='max', how_filter='median'):
135
+ arr = arr.squeeze(0)
136
+ flat_indices = np.argpartition(arr.flatten(), -k)[-k:]
137
+ unravel_topk_idx = np.unravel_index(flat_indices, arr.shape)
138
+ topk_indices = np.array(unravel_topk_idx).transpose()[:, ::-1]
139
+ # print(topk_indices.shape)
140
+
141
+ if how_filter == 'random':
142
+ random_rows = np.random.choice(
143
+ topk_indices.shape[0], size=int(round(k / 16)), replace=False
144
+ )
145
+ topk_indices = topk_indices[random_rows]
146
+ elif how_filter == 'median':
147
+ distances = cdist(topk_indices, topk_indices)
148
+ distances = np.sum(distances, axis=1)
149
+ median_distance = np.median(distances)
150
+ filtered_idx = [
151
+ i for i in range(len(distances)) if distances[i] < median_distance
152
+ ]
153
+ topk_indices = topk_indices[filtered_idx]
154
+ return topk_indices
155
+
156
+
157
+ def max_sum_submatrix(matrix):
158
+ matrix = np.array(matrix)
159
+ H, W = matrix.shape
160
+ # Preprocess cumulative sums for rows
161
+ matrix[:, 1:] += matrix[:, :-1]
162
+ max_sum = float('-inf')
163
+ max_rect = (0, 0, 0, 0) # (top, left, bottom, right)
164
+
165
+ for left in range(W):
166
+ for right in range(left, W):
167
+ # Apply 1D Kadane's algorithm for the current pair of columns
168
+ column_sum = matrix[:, right] - (matrix[:, left - 1] if left > 0 else 0)
169
+ max_ending_here = max_so_far = column_sum[0]
170
+ start, end = 0, 0
171
+
172
+ for i in range(1, H):
173
+ val = column_sum[i]
174
+ if max_ending_here > 0:
175
+ max_ending_here += val
176
+ else:
177
+ max_ending_here = val
178
+ start = i
179
+
180
+ if max_ending_here > max_so_far:
181
+ max_so_far = max_ending_here
182
+ end = i
183
+
184
+ if max_so_far > max_sum:
185
+ max_sum = max_so_far
186
+ max_rect = (start, left, end, right)
187
+
188
+ return max_sum, max_rect
189
+
190
+
191
+ def CAM2SAMClick(activation_map, k=5, order='max', how_filter='median'):
192
+ # activation_map = aggregate_RGB_channel(activation_map)
193
+ H, W, C = activation_map.shape
194
+ activation_map = activation_map.reshape((1, 1, H, W))
195
+ coords = []
196
+ for nrow in range(activation_map.shape[0]):
197
+ coord = find_k_points(activation_map[nrow], k, order, how_filter)
198
+ coords.append(coord)
199
+ return coords
200
+
201
+
202
+ def CAM2SAMBox(activation_map):
203
+ # print(activation_map.shape)
204
+ # activation_map = aggregate_RGB_channel(activation_map)
205
+ H, W, C = activation_map.shape
206
+ activation_map = activation_map.reshape((1, H, W))
207
+ box_coordinates = []
208
+ for nrow in range(activation_map.shape[0]):
209
+ # print(activation_map[nrow].shape)
210
+ arr = activation_map[nrow]
211
+
212
+ norm_arr = 2 * ((arr - np.min(arr)) / (np.max(arr) - np.min(arr))) - 1
213
+ # print(norm_arr.shape)
214
+ _, box_coordinate = max_sum_submatrix(norm_arr)
215
+ box_coordinates.append(box_coordinate)
216
+ return box_coordinates
217
+
218
+
219
+ # Visualize
220
+ def visualize_attention(arr, filename):
221
+ # Create a figure and axes object
222
+ fig, ax = plt.subplots()
223
+ # Display the array as an image
224
+ im = ax.imshow(arr)
225
+ # Add a colorbar
226
+ ax.figure.colorbar(im, ax=ax)
227
+ # cbar = ax.figure.colorbar(im, ax=ax)
228
+ # Save the figure as a PNG file
229
+ fig.savefig(filename)
230
+
231
+
232
+ # Build config
233
+ def build_sam_config(config_path):
234
+ with open(config_path, 'r') as infile:
235
+ config = json.load(infile)
236
+
237
+ sam_checkpoint = config['model']['sam_checkpoint']
238
+ model_type = config['model']['model_type']
239
+ return sam_checkpoint, model_type
utils/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
utils/inference_pipeline.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """The inference pipeline for the CaR model."""
17
+
18
+ import numpy as np
19
+ from PIL import Image
20
+ import torch
21
+
22
+ # pylint: disable=g-importing-member
23
+ # pylint: disable=g-bad-import-order
24
+ from modeling.post_process.post_process import generate_masks_from_sam
25
+ from modeling.post_process.post_process import match_masks
26
+ from utils.utils import process_sentence
27
+ from utils.metrics import IoU
28
+
29
+ IMAGE_WIDTH = 512
30
+ IMAGE_HEIGHT = 512
31
+
32
+
33
+ def get_sam_masks(
34
+ config, image_path, masks, matching_thresh=0.9, img_sam=None, pipeline=None
35
+ ):
36
+ """Generate SAM masks."""
37
+ print("generating sam masks online")
38
+ mask_tensor, mask_list = generate_masks_from_sam(
39
+ image_path,
40
+ save_path="./",
41
+ pipeline=pipeline,
42
+ img_sam=img_sam,
43
+ visualize=False,
44
+ )
45
+ mask_tensor = mask_tensor.to(masks.device)
46
+ # only conduct sam on masks that is not all zero
47
+ attn_map, mask_ids = [], []
48
+ for mask_id, mask in enumerate(masks):
49
+ if torch.sum(mask) > 0:
50
+ attn_map.append(mask.unsqueeze(0))
51
+ mask_ids.append(mask_id)
52
+ matched_masks = [
53
+ match_masks(
54
+ mask_tensor,
55
+ attn,
56
+ mask_list,
57
+ iom_thres=config.car.iom_thres,
58
+ min_pred_threshold=config.sam.min_pred_threshold,
59
+ )
60
+ for attn in attn_map
61
+ ]
62
+ for matched_mask, mask_id in zip(matched_masks, mask_ids):
63
+ sam_masks = np.array([item["segmentation"] for item in matched_mask])
64
+ sam_mask = np.any(sam_masks, axis=0)
65
+ cur_mask = masks[mask_id]
66
+ iou = IoU(torch.from_numpy(sam_mask).to(cur_mask.device), cur_mask)
67
+ if iou > matching_thresh:
68
+ masks[mask_id] = torch.from_numpy(sam_mask).to(masks.device)
69
+ return masks
70
+
71
+
72
+ def inference_car(cfg, car_model, image_path, sentences, sam_pipeline=None):
73
+ sentences = [process_sentence(sen, cfg.test.ds_name) for sen in sentences]
74
+ img = Image.open(image_path).convert("RGB")
75
+ if cfg.test.use_pseudo:
76
+ masks, scores = car_model(img, sentences)
77
+ return masks, scores
78
+
79
+ masks, scores = car_model(img, sentences, cfg.car.num_iteration)
80
+ sam_masks = get_sam_masks(
81
+ cfg, image_path, masks, cfg.sam.matching_thresh, pipeline=sam_pipeline
82
+ )
83
+ return sam_masks, scores
utils/merge_mask.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Mask merging functions for post-processing."""
17
+
18
+ import cv2
19
+ import numpy as np
20
+ import torch
21
+ import torch.nn.functional as F
22
+
23
+
24
+ def merge_masks_simple(
25
+ all_masks, target_h, target_w, threshold=0.5, scores=None
26
+ ):
27
+ """Merge masks."""
28
+ merged_mask = None
29
+ if scores is not None:
30
+ merged_mask = torch.sum(all_masks * scores[:, None, None], dim=0)
31
+ merged_mask /= torch.sum(scores)
32
+ merged_mask = merged_mask.detach().cpu().numpy()
33
+ # resize the mask to the target size
34
+ merged_mask = cv2.resize(merged_mask, (target_w, target_h))
35
+ merged_mask = np.where(merged_mask >= threshold, 1, 0).astype(np.uint8)
36
+ if np.sum(merged_mask) <= 0.05 * (target_h * target_w):
37
+ merged_mask = torch.any(all_masks > 0, dim=0)
38
+ merged_mask = merged_mask.detach().cpu().numpy().astype(np.uint8)
39
+ # resize the mask to the target size
40
+ merged_mask = cv2.resize(merged_mask, (target_w, target_h))
41
+ merged_mask = merged_mask > threshold
42
+ merged_mask = torch.from_numpy(merged_mask).float()
43
+ return merged_mask[None]
44
+
45
+
46
+ def merge_masks(all_masks, target_h, target_w, threshold=0.5):
47
+ all_masks = torch.from_numpy(np.stack(all_masks)).float()
48
+ mask_tensor = F.interpolate(
49
+ all_masks[None], size=(target_h, target_w), mode='bilinear'
50
+ ).squeeze(0)
51
+ bg_mask = threshold * torch.ones((1, target_h, target_w))
52
+ merged_mask = torch.cat([bg_mask, mask_tensor], dim=0)
53
+ mask_idx = torch.argmax(merged_mask, dim=0)
54
+ merged_mask = mask_idx > 0
55
+ if merged_mask.sum() <= 0.05 * (target_h * target_w):
56
+ merged_mask = torch.any(mask_tensor, dim=0)
57
+ return merged_mask.float()[None]
utils/metrics.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Metrics for evaluating the performance of the model."""
17
+
18
+ import torch
19
+
20
+
21
+ def IoU(mask1, mask2, threshold=0.5):
22
+ """Calculate Intersection over Union (IoU) between prediction and GT masks.
23
+
24
+ Args:
25
+ mask1: A torch.Tensor denoting the prediction, shape (N, H, W), where N is
26
+ the number of masks.
27
+ mask2: A torch.Tensor denoting the ground truth, shape (N, H, W), where N
28
+ is the number of masks.
29
+ threshold: The threshold to binarize masks.
30
+ Returns:
31
+ IoU of `mask1` and `mask2`.
32
+ """
33
+ if threshold > 0:
34
+ mask1, mask2 = (mask1 > threshold).to(torch.bool), (mask2 > threshold).to(
35
+ torch.bool
36
+ )
37
+ intersection = torch.sum(mask1 * (mask1 == mask2), dim=[-1, -2]).squeeze()
38
+ union = torch.sum(mask1 + mask2, dim=[-1, -2]).squeeze()
39
+ if union.sum() == 0:
40
+ return 0
41
+ return (intersection.to(torch.float) / union).mean().item()
42
+
43
+
44
+ def IoM(pred, target, min_pred_threshold=0.2):
45
+ """Calculate Intersection over the area of gt Mask and pred Mask (IoM).
46
+
47
+ between prediction and each ground truth masks.
48
+ Precaution:
49
+ this function works for prediction and target that are binary masks,
50
+ where 1 represents the mask and 0 represents the background.
51
+ Args:
52
+ pred: A torch.Tensor denoting the prediction, shape (N, H, W), where N is
53
+ the number of masks.
54
+ target: A torch.Tensor denoting the ground truth, shape (N, H, W), where N
55
+ is the number of masks.
56
+ min_pred_threshold: prediction threshold.
57
+
58
+ Returns:
59
+ ious: A torch.Tensor denoting the IoU, shape (N,).
60
+ """
61
+ # calculate the intersection over all masks
62
+ intersection = torch.einsum("mij,nij->mn", pred.to(target.device), target)
63
+ area_pred = torch.einsum("mij->m", pred)
64
+ area_target = torch.einsum("nij->n", target)
65
+ # we calculate the IoM by dividing the intersection over the minimum area.
66
+ iom_target = torch.einsum("mn,n->mn", intersection, 1 / area_target)
67
+ iom_pred = torch.einsum("mn,m->mn", intersection, 1 / area_pred)
68
+ # if the intersection is smaller than a certain percentage of the area of
69
+ # the pred mask, we consider it as background.
70
+ iom_target[iom_pred < min_pred_threshold] = 0
71
+ # we consider the IoM as the maximum IoM between the pred mask and
72
+ # the target mask.
73
+ iom = torch.max(iom_target, iom_pred)
74
+ iom = iom.max(dim=0)[0]
75
+ return iom
utils/nlp.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Language processing utilities."""
17
+
18
+ import spacy
19
+
20
+
21
+ def load_spacy_model(model='en_core_web_trf'):
22
+ nlp = spacy.load(model)
23
+ return nlp
24
+
25
+
26
+ def process_sentence(sentence, nlp):
27
+ """Process a sentence."""
28
+ doc = nlp(sentence)
29
+ sentence_for_spacy = []
30
+
31
+ for _, token in enumerate(doc):
32
+ if token.text == ' ':
33
+ continue
34
+ sentence_for_spacy.append(token.text)
35
+
36
+ sentence_for_spacy = ' '.join(sentence_for_spacy)
37
+ noun_phrase, _, _ = extract_noun_phrase(
38
+ sentence_for_spacy, nlp, need_index=True
39
+ )
40
+ return noun_phrase
41
+
42
+
43
+ def extract_noun_phrase(text, nlp, need_index=False):
44
+ """Extract noun phrase from text. nlp is a spacy model.
45
+
46
+ Args:
47
+ text: str, text to be processed.
48
+ nlp: spacy model.
49
+ need_index: bool, whether to return the index of the noun phrase.
50
+
51
+ Returns:
52
+ noun_phrase: str, noun phrase of the text.
53
+ """
54
+ # text = text.lower()
55
+
56
+ doc = nlp(text)
57
+
58
+ chunks = {}
59
+ chunks_index = {}
60
+ for chunk in doc.noun_chunks:
61
+ for i in range(chunk.start, chunk.end):
62
+ chunks[i] = chunk
63
+ chunks_index[i] = (chunk.start, chunk.end)
64
+
65
+ for token in doc:
66
+ if token.head.i == token.i:
67
+ head = token.head
68
+
69
+ if head.i not in chunks:
70
+ children = list(head.children)
71
+ if children and children[0].i in chunks:
72
+ head = children[0]
73
+ else:
74
+ if need_index:
75
+ return text, [], text
76
+ else:
77
+ return text
78
+
79
+ head_noun = head.text
80
+ head_index = chunks_index[head.i]
81
+ head_index = [i for i in range(head_index[0], head_index[1])]
82
+
83
+ sentence_index = [i for i in range(len(doc))]
84
+ not_phrase_index = []
85
+ for i in sentence_index:
86
+ # not_phrase_index.append(i) if i not in head_index else None
87
+ if i not in head_index:
88
+ not_phrase_index.append(i)
89
+
90
+ head = chunks[head.i]
91
+ if need_index:
92
+ return head.text, not_phrase_index, head_noun
93
+ else:
94
+ return head.text
utils/utils.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Utility functions for the project."""
17
+
18
+ from __future__ import print_function
19
+ # pylint: disable=g-importing-member
20
+ from collections import defaultdict
21
+ from collections import deque
22
+ from copy import deepcopy
23
+ import datetime
24
+ import errno
25
+ import os
26
+ import sys
27
+ import time
28
+ import numpy as np
29
+ from PIL import Image
30
+ import torch
31
+ from torchvision import transforms
32
+ import yaml
33
+
34
+ # pylint: disable=g-bad-import-order
35
+ from data.voc import CLASS2ID
36
+ from data.voc import VOC_CLASSES
37
+
38
+
39
+ _MB = 1024.0 * 1024.0
40
+
41
+ DINO_transform = transforms.Compose([
42
+ transforms.ToTensor(),
43
+ transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
44
+ ])
45
+
46
+
47
+ class Config:
48
+
49
+ def __init__(self, **kwargs):
50
+ for key, value in kwargs.items():
51
+ if isinstance(value, dict):
52
+ setattr(self, key, Config(**value))
53
+ else:
54
+ setattr(self, key, value)
55
+
56
+
57
+ def load_yaml(filename):
58
+ with open(filename) as file:
59
+ try:
60
+ data = yaml.safe_load(file)
61
+ return data
62
+ except yaml.YAMLError as e:
63
+ print(f"Error while loading YAML file: {e}")
64
+
65
+
66
+ def normalize(x, dim=None, eps=1e-15):
67
+ if dim is None:
68
+ return (x - x.min()) / (x.max() - x.min())
69
+ # Normalize to [0, 1].
70
+ numerator = x - x.min(axis=dim, keepdims=True)[0]
71
+ denominator = (
72
+ x.max(axis=dim, keepdims=True)[0]
73
+ - x.min(axis=dim, keepdims=True)[0]
74
+ + eps
75
+ )
76
+ return numerator / denominator
77
+
78
+
79
+ class SmoothedValue(object):
80
+ """Track a series of values and provide access to smoothed values over a window or the global series average."""
81
+
82
+ def __init__(self, window_size=20, fmt=None):
83
+ if fmt is None:
84
+ fmt = "{median:.4f} ({global_avg:.4f})"
85
+ self.deque = deque(maxlen=window_size)
86
+ self.total = 0.0
87
+ self.count = 0
88
+ self.fmt = fmt
89
+
90
+ def update(self, value, n=1):
91
+ self.deque.append(value)
92
+ self.count += n
93
+ self.total += value * n
94
+
95
+ # def synchronize_between_processes(self):
96
+ # """
97
+ # Warning: does not synchronize the deque!
98
+ # """
99
+ # if not is_dist_avail_and_initialized():
100
+ # return
101
+ # t = torch.tensor([self.count, self.total],
102
+ # dtype=torch.float64, device='cuda')
103
+ # dist.barrier()
104
+ # dist.all_reduce(t)
105
+ # t = t.tolist()
106
+ # self.count = int(t[0])
107
+ # self.total = t[1]
108
+
109
+ @property
110
+ def median(self):
111
+ d = torch.tensor(list(self.deque))
112
+ return d.median().item()
113
+
114
+ @property
115
+ def avg(self):
116
+ d = torch.tensor(list(self.deque), dtype=torch.float32)
117
+ return d.mean().item()
118
+
119
+ @property
120
+ def global_avg(self):
121
+ return self.total / self.count
122
+
123
+ @property
124
+ def max(self):
125
+ return max(self.deque)
126
+
127
+ @property
128
+ def value(self):
129
+ return self.deque[-1]
130
+
131
+ def __str__(self):
132
+ return self.fmt.format(
133
+ median=self.median,
134
+ avg=self.avg,
135
+ global_avg=self.global_avg,
136
+ max=self.max,
137
+ value=self.value,
138
+ )
139
+
140
+
141
+ class MetricLogger(object):
142
+ """Log the metrics."""
143
+
144
+ def __init__(self, delimiter="\t"):
145
+ self.meters = defaultdict(SmoothedValue)
146
+ self.delimiter = delimiter
147
+
148
+ def update(self, **kwargs):
149
+ for k, v in kwargs.items():
150
+ if isinstance(v, torch.Tensor):
151
+ v = v.item()
152
+ assert isinstance(v, (float, int))
153
+ self.meters[k].update(v)
154
+
155
+ def __getattr__(self, attr):
156
+ if attr in self.meters:
157
+ return self.meters[attr]
158
+ if attr in self.__dict__:
159
+ return self.__dict__[attr]
160
+ raise AttributeError(
161
+ "'{}' object has no attribute '{}'".format(type(self).__name__, attr)
162
+ )
163
+
164
+ def __str__(self):
165
+ loss_str = []
166
+ for name, meter in self.meters.items():
167
+ loss_str.append("{}: {}".format(name, str(meter)))
168
+ return self.delimiter.join(loss_str)
169
+
170
+ def synchronize_between_processes(self):
171
+ for meter in self.meters.values():
172
+ meter.synchronize_between_processes()
173
+
174
+ def add_meter(self, name, meter):
175
+ self.meters[name] = meter
176
+
177
+ def log_every(self, iterable, print_freq, header=None):
178
+ """Log every `print_freq` times."""
179
+ i = 0
180
+ if not header:
181
+ header = ""
182
+ start_time = time.time()
183
+ end = time.time()
184
+ iter_time = SmoothedValue(fmt="{avg:.4f}")
185
+ data_time = SmoothedValue(fmt="{avg:.4f}")
186
+ space_fmt = ":" + str(len(str(len(iterable)))) + "d"
187
+ log_msg = self.delimiter.join([
188
+ header,
189
+ "[{0" + space_fmt + "}/{1}]",
190
+ "eta: {eta}",
191
+ "{meters}",
192
+ "time: {time}",
193
+ "data: {data}",
194
+ "max mem: {memory:.0f}",
195
+ ])
196
+ for obj in iterable:
197
+ data_time.update(time.time() - end)
198
+ yield obj
199
+ iter_time.update(time.time() - end)
200
+ if i % print_freq == 0:
201
+ eta_seconds = iter_time.global_avg * (len(iterable) - i)
202
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
203
+ print(
204
+ log_msg.format(
205
+ i,
206
+ len(iterable),
207
+ eta=eta_string,
208
+ meters=str(self),
209
+ time=str(iter_time),
210
+ data=str(data_time),
211
+ memory=torch.cuda.max_memory_allocated() / _MB,
212
+ )
213
+ )
214
+ sys.stdout.flush()
215
+
216
+ i += 1
217
+ end = time.time()
218
+ total_time = time.time() - start_time
219
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
220
+ print("{} Total time: {}".format(header, total_time_str))
221
+
222
+
223
+ def mkdir(path):
224
+ try:
225
+ os.makedirs(path)
226
+ except OSError as e:
227
+ if e.errno != errno.EEXIST:
228
+ raise
229
+
230
+
231
+ def pad_to_square(im):
232
+ """Pad the images to square shape."""
233
+ im = deepcopy(im)
234
+ width, height = im.size
235
+ top_pad = (max(width, height) - height) // 2
236
+ bot_pad = max(width, height) - height - top_pad
237
+ left_pad = (max(width, height) - width) // 2
238
+ right_pad = max(width, height) - width - left_pad
239
+
240
+ if len(im.mode) == 3:
241
+ color = (0, 0, 0)
242
+ elif len(im.mode) == 1:
243
+ color = 0
244
+ else:
245
+ raise ValueError(f"Image mode not supported. Image has {im.mode} channels.")
246
+
247
+ return add_margin(im, top_pad, right_pad, bot_pad, left_pad, color=color)
248
+
249
+
250
+ def add_margin(pil_img, top, right, bottom, left, color=(0, 0, 0)):
251
+ """Ref: https://note.nkmk.me/en/python-pillow-add-margin-expand-canvas/."""
252
+ width, height = pil_img.size
253
+ new_width = width + right + left
254
+ new_height = height + top + bottom
255
+ result = Image.new(pil_img.mode, (new_width, new_height), color)
256
+ result.paste(pil_img, (left, top))
257
+
258
+ # 1 represents the image, 0 represents the padding
259
+ pad = [left, top, width, height]
260
+ return result, pad
261
+
262
+
263
+ def process_sentence(sentence, ds_name):
264
+ """Dataset specific sentence processing."""
265
+ if "refcoco" in ds_name:
266
+ sentence = sentence[0].lower()
267
+ # get rid of special characters
268
+ sentence = sentence.replace('"', "")
269
+ sentence = sentence.replace("/", "")
270
+ if ds_name == "voc":
271
+ if sentence in list(CLASS2ID.keys()):
272
+ label_id = CLASS2ID[sentence] - 1
273
+ sentence = VOC_CLASSES[label_id]
274
+
275
+ if not isinstance(sentence, str):
276
+ sentence = sentence[0]
277
+ return sentence
utils/visualize.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Visualization functions."""
17
+
18
+ import os
19
+
20
+ import cv2
21
+ import matplotlib.pyplot as plt
22
+ import numpy as np
23
+ from PIL import Image
24
+ import torch
25
+ # pylint: disable=g-importing-member
26
+ from utils.utils import normalize
27
+
28
+ _VIS_HEIGHT = 512
29
+ _VIS_WIDTH = 512
30
+
31
+
32
+ def show_cam_on_image(img, mask):
33
+ if img.shape[1] != mask.shape[1]:
34
+ mask = cv2.resize(mask, (img.shape[1], img.shape[0]))
35
+ heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
36
+ heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
37
+ heatmap = np.float32(heatmap) / 255
38
+ cam = heatmap + np.float32(img)
39
+ cam = cam / np.max(cam)
40
+ cam = np.uint8(255 * cam)
41
+ return cam
42
+
43
+
44
+ def save_img(array, img_name):
45
+ numpy_array = array.astype(np.uint8)
46
+ image = Image.fromarray(numpy_array, mode="RGB")
47
+ image.save(f"{img_name}.png")
48
+
49
+
50
+ def viz_attn(img, attn_map, prefix="vis_results/clipcam_img", img_name="cam"):
51
+ """Visualize attention map."""
52
+ num_masks = 1
53
+ if len(attn_map.shape) == 3:
54
+ num_masks = attn_map.shape[0]
55
+ attn_map = attn_map.float().squeeze(1).detach().cpu().numpy()
56
+ attn_map = normalize(attn_map)
57
+ img = normalize(img)
58
+ if num_masks == 1:
59
+ vis = show_cam_on_image(img, attn_map)
60
+ if not os.path.exists(prefix):
61
+ os.makedirs(prefix)
62
+ save_img(vis, os.path.join(prefix, f"{img_name}"))
63
+ return vis
64
+ for i in range(num_masks):
65
+ vis = show_cam_on_image(img, attn_map[i])
66
+ if not os.path.exists(prefix):
67
+ os.makedirs(prefix)
68
+ save_img(vis, os.path.join(prefix, f"{img_name}_{i}"))
69
+
70
+
71
+ def vis_mask(mask, gt_mask, img, output_dir, fname):
72
+ """Visualize mask."""
73
+ mask_img = torch.zeros((_VIS_WIDTH, _VIS_HEIGHT))
74
+ mask_img[mask[0]] = 1
75
+
76
+ # print(gt_mask.shape, img.size())
77
+ # Assume img and gt_mask are also torch.Tensor with size (512, 512)
78
+ img = img[0].permute(1, 2, 0).numpy()
79
+ gt_mask_img = torch.zeros((_VIS_WIDTH, _VIS_HEIGHT))
80
+ gt_mask_img[gt_mask[0]] = 1
81
+
82
+ _, axs = plt.subplots(
83
+ 1, 3, figsize=(15, 5)
84
+ ) # change the figsize if necessary
85
+
86
+ axs[0].imshow(img) # if image is grayscale, otherwise remove cmap argument
87
+ axs[0].axis("off")
88
+ axs[0].set_title("Original Image")
89
+
90
+ axs[1].imshow(
91
+ mask_img.numpy(), cmap="jet", alpha=0.5
92
+ ) # using alpha for transparency
93
+ axs[1].axis("off")
94
+ axs[1].set_title("Mask")
95
+
96
+ axs[2].imshow(
97
+ gt_mask_img.numpy(), cmap="jet", alpha=0.5
98
+ ) # using alpha for transparency
99
+ axs[2].axis("off")
100
+ axs[2].set_title("Ground Truth Mask")
101
+
102
+ plt.savefig(
103
+ os.path.join(output_dir, f"{fname}.jpg"),
104
+ bbox_inches="tight",
105
+ dpi=300,
106
+ pad_inches=0.0,
107
+ )