Spaces:
Runtime error
Runtime error
Kevin Sun
commited on
Commit
·
6cd90b7
1
Parent(s):
8927fea
init commit
Browse files- CLIP_as_RNN +1 -0
- README.md +13 -12
- app.py +285 -0
- configs/ade_150.yaml +55 -0
- configs/ade_847.yaml +55 -0
- configs/coco.yaml +51 -0
- configs/gres.yaml +38 -0
- configs/pascal_context.yaml +60 -0
- configs/pascal_context_459.yaml +55 -0
- configs/refcoco+.yaml +34 -0
- configs/refcoco.yaml +34 -0
- configs/refcocog.yaml +37 -0
- configs/voc.yaml +63 -0
- data/__init__.py +15 -0
- data/ade.py +544 -0
- data/ade847.py +1827 -0
- data/coco.py +137 -0
- data/context.py +126 -0
- data/gres.py +455 -0
- data/pascal459.py +998 -0
- data/preprocess.py +110 -0
- data/refcoco.py +449 -0
- data/voc.py +148 -0
- demo.py +227 -0
- evaluate.py +511 -0
- modeling/__init__.py +15 -0
- modeling/model/cam.py +222 -0
- modeling/model/car.py +318 -0
- modeling/model/clip_wrapper.py +297 -0
- modeling/model/clipcam.py +255 -0
- modeling/model/crf.py +113 -0
- modeling/model/utils.py +245 -0
- modeling/model/utils_test.py +129 -0
- modeling/post_process/object_discovery.py +355 -0
- modeling/post_process/post_process.py +167 -0
- requirements.txt +4 -0
- sam/__init__.py +19 -0
- sam/sam.py +205 -0
- sam/utils.py +239 -0
- utils/__init__.py +15 -0
- utils/inference_pipeline.py +83 -0
- utils/merge_mask.py +57 -0
- utils/metrics.py +75 -0
- utils/nlp.py +94 -0
- utils/utils.py +277 -0
- utils/visualize.py +107 -0
CLIP_as_RNN
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
Subproject commit 2457b49b339498af726408aa6673155de408c0f0
|
README.md
CHANGED
@@ -1,13 +1,14 @@
|
|
1 |
-
|
2 |
-
title: CLIP As RNN
|
3 |
-
emoji: 🏢
|
4 |
-
colorFrom: purple
|
5 |
-
colorTo: purple
|
6 |
-
sdk: gradio
|
7 |
-
sdk_version: 4.29.0
|
8 |
-
app_file: app.py
|
9 |
-
pinned: false
|
10 |
-
license: apache-2.0
|
11 |
-
---
|
12 |
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# CLIP as RNN: Segment Countless Visual Concepts without Training Endeavor (CaR)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
+
This repo holds the implementation code of the paper [CLIP as RNN: Segment Countless Visual Concepts without Training Endeavor (CaR)](https://arxiv.org/abs/2312.07661) by Shuyang Sun, Runjia Li, Philip Torr, Xiuye Gu, and Siyang Li:
|
4 |
+
|
5 |
+
```
|
6 |
+
@article{clip_as_rnn,
|
7 |
+
title={CLIP as RNN: Segment Countless Visual Concepts without Training Endeavor},
|
8 |
+
author={Shuyang Sun and Runjia Li and Philip Torr and Xiuye Gu and Siyang Li},
|
9 |
+
year={2023},
|
10 |
+
eprint={2312.07661},
|
11 |
+
archivePrefix={arXiv},
|
12 |
+
primaryClass={cs.CV}
|
13 |
+
}
|
14 |
+
```
|
app.py
ADDED
@@ -0,0 +1,285 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Run a Gradio demo of the CaR model on a single image."""
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import argparse
|
5 |
+
from functools import reduce
|
6 |
+
import PIL.Image as Image
|
7 |
+
import torch
|
8 |
+
from modeling.model import CaR
|
9 |
+
from utils.utils import Config, load_yaml
|
10 |
+
import matplotlib.pyplot as plt
|
11 |
+
import colorsys
|
12 |
+
from modeling.post_process.post_process import match_masks, generate_masks_from_sam
|
13 |
+
from sam.sam import SAMPipeline
|
14 |
+
from sam.utils import build_sam_config
|
15 |
+
import random
|
16 |
+
import gradio as gr
|
17 |
+
|
18 |
+
# set random seed
|
19 |
+
random.seed(15)
|
20 |
+
np.random.seed(0)
|
21 |
+
torch.manual_seed(0)
|
22 |
+
|
23 |
+
|
24 |
+
CFG_PATH = "configs/demo/pokemon.yaml"
|
25 |
+
|
26 |
+
def generate_distinct_colors(n):
|
27 |
+
colors = []
|
28 |
+
# generate a random number from 0 to 1
|
29 |
+
random_color_bias = random.random()
|
30 |
+
for i in range(n):
|
31 |
+
hue = float(i) / n
|
32 |
+
hue += random_color_bias
|
33 |
+
hue = hue % 1.0
|
34 |
+
rgb = colorsys.hsv_to_rgb(hue, 1.0, 1.0)
|
35 |
+
# Convert RGB values from [0, 1] range to [0, 255]
|
36 |
+
colors.append(tuple(int(val * 255) for val in rgb))
|
37 |
+
return colors
|
38 |
+
|
39 |
+
|
40 |
+
def overlap_masks(masks):
|
41 |
+
"""
|
42 |
+
Overlap masks to generate a single mask for visualization.
|
43 |
+
|
44 |
+
Parameters:
|
45 |
+
- masks: list of np.arrays of shape (H, W) representing binary masks for each class
|
46 |
+
|
47 |
+
Returns:
|
48 |
+
- overlap_mask: list of np.array of shape (H, W) that have no overlaps
|
49 |
+
"""
|
50 |
+
overlap_mask = torch.zeros_like(masks[0])
|
51 |
+
for mask_idx, mask in enumerate(masks):
|
52 |
+
overlap_mask[mask > 0] = mask_idx + 1
|
53 |
+
|
54 |
+
clean_masks = [overlap_mask == mask_idx +
|
55 |
+
1 for mask_idx in range(len(masks))]
|
56 |
+
clean_masks = torch.stack(clean_masks, dim=0)
|
57 |
+
|
58 |
+
return clean_masks
|
59 |
+
|
60 |
+
|
61 |
+
def visualize_segmentation(image,
|
62 |
+
masks,
|
63 |
+
class_names,
|
64 |
+
alpha=0.7,
|
65 |
+
y_list=None,
|
66 |
+
x_list=None):
|
67 |
+
"""
|
68 |
+
Visualize segmentation masks on an image.
|
69 |
+
|
70 |
+
Parameters:
|
71 |
+
- image: np.array of shape (H, W, 3) representing the RGB image
|
72 |
+
- masks: list of np.arrays of shape (H, W) representing binary masks for each class
|
73 |
+
- class_names: list of strings representing names of each class
|
74 |
+
- alpha: float, transparency level of masks on the image
|
75 |
+
|
76 |
+
Returns:
|
77 |
+
- visualization: plt.figure object
|
78 |
+
"""
|
79 |
+
# Create a figure and axis
|
80 |
+
fig, ax = plt.subplots(1, figsize=(12, 9))
|
81 |
+
# Display the image
|
82 |
+
# ax.imshow(image)
|
83 |
+
# Generate distinct colors for each mask
|
84 |
+
final_mask = np.zeros(
|
85 |
+
(masks.shape[1], masks.shape[2], 3), dtype=np.float32)
|
86 |
+
binary_final_mask = np.zeros(
|
87 |
+
(masks.shape[1], masks.shape[2]), dtype=np.float32)
|
88 |
+
colors = generate_distinct_colors(len(class_names))
|
89 |
+
idx = 0
|
90 |
+
for mask, color, class_name in zip(masks, colors, class_names):
|
91 |
+
# Overlay the mask
|
92 |
+
final_mask += np.dstack([mask * c for c in color])
|
93 |
+
binary_final_mask += mask
|
94 |
+
# Find a representative point (e.g., centroid) for placing the label
|
95 |
+
if y_list is None or x_list is None:
|
96 |
+
y, x = np.argwhere(mask).mean(axis=0)
|
97 |
+
else:
|
98 |
+
y, x = y_list[idx], x_list[idx]
|
99 |
+
ax.text(x, y, class_name, color='white',
|
100 |
+
fontsize=22, va='center', ha='center',
|
101 |
+
bbox=dict(facecolor='black', alpha=0.7, edgecolor='none'))
|
102 |
+
idx += 1
|
103 |
+
|
104 |
+
image[binary_final_mask > 0] = image[binary_final_mask > 0] * (1 - alpha)
|
105 |
+
final_image = image + final_mask * alpha
|
106 |
+
final_image = final_image.astype(np.uint8)
|
107 |
+
ax.imshow(final_image)
|
108 |
+
# Remove axis ticks and labels
|
109 |
+
ax.axis('off')
|
110 |
+
return fig
|
111 |
+
|
112 |
+
|
113 |
+
def get_sam_masks(cfg,
|
114 |
+
masks,
|
115 |
+
image_path=None,
|
116 |
+
img_sam=None,
|
117 |
+
pipeline=None):
|
118 |
+
# image_id = image_path.split('/')[-1].split('.')[0]
|
119 |
+
# sam_mask_path = os.path.join(cfg.test.sam_mask_root, f'{image_id}.npz')
|
120 |
+
# if os.path.exists(sam_mask_path):
|
121 |
+
# sam_mask_masks = np.load(sam_mask_path, allow_pickle=True)
|
122 |
+
# mask_tensor = torch.from_numpy(sam_mask_masks['mask_tensor'])
|
123 |
+
# mask_list = sam_mask_path['mask_list']
|
124 |
+
# else:
|
125 |
+
print("generating sam masks online")
|
126 |
+
if img_sam is None and image_path is not None:
|
127 |
+
raise ValueError(
|
128 |
+
'Please provide either the image path or the image numpy array.')
|
129 |
+
|
130 |
+
mask_tensor, mask_list = generate_masks_from_sam(
|
131 |
+
image_path,
|
132 |
+
save_path='./',
|
133 |
+
pipeline=pipeline,
|
134 |
+
img_sam=img_sam,
|
135 |
+
visualize=False,
|
136 |
+
)
|
137 |
+
mask_tensor = mask_tensor.to(masks.device)
|
138 |
+
# only conduct sam on masks that is not all zero
|
139 |
+
attn_map, mask_ids = [], []
|
140 |
+
for mask_id, mask in enumerate(masks):
|
141 |
+
if torch.sum(mask) > 0:
|
142 |
+
attn_map.append(mask.unsqueeze(0))
|
143 |
+
mask_ids.append(mask_id)
|
144 |
+
matched_masks = [match_masks(
|
145 |
+
mask_tensor,
|
146 |
+
attn,
|
147 |
+
mask_list,
|
148 |
+
iom_thres=cfg.car.iom_thres,
|
149 |
+
min_pred_threshold=cfg.sam.min_pred_threshold)
|
150 |
+
for attn in attn_map]
|
151 |
+
for matched_mask, mask_id in zip(matched_masks, mask_ids):
|
152 |
+
sam_masks = np.array([item['segmentation'] for item in matched_mask])
|
153 |
+
sam_mask = np.any(sam_masks, axis=0)
|
154 |
+
masks[mask_id] = torch.from_numpy(sam_mask).to(masks.device)
|
155 |
+
return masks
|
156 |
+
|
157 |
+
|
158 |
+
def load_sam(cfg, device):
|
159 |
+
sam_checkpoint, model_type = build_sam_config(cfg)
|
160 |
+
pipeline = SAMPipeline(
|
161 |
+
sam_checkpoint,
|
162 |
+
model_type,
|
163 |
+
device=device,
|
164 |
+
points_per_side=cfg.sam.points_per_side,
|
165 |
+
pred_iou_thresh=cfg.sam.pred_iou_thresh,
|
166 |
+
stability_score_thresh=cfg.sam.stability_score_thresh,
|
167 |
+
box_nms_thresh=cfg.sam.box_nms_thresh,
|
168 |
+
)
|
169 |
+
return pipeline
|
170 |
+
|
171 |
+
def generate(img,
|
172 |
+
class_names,
|
173 |
+
clip_thresh,
|
174 |
+
mask_thresh,
|
175 |
+
confidence_thresh,
|
176 |
+
post_process,
|
177 |
+
stability_score_thresh,
|
178 |
+
box_nms_thresh,
|
179 |
+
iom_thres,
|
180 |
+
min_pred_threshold):
|
181 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
182 |
+
cfg = Config(**load_yaml(CFG_PATH))
|
183 |
+
cfg.car.clipes_threshold = clip_thresh
|
184 |
+
cfg.car.mask_threshold = mask_thresh
|
185 |
+
cfg.car.confidence_threshold = confidence_thresh
|
186 |
+
cfg.sam.stability_score_thresh = stability_score_thresh
|
187 |
+
cfg.sam.box_nms_thresh = box_nms_thresh
|
188 |
+
cfg.car.iom_thres = iom_thres
|
189 |
+
cfg.sam.min_pred_threshold = min_pred_threshold
|
190 |
+
car_model = CaR(cfg,
|
191 |
+
visualize=True,
|
192 |
+
seg_mode='semantic',
|
193 |
+
device=device)
|
194 |
+
|
195 |
+
|
196 |
+
# resize image by dividing 2 if the size is larger than 1000
|
197 |
+
if img.size[0] > 1000:
|
198 |
+
img = img.resize((img.size[0] // 2, img.size[1] // 2))
|
199 |
+
|
200 |
+
y_list, x_list = None, None
|
201 |
+
class_names = class_names.split(',')
|
202 |
+
sentences = class_names
|
203 |
+
|
204 |
+
# class_names = ['the women chatting', 'the women chatting', 'table', 'fridge', 'cooking pot']
|
205 |
+
|
206 |
+
pseudo_masks, _, _ = car_model(
|
207 |
+
img, sentences, 1)
|
208 |
+
|
209 |
+
if post_process == 'SAM':
|
210 |
+
pipeline = load_sam(cfg, device)
|
211 |
+
pseudo_masks = get_sam_masks(
|
212 |
+
cfg,
|
213 |
+
pseudo_masks,
|
214 |
+
image_path=None,
|
215 |
+
img_sam=np.array(img),
|
216 |
+
pipeline=pipeline)
|
217 |
+
pseudo_masks = overlap_masks(pseudo_masks)
|
218 |
+
|
219 |
+
# visualize segmentation masks
|
220 |
+
demo_fig = visualize_segmentation(np.array(img),
|
221 |
+
pseudo_masks.detach().cpu().numpy(),
|
222 |
+
class_names,
|
223 |
+
y_list=y_list,
|
224 |
+
x_list=x_list)
|
225 |
+
|
226 |
+
# convert the demo figure to an pil image
|
227 |
+
demo_fig.canvas.draw()
|
228 |
+
demo_img = np.array(demo_fig.canvas.renderer._renderer)
|
229 |
+
demo_img = Image.fromarray(demo_img)
|
230 |
+
return demo_img
|
231 |
+
|
232 |
+
|
233 |
+
|
234 |
+
if __name__ == "__main__":
|
235 |
+
parser = argparse.ArgumentParser('car')
|
236 |
+
parser.add_argument("--cfg-path",
|
237 |
+
default='configs/local_car.yaml',
|
238 |
+
help="path to configuration file.")
|
239 |
+
args = parser.parse_args()
|
240 |
+
|
241 |
+
demo = gr.Interface(generate,
|
242 |
+
inputs=[gr.Image(label="upload an image", type="pil"),
|
243 |
+
"text",
|
244 |
+
gr.Slider(label="clip thresh",
|
245 |
+
minimum=0,
|
246 |
+
maximum=1,
|
247 |
+
value=0.4,
|
248 |
+
step=0.1,
|
249 |
+
info="the threshold for clip-es adversarial heatmap clipping"),
|
250 |
+
gr.Slider(label="mask thresh",
|
251 |
+
minimum=0,
|
252 |
+
maximum=1,
|
253 |
+
value=0.6,
|
254 |
+
step=0.1,
|
255 |
+
info="the binariation threshold for the mask to generate visual prompt"),
|
256 |
+
gr.Slider(label="confidence thresh",
|
257 |
+
minimum=0,
|
258 |
+
maximum=1,
|
259 |
+
value=0,
|
260 |
+
step=0.1,
|
261 |
+
info="the threshold for filtering the proposed classes"),
|
262 |
+
gr.Radio(["CRF", "SAM"], label="post process", value="CRF", info="choose the post process method"),
|
263 |
+
gr.Slider(label="stability score thresh for SAM mask proposal \n(only when SAM is chosen for post process)",
|
264 |
+
minimum=0,
|
265 |
+
maximum=1,
|
266 |
+
value=0.95,
|
267 |
+
step=0.1),
|
268 |
+
gr.Slider(label="box nms thresh for SAM mask proposal \n(only when SAM is chosen for post process)", minimum=0, maximum=1, value=0.7, step=0.1),
|
269 |
+
gr.Slider(label="intersection over mask threshold for SAM mask proposal \n(only when SAM is chosen for post process)", minimum=0, maximum=1, value=0.5, step=0.1),
|
270 |
+
gr.Slider(label="minimum prediction threshold for SAM mask proposal \n(only when SAM is chosen for post process)", minimum=0, maximum=1, value=0.03, step=0.01)],
|
271 |
+
outputs="image",
|
272 |
+
title="CLIP as RNN: Segment Countless Visual Concepts without Training Endeavor",
|
273 |
+
description="This is the official demo for CLIP as RNN. Please upload an image and type in the class names (connected by ',' e.g. cat,dog,human) you want to segment. The model will generate the segmentation masks for the input image. You can also adjust the clip thresh, mask thresh and confidence thresh to get better results.",
|
274 |
+
examples=[["demo/pokemon1.jpg", "Charmander,Bulbasaur,Squirtle", 0.6, 0.6, 0, "SAM", 0.95, 0.7, 0.6, 0.01],
|
275 |
+
["demo/batman.jpg", "Batman,Joker,Cat Woman", 0.6, 0.6, 0, "SAM", 0.95, 0.7, 0.6, 0.01],
|
276 |
+
["demo/avengers1.jpg", "Thor,Captain America,Hulk,Iron Man", 0.6, 0.6, 0, "SAM", 0.89, 0.65, 0.5, 0.03],
|
277 |
+
|
278 |
+
])
|
279 |
+
demo.launch(share=True)
|
280 |
+
|
281 |
+
|
282 |
+
# device = "cuda" if torch.cuda.is_available() else "cpu"
|
283 |
+
|
284 |
+
|
285 |
+
stop = 0
|
configs/ade_150.yaml
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
clip:
|
2 |
+
semantic_clip_model_name: 'ViT-L/14'
|
3 |
+
semantic_pretrained_data: 'openai'
|
4 |
+
clip_model_name: "ViT-B/16"
|
5 |
+
pretrained_data: 'openai'
|
6 |
+
|
7 |
+
car:
|
8 |
+
iom_thres: 0.6
|
9 |
+
mask_threshold: 0.6
|
10 |
+
min_area_ratio: 0.2
|
11 |
+
num_iteration: 1
|
12 |
+
confidence_threshold: 0.25
|
13 |
+
clipes_threshold: 0.7
|
14 |
+
bg_factor: 1
|
15 |
+
stuff_bg_factor: 1
|
16 |
+
visual_prompt_type: ['gray', 'blur']
|
17 |
+
stuff_visual_prompt_type: ['gray', 'blur']
|
18 |
+
semantic_templates: ['a clean origami {}.',
|
19 |
+
'a photo of a {}.',
|
20 |
+
'This is a photo of a {}',
|
21 |
+
'There is a {} in the scene',
|
22 |
+
'There is the {} in the scene',
|
23 |
+
'a photo of a {} in the scene',
|
24 |
+
'a photo of a small {}.',
|
25 |
+
'a photo of a medium {}.',
|
26 |
+
'a photo of a large {}.',
|
27 |
+
'This is a photo of a small {}.',
|
28 |
+
'This is a photo of a medium {}.',
|
29 |
+
'This is a photo of a large {}.',
|
30 |
+
'There is a small {} in the scene.',
|
31 |
+
'There is a medium {} in the scene.',
|
32 |
+
'There is a large {} in the scene.']
|
33 |
+
|
34 |
+
bg_cls: ['ground', 'land', 'grass', 'tree', 'building',
|
35 |
+
'wall', 'sky', 'lake', 'water', 'river', 'sea',
|
36 |
+
'railway', 'railroad', 'helmet', 'cloud', 'house',
|
37 |
+
'mountain', 'ocean', 'road', 'rock', 'street',
|
38 |
+
'valley', 'bridge']
|
39 |
+
|
40 |
+
test:
|
41 |
+
algo: "car"
|
42 |
+
ds_name: "ade"
|
43 |
+
seg_mode: "semantic"
|
44 |
+
split: 'validation'
|
45 |
+
data_root: "$YOUR_ADE_DATA_DIR"
|
46 |
+
# You need to extract the sam mask for the ADE dataset if use_pseudo=False
|
47 |
+
sam_mask_root: "$YOUR_SAM_MASK_DIR"
|
48 |
+
output_path: "./outputs/"
|
49 |
+
use_pseudo: True
|
50 |
+
n_class: 151
|
51 |
+
num_chunks: 1
|
52 |
+
chunk_index: 0
|
53 |
+
ignore_background: True
|
54 |
+
|
55 |
+
save_path: "./outputs"
|
configs/ade_847.yaml
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
clip:
|
2 |
+
semantic_clip_model_name: 'ViT-L/14'
|
3 |
+
semantic_pretrained_data: 'openai'
|
4 |
+
clip_model_name: "ViT-B/16"
|
5 |
+
pretrained_data: 'openai'
|
6 |
+
|
7 |
+
car:
|
8 |
+
iom_thres: 0.6
|
9 |
+
mask_threshold: 0.6
|
10 |
+
min_area_ratio: 0.2
|
11 |
+
num_iteration: 1
|
12 |
+
confidence_threshold: 0.25
|
13 |
+
clipes_threshold: 0.7
|
14 |
+
bg_factor: 1
|
15 |
+
stuff_bg_factor: 1
|
16 |
+
visual_prompt_type: ['gray', 'blur']
|
17 |
+
stuff_visual_prompt_type: ['gray', 'blur']
|
18 |
+
semantic_templates: ['a clean origami {}.',
|
19 |
+
'a photo of a {}.',
|
20 |
+
'This is a photo of a {}',
|
21 |
+
'There is a {} in the scene',
|
22 |
+
'There is the {} in the scene',
|
23 |
+
'a photo of a {} in the scene',
|
24 |
+
'a photo of a small {}.',
|
25 |
+
'a photo of a medium {}.',
|
26 |
+
'a photo of a large {}.',
|
27 |
+
'This is a photo of a small {}.',
|
28 |
+
'This is a photo of a medium {}.',
|
29 |
+
'This is a photo of a large {}.',
|
30 |
+
'There is a small {} in the scene.',
|
31 |
+
'There is a medium {} in the scene.',
|
32 |
+
'There is a large {} in the scene.']
|
33 |
+
|
34 |
+
bg_cls: ['ground', 'land', 'grass', 'tree', 'building',
|
35 |
+
'wall', 'sky', 'lake', 'water', 'river', 'sea',
|
36 |
+
'railway', 'railroad', 'helmet', 'cloud', 'house',
|
37 |
+
'mountain', 'ocean', 'road', 'rock', 'street',
|
38 |
+
'valley', 'bridge']
|
39 |
+
|
40 |
+
test:
|
41 |
+
algo: "car"
|
42 |
+
ds_name: "ade_847"
|
43 |
+
seg_mode: "semantic"
|
44 |
+
split: 'validation'
|
45 |
+
data_root: "$YOUR_ADE_DATA_DIR"
|
46 |
+
# You need to extract the sam mask for the ADE dataset if use_pseudo=False
|
47 |
+
sam_mask_root: "$YOUR_SAM_MASK_DIR"
|
48 |
+
output_path: "./outputs/"
|
49 |
+
use_pseudo: True
|
50 |
+
n_class: 847
|
51 |
+
num_chunks: 1
|
52 |
+
chunk_index: 0
|
53 |
+
ignore_background: True
|
54 |
+
|
55 |
+
save_path: "./outputs"
|
configs/coco.yaml
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
clip:
|
2 |
+
semantic_clip_model_name: 'ViT-L/14'
|
3 |
+
semantic_pretrained_data: 'openai'
|
4 |
+
clip_model_name: "ViT-B/16"
|
5 |
+
pretrained_data: 'openai'
|
6 |
+
|
7 |
+
|
8 |
+
car:
|
9 |
+
iom_thres: 0.7
|
10 |
+
mask_threshold: 0.5
|
11 |
+
min_area_ratio: 0.2
|
12 |
+
num_iteration: 1
|
13 |
+
confidence_threshold: 0.3
|
14 |
+
clipes_threshold: 0.5
|
15 |
+
visual_prompt_type: ['blur', 'gray']
|
16 |
+
semantic_templates: ['a clean origami {}.',
|
17 |
+
'a photo of a {}.',
|
18 |
+
'This is a photo of a {}',
|
19 |
+
'There is a {} in the scene',
|
20 |
+
'There is the {} in the scene',
|
21 |
+
'a photo of a {} in the scene',
|
22 |
+
'a photo of a small {}.',
|
23 |
+
'a photo of a medium {}.',
|
24 |
+
'a photo of a large {}.',
|
25 |
+
'This is a photo of a small {}.',
|
26 |
+
'This is a photo of a medium {}.',
|
27 |
+
'This is a photo of a large {}.',
|
28 |
+
'There is a small {} in the scene.',
|
29 |
+
'There is a medium {} in the scene.',
|
30 |
+
'There is a large {} in the scene.']
|
31 |
+
bg_cls: ['ground', 'land', 'grass', 'tree', 'building',
|
32 |
+
'wall', 'sky', 'lake', 'water', 'river', 'sea',
|
33 |
+
'railway', 'railroad', 'helmet', 'cloud', 'house',
|
34 |
+
'mountain', 'ocean', 'road', 'rock', 'street',
|
35 |
+
'valley', 'bridge']
|
36 |
+
|
37 |
+
test:
|
38 |
+
algo: "car"
|
39 |
+
ds_name: "coco"
|
40 |
+
seg_mode: "semantic"
|
41 |
+
data_root: "$YOUR_DATA_DIR"
|
42 |
+
# You need to extract the sam mask for the ADE dataset if use_pseudo=False
|
43 |
+
sam_mask_root: "$YOUR_SAM_MASK_DIR"
|
44 |
+
output_path: "./outputs/"
|
45 |
+
use_pseudo: True
|
46 |
+
split: "val"
|
47 |
+
n_class: 81
|
48 |
+
num_chunks: 1
|
49 |
+
chunk_index: 0
|
50 |
+
|
51 |
+
save_path: "./outputs"
|
configs/gres.yaml
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
clip:
|
2 |
+
semantic_clip_model_name: 'ViT-L/14'
|
3 |
+
semantic_pretrained_data: 'openai'
|
4 |
+
clip_model_name: "ViT-B/16"
|
5 |
+
pretrained_data: 'openai'
|
6 |
+
|
7 |
+
car:
|
8 |
+
iom_thres: 0.5
|
9 |
+
mask_threshold: 0.5
|
10 |
+
confidence_threshold: 0
|
11 |
+
clipes_threshold: 0.3
|
12 |
+
cam_text_template: 'a clean origami {}.'
|
13 |
+
color: [255, 0, 0] # red
|
14 |
+
visual_prompt_type: ['circle']
|
15 |
+
bg_cls: ['ground', 'land', 'grass', 'tree', 'building',
|
16 |
+
'wall', 'sky', 'lake', 'water', 'river', 'sea',
|
17 |
+
'railway', 'railroad', 'helmet', 'cloud', 'house',
|
18 |
+
'mountain', 'ocean', 'road', 'rock', 'street',
|
19 |
+
'valley', 'bridge']
|
20 |
+
|
21 |
+
|
22 |
+
test:
|
23 |
+
algo: "car"
|
24 |
+
ds_name: "gres"
|
25 |
+
split: 'val'
|
26 |
+
seg_mode: "refer"
|
27 |
+
data_root: "$YOUR_ADE_DATA_DIR"
|
28 |
+
output_path: "./outputs/"
|
29 |
+
prompts_augment: False
|
30 |
+
use_pseudo: True
|
31 |
+
use_background: False
|
32 |
+
prompts_prefix: False
|
33 |
+
prompts_augment: False
|
34 |
+
|
35 |
+
sentence_process:
|
36 |
+
mixing_alpha: 0.
|
37 |
+
|
38 |
+
save_path: "./outputs"
|
configs/pascal_context.yaml
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
clip:
|
2 |
+
semantic_clip_model_name: 'ViT-L/14'
|
3 |
+
semantic_pretrained_data: 'openai'
|
4 |
+
clip_model_name: "ViT-B/16"
|
5 |
+
pretrained_data: 'openai'
|
6 |
+
|
7 |
+
|
8 |
+
car:
|
9 |
+
iom_thres: 0.5
|
10 |
+
mask_threshold: 0.6
|
11 |
+
stuff_mask_threshold: 0.6
|
12 |
+
min_area_ratio: 0.2
|
13 |
+
num_iteration: 1
|
14 |
+
confidence_threshold: 0.25
|
15 |
+
clipes_threshold: 0.4
|
16 |
+
bg_factor: 1
|
17 |
+
stuff_bg_factor: 1
|
18 |
+
has_pamr: False
|
19 |
+
visual_prompt_type: ['blur', 'circle']
|
20 |
+
stuff_visual_prompt_type: ['blur', 'gray']
|
21 |
+
semantic_templates: ['a clean origami {}.',
|
22 |
+
'a photo of a {}.',
|
23 |
+
'This is a photo of a {}',
|
24 |
+
'There is a {} in the scene',
|
25 |
+
'There is the {} in the scene',
|
26 |
+
'a photo of a {} in the scene',
|
27 |
+
'a photo of a small {}.',
|
28 |
+
'a photo of a medium {}.',
|
29 |
+
'a photo of a large {}.',
|
30 |
+
'This is a photo of a small {}.',
|
31 |
+
'This is a photo of a medium {}.',
|
32 |
+
'This is a photo of a large {}.',
|
33 |
+
'There is a small {} in the scene.',
|
34 |
+
'There is a medium {} in the scene.',
|
35 |
+
'There is a large {} in the scene.']
|
36 |
+
|
37 |
+
|
38 |
+
|
39 |
+
bg_cls: ['ground', 'land', 'grass', 'tree', 'building',
|
40 |
+
'wall', 'sky', 'lake', 'water', 'river', 'sea',
|
41 |
+
'railway', 'railroad', 'helmet', 'cloud', 'house',
|
42 |
+
'mountain', 'ocean', 'road', 'rock', 'street',
|
43 |
+
'valley', 'bridge']
|
44 |
+
|
45 |
+
|
46 |
+
test:
|
47 |
+
algo: "car"
|
48 |
+
ds_name: "context"
|
49 |
+
seg_mode: "semantic"
|
50 |
+
n_class: 60
|
51 |
+
data_root: "$YOUR_DATA_DIR"
|
52 |
+
output_path: "./outputs/"
|
53 |
+
use_pseudo: True
|
54 |
+
split: "val"
|
55 |
+
num_chunks: 1
|
56 |
+
chunk_index: 0
|
57 |
+
ignore_background: False
|
58 |
+
|
59 |
+
|
60 |
+
save_path: "./outputs"
|
configs/pascal_context_459.yaml
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
clip:
|
2 |
+
semantic_clip_model_name: 'ViT-L/14'
|
3 |
+
semantic_pretrained_data: 'openai'
|
4 |
+
clip_model_name: "ViT-B/16"
|
5 |
+
pretrained_data: 'openai'
|
6 |
+
|
7 |
+
car:
|
8 |
+
iom_thres: 0.6
|
9 |
+
mask_threshold: 0.4
|
10 |
+
min_area_ratio: 0.2
|
11 |
+
num_iteration: 1
|
12 |
+
confidence_threshold: 0.25 # 0.2
|
13 |
+
clipes_threshold: 0.7
|
14 |
+
bg_factor: 1
|
15 |
+
stuff_bg_factor: 1
|
16 |
+
visual_prompt_type: ['gray', 'blur']
|
17 |
+
stuff_visual_prompt_type: ['gray', 'blur']
|
18 |
+
semantic_templates: ['a clean origami {}.',
|
19 |
+
'a photo of a {}.',
|
20 |
+
'This is a photo of a {}',
|
21 |
+
'There is a {} in the scene',
|
22 |
+
'There is the {} in the scene',
|
23 |
+
'a photo of a {} in the scene',
|
24 |
+
'a photo of a small {}.',
|
25 |
+
'a photo of a medium {}.',
|
26 |
+
'a photo of a large {}.',
|
27 |
+
'This is a photo of a small {}.',
|
28 |
+
'This is a photo of a medium {}.',
|
29 |
+
'This is a photo of a large {}.',
|
30 |
+
'There is a small {} in the scene.',
|
31 |
+
'There is a medium {} in the scene.',
|
32 |
+
'There is a large {} in the scene.']
|
33 |
+
|
34 |
+
bg_cls: ['ground', 'land', 'grass', 'tree', 'building',
|
35 |
+
'wall', 'sky', 'lake', 'water', 'river', 'sea',
|
36 |
+
'railway', 'railroad', 'helmet', 'cloud', 'house',
|
37 |
+
'mountain', 'ocean', 'road', 'rock', 'street',
|
38 |
+
'valley', 'bridge']
|
39 |
+
|
40 |
+
test:
|
41 |
+
algo: "car"
|
42 |
+
ds_name: "pascal_459"
|
43 |
+
seg_mode: "semantic"
|
44 |
+
split: 'validation'
|
45 |
+
data_root: "$YOUR_DATA_DIR"
|
46 |
+
# You need to extract the sam mask for the ADE dataset if use_pseudo=False
|
47 |
+
sam_mask_root: "$YOUR_SAM_MASK_DIR"
|
48 |
+
output_path: "./outputs/"
|
49 |
+
use_pseudo: True
|
50 |
+
n_class: 460
|
51 |
+
num_chunks: 1
|
52 |
+
chunk_index: 0
|
53 |
+
ignore_background: True
|
54 |
+
|
55 |
+
save_path: "./outputs"
|
configs/refcoco+.yaml
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
clip:
|
2 |
+
semantic_clip_model_name: 'ViT-B/16'
|
3 |
+
semantic_pretrained_data: 'openai'
|
4 |
+
clip_model_name: "ViT-B/16"
|
5 |
+
pretrained_data: 'openai'
|
6 |
+
|
7 |
+
car:
|
8 |
+
iom_thres: 0.5
|
9 |
+
mask_threshold: 0.2
|
10 |
+
confidence_threshold: 0.1
|
11 |
+
clipes_threshold: 0.5 # refcocog: 0.6
|
12 |
+
color: [255, 0, 0] # red
|
13 |
+
visual_prompt_type: ['circle', 'blur']
|
14 |
+
min_area_ratio: 0.2
|
15 |
+
bg_cls: ['ground', 'land', 'grass', 'tree', 'building',
|
16 |
+
'wall', 'sky', 'lake', 'water', 'river', 'sea',
|
17 |
+
'railway', 'railroad', 'helmet', 'cloud', 'house',
|
18 |
+
'mountain', 'ocean', 'road', 'rock', 'street',
|
19 |
+
'valley', 'bridge']
|
20 |
+
|
21 |
+
test:
|
22 |
+
algo: "car"
|
23 |
+
ds_name: "refcoco+"
|
24 |
+
seg_mode: "refer"
|
25 |
+
split: 'val'
|
26 |
+
data_root: "$YOUR_DATA_DIR"
|
27 |
+
output_path: "./outputs/"
|
28 |
+
prompts_augment: False
|
29 |
+
use_pseudo: True
|
30 |
+
|
31 |
+
sentence_process:
|
32 |
+
mixing_alpha: 0.
|
33 |
+
|
34 |
+
save_path: "./outputs"
|
configs/refcoco.yaml
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
clip:
|
2 |
+
semantic_clip_model_name: 'ViT-B/16'
|
3 |
+
semantic_pretrained_data: 'openai'
|
4 |
+
clip_model_name: "ViT-B/16"
|
5 |
+
pretrained_data: 'openai'
|
6 |
+
|
7 |
+
car:
|
8 |
+
iom_thres: 0.5
|
9 |
+
mask_threshold: 0.5
|
10 |
+
confidence_threshold: 0.3
|
11 |
+
clipes_threshold: 0.5
|
12 |
+
color: [255, 0, 0] # red
|
13 |
+
visual_prompt_type: ['circle']
|
14 |
+
min_area_ratio: 0.2
|
15 |
+
bg_cls: ['ground', 'land', 'grass', 'tree', 'building',
|
16 |
+
'wall', 'sky', 'lake', 'water', 'river', 'sea',
|
17 |
+
'railway', 'railroad', 'helmet', 'cloud', 'house',
|
18 |
+
'mountain', 'ocean', 'road', 'rock', 'street',
|
19 |
+
'valley', 'bridge']
|
20 |
+
|
21 |
+
test:
|
22 |
+
algo: "car"
|
23 |
+
ds_name: "refcoco"
|
24 |
+
seg_mode: "refer"
|
25 |
+
split: 'val'
|
26 |
+
data_root: "$YOUR_DATA_DIR"
|
27 |
+
output_path: "./outputs/"
|
28 |
+
prompts_augment: False
|
29 |
+
use_pseudo: True
|
30 |
+
|
31 |
+
sentence_process:
|
32 |
+
mixing_alpha: 0.
|
33 |
+
|
34 |
+
save_path: "./outputs"
|
configs/refcocog.yaml
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
clip:
|
2 |
+
semantic_clip_model_name: 'ViT-B/16'
|
3 |
+
semantic_pretrained_data: 'openai'
|
4 |
+
clip_model_name: "ViT-B/16"
|
5 |
+
pretrained_data: 'openai'
|
6 |
+
|
7 |
+
car:
|
8 |
+
iom_thres: 0.5
|
9 |
+
mask_threshold: 0.5
|
10 |
+
confidence_threshold: 0.1
|
11 |
+
clipes_threshold: 0.6
|
12 |
+
color: [255, 0, 0] # red
|
13 |
+
visual_prompt_type: ['circle', 'blur']
|
14 |
+
min_area_ratio: 0.2
|
15 |
+
bg_cls: ['ground', 'land', 'grass', 'tree', 'building',
|
16 |
+
'wall', 'sky', 'lake', 'water', 'river', 'sea',
|
17 |
+
'railway', 'railroad', 'helmet', 'cloud', 'house',
|
18 |
+
'mountain', 'ocean', 'road', 'rock', 'street',
|
19 |
+
'valley', 'bridge']
|
20 |
+
|
21 |
+
test:
|
22 |
+
algo: "car"
|
23 |
+
ds_name: "refcoco+"
|
24 |
+
seg_mode: "refer"
|
25 |
+
splitby: 'umd'
|
26 |
+
split: 'val'
|
27 |
+
data_root: "$YOUR_DATA_DIR"
|
28 |
+
output_path: "./outputs/"
|
29 |
+
prompts_augment: False
|
30 |
+
use_pseudo: True
|
31 |
+
|
32 |
+
sentence_process:
|
33 |
+
mixing_alpha: 0.
|
34 |
+
|
35 |
+
save_path: "./outputs"
|
36 |
+
|
37 |
+
|
configs/voc.yaml
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
clip:
|
2 |
+
semantic_clip_model_name: 'ViT-L/14'
|
3 |
+
semantic_pretrained_data: 'openai'
|
4 |
+
clip_model_name: "ViT-B/16"
|
5 |
+
pretrained_data: 'openai'
|
6 |
+
|
7 |
+
car:
|
8 |
+
iom_thres: 0.6
|
9 |
+
mask_threshold: 0.4
|
10 |
+
min_area_ratio: 0.2
|
11 |
+
confidence_threshold: 0.6 # 0.2
|
12 |
+
clipes_threshold: 0.4
|
13 |
+
visualize: False
|
14 |
+
visual_prompt_type: ['circle', 'blur']
|
15 |
+
semantic_templates: ['a clean origami {}.',
|
16 |
+
'a photo of a {}.',
|
17 |
+
'This is a photo of a {}',
|
18 |
+
'There is a {} in the scene',
|
19 |
+
'There is the {} in the scene',
|
20 |
+
'a photo of a {} in the scene',
|
21 |
+
'a photo of a small {}.',
|
22 |
+
'a photo of a medium {}.',
|
23 |
+
'a photo of a large {}.',
|
24 |
+
'This is a photo of a small {}.',
|
25 |
+
'This is a photo of a medium {}.',
|
26 |
+
'This is a photo of a large {}.',
|
27 |
+
'There is a small {} in the scene.',
|
28 |
+
'There is a medium {} in the scene.',
|
29 |
+
'There is a large {} in the scene.']
|
30 |
+
|
31 |
+
bg_cls: ['ground', 'land', 'grass', 'tree', 'building',
|
32 |
+
'wall', 'sky', 'lake', 'water', 'river', 'sea',
|
33 |
+
'railway', 'railroad', 'helmet', 'cloud', 'house',
|
34 |
+
'mountain', 'ocean', 'road', 'rock', 'street',
|
35 |
+
'valley', 'bridge']
|
36 |
+
|
37 |
+
# SAM is activated only if test.use_pseudo is False
|
38 |
+
sam:
|
39 |
+
model_dir: "$YOUR_SAM_MODEL_DIR"
|
40 |
+
sam_checkpoint: "$YOUR_SAM_MODEL_DIR/sam_hq_vit_h.pth"
|
41 |
+
model_type: "vit_h"
|
42 |
+
min_pred_threshold: 0.05
|
43 |
+
points_per_side:
|
44 |
+
pred_iou_thresh: 0.88
|
45 |
+
stability_score_thresh: 0.95
|
46 |
+
box_nms_thresh: 0.7
|
47 |
+
|
48 |
+
test:
|
49 |
+
algo: "car"
|
50 |
+
ds_name: "voc"
|
51 |
+
seg_mode: "semantic"
|
52 |
+
split: 'val'
|
53 |
+
data_root: "$YOUR_DATA_DIR"
|
54 |
+
# You need to extract the sam mask for the ADE dataset if use_pseudo=False
|
55 |
+
sam_mask_root: "$YOUR_SAM_MASK_DIR"
|
56 |
+
output_path: "./outputs/"
|
57 |
+
use_pseudo: True
|
58 |
+
n_class: 21
|
59 |
+
num_chunks: 1
|
60 |
+
chunk_index: 0
|
61 |
+
ignore_background: False
|
62 |
+
|
63 |
+
save_path: "./outputs"
|
data/__init__.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2024 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
data/ade.py
ADDED
@@ -0,0 +1,544 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2024 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""ADE20K dataset."""
|
17 |
+
|
18 |
+
import os
|
19 |
+
import numpy as np
|
20 |
+
from PIL import Image
|
21 |
+
import torch
|
22 |
+
|
23 |
+
|
24 |
+
ADE_CLASSES = [
|
25 |
+
'wall',
|
26 |
+
'building, edifice',
|
27 |
+
'sky',
|
28 |
+
'floor, flooring',
|
29 |
+
'tree',
|
30 |
+
'ceiling',
|
31 |
+
'road, route',
|
32 |
+
'bed',
|
33 |
+
'windowpane, window',
|
34 |
+
'grass',
|
35 |
+
'cabinet',
|
36 |
+
'sidewalk, pavement',
|
37 |
+
'person, individual, someone, somebody, mortal, soul',
|
38 |
+
'earth, ground',
|
39 |
+
'door, double, door',
|
40 |
+
'table',
|
41 |
+
'mountain, mount',
|
42 |
+
'plant, flora, plant, life',
|
43 |
+
'curtain, drape, drapery, mantle, pall',
|
44 |
+
'chair',
|
45 |
+
'car, auto, automobile, machine, motorcar',
|
46 |
+
'water',
|
47 |
+
'painting, picture',
|
48 |
+
'sofa, couch, lounge',
|
49 |
+
'shelf',
|
50 |
+
'house',
|
51 |
+
'sea',
|
52 |
+
'mirror',
|
53 |
+
'rug, carpet, carpeting',
|
54 |
+
'field',
|
55 |
+
'armchair',
|
56 |
+
'seat',
|
57 |
+
'fence, fencing',
|
58 |
+
'desk',
|
59 |
+
'rock, stone',
|
60 |
+
'wardrobe, closet, press',
|
61 |
+
'lamp',
|
62 |
+
'bathtub, bathing, tub, bath, tub',
|
63 |
+
'railing, rail',
|
64 |
+
'cushion',
|
65 |
+
'base, pedestal, stand',
|
66 |
+
'box',
|
67 |
+
'column, pillar',
|
68 |
+
'signboard, sign',
|
69 |
+
'chest, of, drawers, chest, bureau, dresser',
|
70 |
+
'counter',
|
71 |
+
'sand',
|
72 |
+
'sink',
|
73 |
+
'skyscraper',
|
74 |
+
'fireplace, hearth, open, fireplace',
|
75 |
+
'refrigerator, icebox',
|
76 |
+
'grandstand, covered, stand',
|
77 |
+
'path',
|
78 |
+
'stairs, steps',
|
79 |
+
'runway',
|
80 |
+
'case, display, case, showcase, vitrine',
|
81 |
+
'pool, table, billiard, table, snooker, table',
|
82 |
+
'pillow',
|
83 |
+
'screen, door, screen',
|
84 |
+
'stairway, staircase',
|
85 |
+
'river',
|
86 |
+
'bridge, span',
|
87 |
+
'bookcase',
|
88 |
+
'blind, screen',
|
89 |
+
'coffee, table, cocktail, table',
|
90 |
+
'toilet, can, commode, crapper, pot, potty, stool, throne',
|
91 |
+
'flower',
|
92 |
+
'book',
|
93 |
+
'hill',
|
94 |
+
'bench',
|
95 |
+
'countertop',
|
96 |
+
'stove, kitchen, stove, range, kitchen, range, cooking, stove',
|
97 |
+
'palm, palm, tree',
|
98 |
+
'kitchen, island',
|
99 |
+
(
|
100 |
+
'computer, computing, machine, computing, device, data, processor,'
|
101 |
+
' electronic, computer, information, processing, system'
|
102 |
+
),
|
103 |
+
'swivel, chair',
|
104 |
+
'boat',
|
105 |
+
'bar',
|
106 |
+
'arcade, machine',
|
107 |
+
'hovel, hut, hutch, shack, shanty',
|
108 |
+
(
|
109 |
+
'bus, autobus, coach, charabanc, double-decker, jitney, motorbus,'
|
110 |
+
' motorcoach, omnibus, passenger, vehicle'
|
111 |
+
),
|
112 |
+
'towel',
|
113 |
+
'light, light, source',
|
114 |
+
'truck, motortruck',
|
115 |
+
'tower',
|
116 |
+
'chandelier, pendant, pendent',
|
117 |
+
'awning, sunshade, sunblind',
|
118 |
+
'streetlight, street, lamp',
|
119 |
+
'booth, cubicle, stall, kiosk',
|
120 |
+
(
|
121 |
+
'television, television, receiver, television, set, tv, tv, set, idiot,'
|
122 |
+
' box, boob, tube, telly, goggle, box'
|
123 |
+
),
|
124 |
+
'airplane, aeroplane, plane',
|
125 |
+
'dirt, track',
|
126 |
+
'apparel, wearing, apparel, dress, clothes',
|
127 |
+
'pole',
|
128 |
+
'land, ground, soil',
|
129 |
+
'bannister, banister, balustrade, balusters, handrail',
|
130 |
+
'escalator, moving, staircase, moving, stairway',
|
131 |
+
'ottoman, pouf, pouffe, puff, hassock',
|
132 |
+
'bottle',
|
133 |
+
'buffet, counter, sideboard',
|
134 |
+
'poster, posting, placard, notice, bill, card',
|
135 |
+
'stage',
|
136 |
+
'van',
|
137 |
+
'ship',
|
138 |
+
'fountain',
|
139 |
+
'conveyer, belt, conveyor, belt, conveyer, conveyor, transporter',
|
140 |
+
'canopy',
|
141 |
+
'washer, automatic, washer, washing, machine',
|
142 |
+
'plaything, toy',
|
143 |
+
'swimming, pool, swimming, bath, natatorium',
|
144 |
+
'stool',
|
145 |
+
'barrel, cask',
|
146 |
+
'basket, handbasket',
|
147 |
+
'waterfall, falls',
|
148 |
+
'tent, collapsible, shelter',
|
149 |
+
'bag',
|
150 |
+
'minibike, motorbike',
|
151 |
+
'cradle',
|
152 |
+
'oven',
|
153 |
+
'ball',
|
154 |
+
'food, solid, food',
|
155 |
+
'step, stair',
|
156 |
+
'tank, storage, tank',
|
157 |
+
'trade, name, brand, name, brand, marque',
|
158 |
+
'microwave, microwave, oven',
|
159 |
+
'pot, flowerpot',
|
160 |
+
'animal, animate, being, beast, brute, creature, fauna',
|
161 |
+
'bicycle, bike, wheel, cycle',
|
162 |
+
'lake',
|
163 |
+
'dishwasher, dish, washer, dishwashing, machine',
|
164 |
+
'screen, silver, screen, projection, screen',
|
165 |
+
'blanket, cover',
|
166 |
+
'sculpture',
|
167 |
+
'hood, exhaust, hood',
|
168 |
+
'sconce',
|
169 |
+
'vase',
|
170 |
+
'traffic, light, traffic, signal, stoplight',
|
171 |
+
'tray',
|
172 |
+
(
|
173 |
+
'ashcan, trash, can, garbage, can, wastebin, ash, bin, ash-bin, ashbin,'
|
174 |
+
' dustbin, trash, barrel, trash, bin'
|
175 |
+
),
|
176 |
+
'fan',
|
177 |
+
'pier, wharf, wharfage, dock',
|
178 |
+
'crt, screen',
|
179 |
+
'plate',
|
180 |
+
'monitor, monitoring, device',
|
181 |
+
'bulletin, board, notice, board',
|
182 |
+
'shower',
|
183 |
+
'radiator',
|
184 |
+
'glass, drinking, glass',
|
185 |
+
'clock',
|
186 |
+
'flag',
|
187 |
+
]
|
188 |
+
|
189 |
+
|
190 |
+
ADE_STUFF_CLASS = [
|
191 |
+
'wall',
|
192 |
+
'sky',
|
193 |
+
'floor, flooring',
|
194 |
+
'tree',
|
195 |
+
'ceiling',
|
196 |
+
'road, route',
|
197 |
+
'grass',
|
198 |
+
'earth, ground',
|
199 |
+
'mountain, mount',
|
200 |
+
'plant, flora, plant, life',
|
201 |
+
'water',
|
202 |
+
'sea',
|
203 |
+
'field',
|
204 |
+
'sand',
|
205 |
+
'skyscraper',
|
206 |
+
'path',
|
207 |
+
'river',
|
208 |
+
'bridge, span',
|
209 |
+
'flower',
|
210 |
+
'hill',
|
211 |
+
'land, ground, soil',
|
212 |
+
'dirt, track',
|
213 |
+
'apparel, wearing, apparel, dress, clothes',
|
214 |
+
'lake',
|
215 |
+
'waterfall, falls',
|
216 |
+
]
|
217 |
+
|
218 |
+
ADE_THING_CLASS = [
|
219 |
+
'building, edifice',
|
220 |
+
'bed',
|
221 |
+
'windowpane, window',
|
222 |
+
'cabinet',
|
223 |
+
'sidewalk, pavement',
|
224 |
+
'person, individual, someone, somebody, mortal, soul',
|
225 |
+
'door, double, door',
|
226 |
+
'table',
|
227 |
+
'curtain, drape, drapery, mantle, pall',
|
228 |
+
'chair',
|
229 |
+
'car, auto, automobile, machine, motorcar',
|
230 |
+
'painting, picture',
|
231 |
+
'sofa, couch, lounge',
|
232 |
+
'shelf',
|
233 |
+
'house',
|
234 |
+
'mirror',
|
235 |
+
'rug, carpet, carpeting',
|
236 |
+
'armchair',
|
237 |
+
'seat',
|
238 |
+
'fence, fencing',
|
239 |
+
'desk',
|
240 |
+
'rock, stone',
|
241 |
+
'wardrobe, closet, press',
|
242 |
+
'lamp',
|
243 |
+
'bathtub, bathing, tub, bath, tub',
|
244 |
+
'railing, rail',
|
245 |
+
'cushion',
|
246 |
+
'base, pedestal, stand',
|
247 |
+
'box',
|
248 |
+
'column, pillar',
|
249 |
+
'signboard, sign',
|
250 |
+
'chest, of, drawers, chest, bureau, dresser',
|
251 |
+
'counter',
|
252 |
+
'sink',
|
253 |
+
'fireplace, hearth, open, fireplace',
|
254 |
+
'refrigerator, icebox',
|
255 |
+
'grandstand, covered, stand',
|
256 |
+
'stairs, steps',
|
257 |
+
'runway',
|
258 |
+
'case, display, case, showcase, vitrine',
|
259 |
+
'pool, table, billiard, table, snooker, table',
|
260 |
+
'pillow',
|
261 |
+
'screen, door, screen',
|
262 |
+
'stairway, staircase',
|
263 |
+
'bookcase',
|
264 |
+
'blind, screen',
|
265 |
+
'coffee, table, cocktail, table',
|
266 |
+
'toilet, can, commode, crapper, pot, potty, stool, throne',
|
267 |
+
'book',
|
268 |
+
'bench',
|
269 |
+
'countertop',
|
270 |
+
'stove, kitchen, stove, range, kitchen, range, cooking, stove',
|
271 |
+
'palm, palm, tree',
|
272 |
+
'kitchen, island',
|
273 |
+
(
|
274 |
+
'computer, computing, machine, computing, device, data, processor,'
|
275 |
+
' electronic, computer, information, processing, system'
|
276 |
+
),
|
277 |
+
'swivel, chair',
|
278 |
+
'boat',
|
279 |
+
'bar',
|
280 |
+
'arcade, machine',
|
281 |
+
'hovel, hut, hutch, shack, shanty',
|
282 |
+
(
|
283 |
+
'bus, autobus, coach, charabanc, double-decker, jitney, motorbus,'
|
284 |
+
' motorcoach, omnibus, passenger, vehicle'
|
285 |
+
),
|
286 |
+
'towel',
|
287 |
+
'light, light, source',
|
288 |
+
'truck, motortruck',
|
289 |
+
'tower',
|
290 |
+
'chandelier, pendant, pendent',
|
291 |
+
'awning, sunshade, sunblind',
|
292 |
+
'streetlight, street, lamp',
|
293 |
+
'booth, cubicle, stall, kiosk',
|
294 |
+
(
|
295 |
+
'television, television, receiver, television, set, tv, tv, set, idiot,'
|
296 |
+
' box, boob, tube, telly, goggle, box'
|
297 |
+
),
|
298 |
+
'airplane, aeroplane, plane',
|
299 |
+
'pole',
|
300 |
+
'bannister, banister, balustrade, balusters, handrail',
|
301 |
+
'escalator, moving, staircase, moving, stairway',
|
302 |
+
'ottoman, pouf, pouffe, puff, hassock',
|
303 |
+
'bottle',
|
304 |
+
'buffet, counter, sideboard',
|
305 |
+
'poster, posting, placard, notice, bill, card',
|
306 |
+
'stage',
|
307 |
+
'van',
|
308 |
+
'ship',
|
309 |
+
'fountain',
|
310 |
+
'conveyer, belt, conveyor, belt, conveyer, conveyor, transporter',
|
311 |
+
'canopy',
|
312 |
+
'washer, automatic, washer, washing, machine',
|
313 |
+
'plaything, toy',
|
314 |
+
'swimming, pool, swimming, bath, natatorium',
|
315 |
+
'stool',
|
316 |
+
'barrel, cask',
|
317 |
+
'basket, handbasket',
|
318 |
+
'tent, collapsible, shelter',
|
319 |
+
'bag',
|
320 |
+
'minibike, motorbike',
|
321 |
+
'cradle',
|
322 |
+
'oven',
|
323 |
+
'ball',
|
324 |
+
'food, solid, food',
|
325 |
+
'step, stair',
|
326 |
+
'tank, storage, tank',
|
327 |
+
'trade, name, brand, name, brand, marque',
|
328 |
+
'microwave, microwave, oven',
|
329 |
+
'pot, flowerpot',
|
330 |
+
'animal, animate, being, beast, brute, creature, fauna',
|
331 |
+
'bicycle, bike, wheel, cycle',
|
332 |
+
'dishwasher, dish, washer, dishwashing, machine',
|
333 |
+
'screen, silver, screen, projection, screen',
|
334 |
+
'blanket, cover',
|
335 |
+
'sculpture',
|
336 |
+
'hood, exhaust, hood',
|
337 |
+
'sconce',
|
338 |
+
'vase',
|
339 |
+
'traffic, light, traffic, signal, stoplight',
|
340 |
+
'tray',
|
341 |
+
(
|
342 |
+
'ashcan, trash, can, garbage, can, wastebin, ash, bin, ash-bin, ashbin,'
|
343 |
+
' dustbin, trash, barrel, trash, bin'
|
344 |
+
),
|
345 |
+
'fan',
|
346 |
+
'pier, wharf, wharfage, dock',
|
347 |
+
'crt, screen',
|
348 |
+
'plate',
|
349 |
+
'monitor, monitoring, device',
|
350 |
+
'bulletin, board, notice, board',
|
351 |
+
'shower',
|
352 |
+
'radiator',
|
353 |
+
'glass, drinking, glass',
|
354 |
+
'clock',
|
355 |
+
'flag',
|
356 |
+
]
|
357 |
+
|
358 |
+
|
359 |
+
ADE_STUFF_CLASS_ID = [
|
360 |
+
0,
|
361 |
+
2,
|
362 |
+
3,
|
363 |
+
4,
|
364 |
+
5,
|
365 |
+
6,
|
366 |
+
9,
|
367 |
+
13,
|
368 |
+
16,
|
369 |
+
17,
|
370 |
+
21,
|
371 |
+
26,
|
372 |
+
29,
|
373 |
+
46,
|
374 |
+
48,
|
375 |
+
52,
|
376 |
+
60,
|
377 |
+
61,
|
378 |
+
66,
|
379 |
+
68,
|
380 |
+
94,
|
381 |
+
91,
|
382 |
+
92,
|
383 |
+
128,
|
384 |
+
113,
|
385 |
+
]
|
386 |
+
|
387 |
+
ADE_THING_CLASS_ID = [
|
388 |
+
1,
|
389 |
+
7,
|
390 |
+
8,
|
391 |
+
10,
|
392 |
+
11,
|
393 |
+
12,
|
394 |
+
14,
|
395 |
+
15,
|
396 |
+
18,
|
397 |
+
19,
|
398 |
+
20,
|
399 |
+
22,
|
400 |
+
23,
|
401 |
+
24,
|
402 |
+
25,
|
403 |
+
27,
|
404 |
+
28,
|
405 |
+
30,
|
406 |
+
31,
|
407 |
+
32,
|
408 |
+
33,
|
409 |
+
34,
|
410 |
+
35,
|
411 |
+
36,
|
412 |
+
37,
|
413 |
+
38,
|
414 |
+
39,
|
415 |
+
40,
|
416 |
+
41,
|
417 |
+
42,
|
418 |
+
43,
|
419 |
+
44,
|
420 |
+
45,
|
421 |
+
47,
|
422 |
+
49,
|
423 |
+
50,
|
424 |
+
51,
|
425 |
+
53,
|
426 |
+
54,
|
427 |
+
55,
|
428 |
+
56,
|
429 |
+
57,
|
430 |
+
58,
|
431 |
+
59,
|
432 |
+
62,
|
433 |
+
63,
|
434 |
+
64,
|
435 |
+
65,
|
436 |
+
67,
|
437 |
+
69,
|
438 |
+
70,
|
439 |
+
71,
|
440 |
+
72,
|
441 |
+
73,
|
442 |
+
74,
|
443 |
+
75,
|
444 |
+
76,
|
445 |
+
77,
|
446 |
+
78,
|
447 |
+
79,
|
448 |
+
80,
|
449 |
+
81,
|
450 |
+
82,
|
451 |
+
83,
|
452 |
+
84,
|
453 |
+
85,
|
454 |
+
86,
|
455 |
+
87,
|
456 |
+
88,
|
457 |
+
89,
|
458 |
+
90,
|
459 |
+
93,
|
460 |
+
95,
|
461 |
+
96,
|
462 |
+
97,
|
463 |
+
98,
|
464 |
+
99,
|
465 |
+
100,
|
466 |
+
101,
|
467 |
+
102,
|
468 |
+
103,
|
469 |
+
104,
|
470 |
+
105,
|
471 |
+
106,
|
472 |
+
107,
|
473 |
+
108,
|
474 |
+
109,
|
475 |
+
110,
|
476 |
+
111,
|
477 |
+
112,
|
478 |
+
114,
|
479 |
+
115,
|
480 |
+
116,
|
481 |
+
117,
|
482 |
+
118,
|
483 |
+
119,
|
484 |
+
120,
|
485 |
+
121,
|
486 |
+
122,
|
487 |
+
123,
|
488 |
+
124,
|
489 |
+
125,
|
490 |
+
126,
|
491 |
+
127,
|
492 |
+
129,
|
493 |
+
130,
|
494 |
+
131,
|
495 |
+
132,
|
496 |
+
133,
|
497 |
+
134,
|
498 |
+
135,
|
499 |
+
136,
|
500 |
+
137,
|
501 |
+
138,
|
502 |
+
139,
|
503 |
+
140,
|
504 |
+
141,
|
505 |
+
142,
|
506 |
+
143,
|
507 |
+
144,
|
508 |
+
145,
|
509 |
+
146,
|
510 |
+
147,
|
511 |
+
148,
|
512 |
+
149,
|
513 |
+
]
|
514 |
+
|
515 |
+
|
516 |
+
class ADEDataset(torch.utils.data.Dataset):
|
517 |
+
"""ADE dataset."""
|
518 |
+
|
519 |
+
def __init__(self, root, split='validation', transform=None):
|
520 |
+
"""Construct ADE dataset.
|
521 |
+
|
522 |
+
Args:
|
523 |
+
root (string): Root directory where images are downloaded.
|
524 |
+
split (string): The split of the dataset.
|
525 |
+
transform (callable, optional): Optional transform to be applied on a
|
526 |
+
sample.
|
527 |
+
"""
|
528 |
+
self.root = root
|
529 |
+
self.image_dir = os.path.join(root, 'images', split)
|
530 |
+
self.ann_dir = os.path.join(root, 'annotations', split)
|
531 |
+
self.images = os.listdir(self.image_dir)
|
532 |
+
self.transform = transform
|
533 |
+
|
534 |
+
def __getitem__(self, index):
|
535 |
+
img_path = os.path.join(self.image_dir, self.images[index])
|
536 |
+
img = Image.open(img_path).convert('RGB')
|
537 |
+
img = np.asarray(img)
|
538 |
+
idx = self.images[index].split('.')[0]
|
539 |
+
ann_path = os.path.join(self.ann_dir, f'{idx}.png')
|
540 |
+
ann = np.asarray(Image.open(ann_path), dtype=np.int32)
|
541 |
+
return img, img_path, ann, idx
|
542 |
+
|
543 |
+
def __len__(self):
|
544 |
+
return len(self.images)
|
data/ade847.py
ADDED
@@ -0,0 +1,1827 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2024 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""ADE-847 dataset."""
|
17 |
+
|
18 |
+
import os
|
19 |
+
import numpy as np
|
20 |
+
from PIL import Image
|
21 |
+
# pylint: disable=g-importing-member
|
22 |
+
from torch.utils.data import Dataset
|
23 |
+
|
24 |
+
|
25 |
+
ADE_847_CLASSES = [
|
26 |
+
'wall',
|
27 |
+
'building, edifice',
|
28 |
+
'sky',
|
29 |
+
'tree',
|
30 |
+
'road, route',
|
31 |
+
'floor, flooring',
|
32 |
+
'ceiling',
|
33 |
+
'bed',
|
34 |
+
'sidewalk, pavement',
|
35 |
+
'earth, ground',
|
36 |
+
'cabinet',
|
37 |
+
'person, individual, someone, somebody, mortal, soul',
|
38 |
+
'grass',
|
39 |
+
'windowpane, window',
|
40 |
+
'car, auto, automobile, machine, motorcar',
|
41 |
+
'mountain, mount',
|
42 |
+
'plant, flora, plant life',
|
43 |
+
'table',
|
44 |
+
'chair',
|
45 |
+
'curtain, drape, drapery, mantle, pall',
|
46 |
+
'door',
|
47 |
+
'sofa, couch, lounge',
|
48 |
+
'sea',
|
49 |
+
'painting, picture',
|
50 |
+
'water',
|
51 |
+
'mirror',
|
52 |
+
'house',
|
53 |
+
'rug, carpet, carpeting',
|
54 |
+
'shelf',
|
55 |
+
'armchair',
|
56 |
+
'fence, fencing',
|
57 |
+
'field',
|
58 |
+
'lamp',
|
59 |
+
'rock, stone',
|
60 |
+
'seat',
|
61 |
+
'river',
|
62 |
+
'desk',
|
63 |
+
'bathtub, bathing tub, bath, tub',
|
64 |
+
'railing, rail',
|
65 |
+
'signboard, sign',
|
66 |
+
'cushion',
|
67 |
+
'path',
|
68 |
+
'work surface',
|
69 |
+
'stairs, steps',
|
70 |
+
'column, pillar',
|
71 |
+
'sink',
|
72 |
+
'wardrobe, closet, press',
|
73 |
+
'snow',
|
74 |
+
'refrigerator, icebox',
|
75 |
+
'base, pedestal, stand',
|
76 |
+
'bridge, span',
|
77 |
+
'blind, screen',
|
78 |
+
'runway',
|
79 |
+
'cliff, drop, drop-off',
|
80 |
+
'sand',
|
81 |
+
'fireplace, hearth, open fireplace',
|
82 |
+
'pillow',
|
83 |
+
'screen door, screen',
|
84 |
+
'toilet, can, commode, crapper, pot, potty, stool, throne',
|
85 |
+
'skyscraper',
|
86 |
+
'grandstand, covered stand',
|
87 |
+
'box',
|
88 |
+
'pool table, billiard table, snooker table',
|
89 |
+
'palm, palm tree',
|
90 |
+
'double door',
|
91 |
+
'coffee table, cocktail table',
|
92 |
+
'counter',
|
93 |
+
'countertop',
|
94 |
+
'chest of drawers, chest, bureau, dresser',
|
95 |
+
'kitchen island',
|
96 |
+
'boat',
|
97 |
+
'waterfall, falls',
|
98 |
+
'stove, kitchen stove, range, kitchen range, cooking stove',
|
99 |
+
'flower',
|
100 |
+
'bookcase',
|
101 |
+
'controls',
|
102 |
+
'book',
|
103 |
+
'stairway, staircase',
|
104 |
+
'streetlight, street lamp',
|
105 |
+
(
|
106 |
+
'computer, computing machine, computing device, data processor,'
|
107 |
+
' electronic computer, information processing system'
|
108 |
+
),
|
109 |
+
(
|
110 |
+
'bus, autobus, coach, charabanc, double-decker, jitney, motorbus,'
|
111 |
+
' motorcoach, omnibus, passenger vehicle'
|
112 |
+
),
|
113 |
+
'swivel chair',
|
114 |
+
'light, light source',
|
115 |
+
'bench',
|
116 |
+
'case, display case, showcase, vitrine',
|
117 |
+
'towel',
|
118 |
+
'fountain',
|
119 |
+
'embankment',
|
120 |
+
(
|
121 |
+
'television receiver, television, television set, tv, tv set, idiot'
|
122 |
+
' box, boob tube, telly, goggle box'
|
123 |
+
),
|
124 |
+
'van',
|
125 |
+
'hill',
|
126 |
+
'awning, sunshade, sunblind',
|
127 |
+
'poster, posting, placard, notice, bill, card',
|
128 |
+
'truck, motortruck',
|
129 |
+
'airplane, aeroplane, plane',
|
130 |
+
'pole',
|
131 |
+
'tower',
|
132 |
+
'court',
|
133 |
+
'ball',
|
134 |
+
'aircraft carrier, carrier, flattop, attack aircraft carrier',
|
135 |
+
'buffet, counter, sideboard',
|
136 |
+
'hovel, hut, hutch, shack, shanty',
|
137 |
+
'apparel, wearing apparel, dress, clothes',
|
138 |
+
'minibike, motorbike',
|
139 |
+
'animal, animate being, beast, brute, creature, fauna',
|
140 |
+
'chandelier, pendant, pendent',
|
141 |
+
'step, stair',
|
142 |
+
'booth, cubicle, stall, kiosk',
|
143 |
+
'bicycle, bike, wheel, cycle',
|
144 |
+
'doorframe, doorcase',
|
145 |
+
'sconce',
|
146 |
+
'pond',
|
147 |
+
'trade name, brand name, brand, marque',
|
148 |
+
'bannister, banister, balustrade, balusters, handrail',
|
149 |
+
'bag',
|
150 |
+
'traffic light, traffic signal, stoplight',
|
151 |
+
'gazebo',
|
152 |
+
'escalator, moving staircase, moving stairway',
|
153 |
+
'land, ground, soil',
|
154 |
+
'board, plank',
|
155 |
+
'arcade machine',
|
156 |
+
'eiderdown, duvet, continental quilt',
|
157 |
+
'bar',
|
158 |
+
'stall, stand, sales booth',
|
159 |
+
'playground',
|
160 |
+
'ship',
|
161 |
+
'ottoman, pouf, pouffe, puff, hassock',
|
162 |
+
(
|
163 |
+
'ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin,'
|
164 |
+
' dustbin, trash barrel, trash bin'
|
165 |
+
),
|
166 |
+
'bottle',
|
167 |
+
'cradle',
|
168 |
+
'pot, flowerpot',
|
169 |
+
'conveyer belt, conveyor belt, conveyer, conveyor, transporter',
|
170 |
+
'train, railroad train',
|
171 |
+
'stool',
|
172 |
+
'lake',
|
173 |
+
'tank, storage tank',
|
174 |
+
'ice, water ice',
|
175 |
+
'basket, handbasket',
|
176 |
+
'manhole',
|
177 |
+
'tent, collapsible shelter',
|
178 |
+
'canopy',
|
179 |
+
'microwave, microwave oven',
|
180 |
+
'barrel, cask',
|
181 |
+
'dirt track',
|
182 |
+
'beam',
|
183 |
+
'dishwasher, dish washer, dishwashing machine',
|
184 |
+
'plate',
|
185 |
+
'screen, crt screen',
|
186 |
+
'ruins',
|
187 |
+
'washer, automatic washer, washing machine',
|
188 |
+
'blanket, cover',
|
189 |
+
'plaything, toy',
|
190 |
+
'food, solid food',
|
191 |
+
'screen, silver screen, projection screen',
|
192 |
+
'oven',
|
193 |
+
'stage',
|
194 |
+
'beacon, lighthouse, beacon light, pharos',
|
195 |
+
'umbrella',
|
196 |
+
'sculpture',
|
197 |
+
'aqueduct',
|
198 |
+
'container',
|
199 |
+
'scaffolding, staging',
|
200 |
+
'hood, exhaust hood',
|
201 |
+
'curb, curbing, kerb',
|
202 |
+
'roller coaster',
|
203 |
+
'horse, equus caballus',
|
204 |
+
'catwalk',
|
205 |
+
'glass, drinking glass',
|
206 |
+
'vase',
|
207 |
+
'central reservation',
|
208 |
+
'carousel',
|
209 |
+
'radiator',
|
210 |
+
'closet',
|
211 |
+
'machine',
|
212 |
+
'pier, wharf, wharfage, dock',
|
213 |
+
'fan',
|
214 |
+
'inflatable bounce game',
|
215 |
+
'pitch',
|
216 |
+
'paper',
|
217 |
+
'arcade, colonnade',
|
218 |
+
'hot tub',
|
219 |
+
'helicopter',
|
220 |
+
'tray',
|
221 |
+
'partition, divider',
|
222 |
+
'vineyard',
|
223 |
+
'bowl',
|
224 |
+
'bullring',
|
225 |
+
'flag',
|
226 |
+
'pot',
|
227 |
+
'footbridge, overcrossing, pedestrian bridge',
|
228 |
+
'shower',
|
229 |
+
'bag, traveling bag, travelling bag, grip, suitcase',
|
230 |
+
'bulletin board, notice board',
|
231 |
+
'confessional booth',
|
232 |
+
'trunk, tree trunk, bole',
|
233 |
+
'forest',
|
234 |
+
'elevator door',
|
235 |
+
'laptop, laptop computer',
|
236 |
+
'instrument panel',
|
237 |
+
'bucket, pail',
|
238 |
+
'tapestry, tapis',
|
239 |
+
'platform',
|
240 |
+
'jacket',
|
241 |
+
'gate',
|
242 |
+
'monitor, monitoring device',
|
243 |
+
'telephone booth, phone booth, call box, telephone box, telephone kiosk',
|
244 |
+
'spotlight, spot',
|
245 |
+
'ring',
|
246 |
+
'control panel',
|
247 |
+
'blackboard, chalkboard',
|
248 |
+
'air conditioner, air conditioning',
|
249 |
+
'chest',
|
250 |
+
'clock',
|
251 |
+
'sand dune',
|
252 |
+
'pipe, pipage, piping',
|
253 |
+
'vault',
|
254 |
+
'table football',
|
255 |
+
'cannon',
|
256 |
+
'swimming pool, swimming bath, natatorium',
|
257 |
+
'fluorescent, fluorescent fixture',
|
258 |
+
'statue',
|
259 |
+
'loudspeaker, speaker, speaker unit, loudspeaker system, speaker system',
|
260 |
+
'exhibitor',
|
261 |
+
'ladder',
|
262 |
+
'carport',
|
263 |
+
'dam',
|
264 |
+
'pulpit',
|
265 |
+
'skylight, fanlight',
|
266 |
+
'water tower',
|
267 |
+
'grill, grille, grillwork',
|
268 |
+
'display board',
|
269 |
+
'pane, pane of glass, window glass',
|
270 |
+
'rubbish, trash, scrap',
|
271 |
+
'ice rink',
|
272 |
+
'fruit',
|
273 |
+
'patio',
|
274 |
+
'vending machine',
|
275 |
+
'telephone, phone, telephone set',
|
276 |
+
'net',
|
277 |
+
'backpack, back pack, knapsack, packsack, rucksack, haversack',
|
278 |
+
'jar',
|
279 |
+
'track',
|
280 |
+
'magazine',
|
281 |
+
'shutter',
|
282 |
+
'roof',
|
283 |
+
'banner, streamer',
|
284 |
+
'landfill',
|
285 |
+
'post',
|
286 |
+
'altarpiece, reredos',
|
287 |
+
'hat, chapeau, lid',
|
288 |
+
'arch, archway',
|
289 |
+
'table game',
|
290 |
+
'bag, handbag, pocketbook, purse',
|
291 |
+
'document, written document, papers',
|
292 |
+
'dome',
|
293 |
+
'pier',
|
294 |
+
'shanties',
|
295 |
+
'forecourt',
|
296 |
+
'crane',
|
297 |
+
'dog, domestic dog, canis familiaris',
|
298 |
+
'piano, pianoforte, forte-piano',
|
299 |
+
'drawing',
|
300 |
+
'cabin',
|
301 |
+
'ad, advertisement, advertizement, advertising, advertizing, advert',
|
302 |
+
'amphitheater, amphitheatre, coliseum',
|
303 |
+
'monument',
|
304 |
+
'henhouse',
|
305 |
+
'cockpit',
|
306 |
+
'heater, warmer',
|
307 |
+
'windmill, aerogenerator, wind generator',
|
308 |
+
'pool',
|
309 |
+
'elevator, lift',
|
310 |
+
'decoration, ornament, ornamentation',
|
311 |
+
'labyrinth',
|
312 |
+
'text, textual matter',
|
313 |
+
'printer',
|
314 |
+
'mezzanine, first balcony',
|
315 |
+
'mattress',
|
316 |
+
'straw',
|
317 |
+
'stalls',
|
318 |
+
'patio, terrace',
|
319 |
+
'billboard, hoarding',
|
320 |
+
'bus stop',
|
321 |
+
'trouser, pant',
|
322 |
+
'console table, console',
|
323 |
+
'rack',
|
324 |
+
'notebook',
|
325 |
+
'shrine',
|
326 |
+
'pantry',
|
327 |
+
'cart',
|
328 |
+
'steam shovel',
|
329 |
+
'porch',
|
330 |
+
'postbox, mailbox, letter box',
|
331 |
+
'figurine, statuette',
|
332 |
+
'recycling bin',
|
333 |
+
'folding screen',
|
334 |
+
'telescope',
|
335 |
+
'deck chair, beach chair',
|
336 |
+
'kennel',
|
337 |
+
'coffee maker',
|
338 |
+
"altar, communion table, lord's table",
|
339 |
+
'fish',
|
340 |
+
'easel',
|
341 |
+
'artificial golf green',
|
342 |
+
'iceberg',
|
343 |
+
'candlestick, candle holder',
|
344 |
+
'shower stall, shower bath',
|
345 |
+
'television stand',
|
346 |
+
(
|
347 |
+
'wall socket, wall plug, electric outlet, electrical outlet, outlet,'
|
348 |
+
' electric receptacle'
|
349 |
+
),
|
350 |
+
'skeleton',
|
351 |
+
'grand piano, grand',
|
352 |
+
'candy, confect',
|
353 |
+
'grille door',
|
354 |
+
'pedestal, plinth, footstall',
|
355 |
+
'jersey, t-shirt, tee shirt',
|
356 |
+
'shoe',
|
357 |
+
'gravestone, headstone, tombstone',
|
358 |
+
'shanty',
|
359 |
+
'structure',
|
360 |
+
'rocking chair, rocker',
|
361 |
+
'bird',
|
362 |
+
'place mat',
|
363 |
+
'tomb',
|
364 |
+
'big top',
|
365 |
+
'gas pump, gasoline pump, petrol pump, island dispenser',
|
366 |
+
'lockers',
|
367 |
+
'cage',
|
368 |
+
'finger',
|
369 |
+
'bleachers',
|
370 |
+
'ferris wheel',
|
371 |
+
'hairdresser chair',
|
372 |
+
'mat',
|
373 |
+
'stands',
|
374 |
+
'aquarium, fish tank, marine museum',
|
375 |
+
'streetcar, tram, tramcar, trolley, trolley car',
|
376 |
+
'napkin, table napkin, serviette',
|
377 |
+
'dummy',
|
378 |
+
'booklet, brochure, folder, leaflet, pamphlet',
|
379 |
+
'sand trap',
|
380 |
+
'shop, store',
|
381 |
+
'table cloth',
|
382 |
+
'service station',
|
383 |
+
'coffin',
|
384 |
+
'drawer',
|
385 |
+
'cages',
|
386 |
+
'slot machine, coin machine',
|
387 |
+
'balcony',
|
388 |
+
'volleyball court',
|
389 |
+
'table tennis',
|
390 |
+
'control table',
|
391 |
+
'shirt',
|
392 |
+
'merchandise, ware, product',
|
393 |
+
'railway',
|
394 |
+
'parterre',
|
395 |
+
'chimney',
|
396 |
+
'can, tin, tin can',
|
397 |
+
'tanks',
|
398 |
+
'fabric, cloth, material, textile',
|
399 |
+
'alga, algae',
|
400 |
+
'system',
|
401 |
+
'map',
|
402 |
+
'greenhouse',
|
403 |
+
'mug',
|
404 |
+
'barbecue',
|
405 |
+
'trailer',
|
406 |
+
'toilet tissue, toilet paper, bathroom tissue',
|
407 |
+
'organ',
|
408 |
+
'dishrag, dishcloth',
|
409 |
+
'island',
|
410 |
+
'keyboard',
|
411 |
+
'trench',
|
412 |
+
'basket, basketball hoop, hoop',
|
413 |
+
'steering wheel, wheel',
|
414 |
+
'pitcher, ewer',
|
415 |
+
'goal',
|
416 |
+
'bread, breadstuff, staff of life',
|
417 |
+
'beds',
|
418 |
+
'wood',
|
419 |
+
'file cabinet',
|
420 |
+
'newspaper, paper',
|
421 |
+
'motorboat',
|
422 |
+
'rope',
|
423 |
+
'guitar',
|
424 |
+
'rubble',
|
425 |
+
'scarf',
|
426 |
+
'barrels',
|
427 |
+
'cap',
|
428 |
+
'leaves',
|
429 |
+
'control tower',
|
430 |
+
'dashboard',
|
431 |
+
'bandstand',
|
432 |
+
'lectern',
|
433 |
+
'switch, electric switch, electrical switch',
|
434 |
+
'baseboard, mopboard, skirting board',
|
435 |
+
'shower room',
|
436 |
+
'smoke',
|
437 |
+
'faucet, spigot',
|
438 |
+
'bulldozer',
|
439 |
+
'saucepan',
|
440 |
+
'shops',
|
441 |
+
'meter',
|
442 |
+
'crevasse',
|
443 |
+
'gear',
|
444 |
+
'candelabrum, candelabra',
|
445 |
+
'sofa bed',
|
446 |
+
'tunnel',
|
447 |
+
'pallet',
|
448 |
+
'wire, conducting wire',
|
449 |
+
'kettle, boiler',
|
450 |
+
'bidet',
|
451 |
+
(
|
452 |
+
'baby buggy, baby carriage, carriage, perambulator, pram, stroller,'
|
453 |
+
' go-cart, pushchair, pusher'
|
454 |
+
),
|
455 |
+
'music stand',
|
456 |
+
'pipe, tube',
|
457 |
+
'cup',
|
458 |
+
'parking meter',
|
459 |
+
'ice hockey rink',
|
460 |
+
'shelter',
|
461 |
+
'weeds',
|
462 |
+
'temple',
|
463 |
+
'patty, cake',
|
464 |
+
'ski slope',
|
465 |
+
'panel',
|
466 |
+
'wallet',
|
467 |
+
'wheel',
|
468 |
+
'towel rack, towel horse',
|
469 |
+
'roundabout',
|
470 |
+
'canister, cannister, tin',
|
471 |
+
'rod',
|
472 |
+
'soap dispenser',
|
473 |
+
'bell',
|
474 |
+
'canvas',
|
475 |
+
'box office, ticket office, ticket booth',
|
476 |
+
'teacup',
|
477 |
+
'trellis',
|
478 |
+
'workbench',
|
479 |
+
'valley, vale',
|
480 |
+
'toaster',
|
481 |
+
'knife',
|
482 |
+
'podium',
|
483 |
+
'ramp',
|
484 |
+
'tumble dryer',
|
485 |
+
'fireplug, fire hydrant, plug',
|
486 |
+
'gym shoe, sneaker, tennis shoe',
|
487 |
+
'lab bench',
|
488 |
+
'equipment',
|
489 |
+
'rocky formation',
|
490 |
+
'plastic',
|
491 |
+
'calendar',
|
492 |
+
'caravan',
|
493 |
+
'check-in-desk',
|
494 |
+
'ticket counter',
|
495 |
+
'brush',
|
496 |
+
'mill',
|
497 |
+
'covered bridge',
|
498 |
+
'bowling alley',
|
499 |
+
'hanger',
|
500 |
+
'excavator',
|
501 |
+
'trestle',
|
502 |
+
'revolving door',
|
503 |
+
'blast furnace',
|
504 |
+
'scale, weighing machine',
|
505 |
+
'projector',
|
506 |
+
'soap',
|
507 |
+
'locker',
|
508 |
+
'tractor',
|
509 |
+
'stretcher',
|
510 |
+
'frame',
|
511 |
+
'grating',
|
512 |
+
'alembic',
|
513 |
+
'candle, taper, wax light',
|
514 |
+
'barrier',
|
515 |
+
'cardboard',
|
516 |
+
'cave',
|
517 |
+
'puddle',
|
518 |
+
'tarp',
|
519 |
+
'price tag',
|
520 |
+
'watchtower',
|
521 |
+
'meters',
|
522 |
+
(
|
523 |
+
'light bulb, lightbulb, bulb, incandescent lamp, electric light,'
|
524 |
+
' electric-light bulb'
|
525 |
+
),
|
526 |
+
'tracks',
|
527 |
+
'hair dryer',
|
528 |
+
'skirt',
|
529 |
+
'viaduct',
|
530 |
+
'paper towel',
|
531 |
+
'coat',
|
532 |
+
'sheet',
|
533 |
+
'fire extinguisher, extinguisher, asphyxiator',
|
534 |
+
'water wheel',
|
535 |
+
'pottery, clayware',
|
536 |
+
'magazine rack',
|
537 |
+
'teapot',
|
538 |
+
'microphone, mike',
|
539 |
+
'support',
|
540 |
+
'forklift',
|
541 |
+
'canyon',
|
542 |
+
'cash register, register',
|
543 |
+
'leaf, leafage, foliage',
|
544 |
+
'remote control, remote',
|
545 |
+
'soap dish',
|
546 |
+
'windshield, windscreen',
|
547 |
+
'cat',
|
548 |
+
'cue, cue stick, pool cue, pool stick',
|
549 |
+
'vent, venthole, vent-hole, blowhole',
|
550 |
+
'videos',
|
551 |
+
'shovel',
|
552 |
+
'eaves',
|
553 |
+
'antenna, aerial, transmitting aerial',
|
554 |
+
'shipyard',
|
555 |
+
'hen, biddy',
|
556 |
+
'traffic cone',
|
557 |
+
'washing machines',
|
558 |
+
'truck crane',
|
559 |
+
'cds',
|
560 |
+
'niche',
|
561 |
+
'scoreboard',
|
562 |
+
'briefcase',
|
563 |
+
'boot',
|
564 |
+
'sweater, jumper',
|
565 |
+
'hay',
|
566 |
+
'pack',
|
567 |
+
'bottle rack',
|
568 |
+
'glacier',
|
569 |
+
'pergola',
|
570 |
+
'building materials',
|
571 |
+
'television camera',
|
572 |
+
'first floor',
|
573 |
+
'rifle',
|
574 |
+
'tennis table',
|
575 |
+
'stadium',
|
576 |
+
'safety belt',
|
577 |
+
'cover',
|
578 |
+
'dish rack',
|
579 |
+
'synthesizer',
|
580 |
+
'pumpkin',
|
581 |
+
'gutter',
|
582 |
+
'fruit stand',
|
583 |
+
'ice floe, floe',
|
584 |
+
'handle, grip, handgrip, hold',
|
585 |
+
'wheelchair',
|
586 |
+
'mousepad, mouse mat',
|
587 |
+
'diploma',
|
588 |
+
'fairground ride',
|
589 |
+
'radio',
|
590 |
+
'hotplate',
|
591 |
+
'junk',
|
592 |
+
'wheelbarrow',
|
593 |
+
'stream',
|
594 |
+
'toll plaza',
|
595 |
+
'punching bag',
|
596 |
+
'trough',
|
597 |
+
'throne',
|
598 |
+
'chair desk',
|
599 |
+
'weighbridge',
|
600 |
+
'extractor fan',
|
601 |
+
'hanging clothes',
|
602 |
+
'dish, dish aerial, dish antenna, saucer',
|
603 |
+
'alarm clock, alarm',
|
604 |
+
'ski lift',
|
605 |
+
'chain',
|
606 |
+
'garage',
|
607 |
+
'mechanical shovel',
|
608 |
+
'wine rack',
|
609 |
+
'tramway',
|
610 |
+
'treadmill',
|
611 |
+
'menu',
|
612 |
+
'block',
|
613 |
+
'well',
|
614 |
+
'witness stand',
|
615 |
+
'branch',
|
616 |
+
'duck',
|
617 |
+
'casserole',
|
618 |
+
'frying pan',
|
619 |
+
'desk organizer',
|
620 |
+
'mast',
|
621 |
+
'spectacles, specs, eyeglasses, glasses',
|
622 |
+
'service elevator',
|
623 |
+
'dollhouse',
|
624 |
+
'hammock',
|
625 |
+
'clothes hanging',
|
626 |
+
'photocopier',
|
627 |
+
'notepad',
|
628 |
+
'golf cart',
|
629 |
+
'footpath',
|
630 |
+
'cross',
|
631 |
+
'baptismal font',
|
632 |
+
'boiler',
|
633 |
+
'skip',
|
634 |
+
'rotisserie',
|
635 |
+
'tables',
|
636 |
+
'water mill',
|
637 |
+
'helmet',
|
638 |
+
'cover curtain',
|
639 |
+
'brick',
|
640 |
+
'table runner',
|
641 |
+
'ashtray',
|
642 |
+
'street box',
|
643 |
+
'stick',
|
644 |
+
'hangers',
|
645 |
+
'cells',
|
646 |
+
'urinal',
|
647 |
+
'centerpiece',
|
648 |
+
'portable fridge',
|
649 |
+
'dvds',
|
650 |
+
'golf club',
|
651 |
+
'skirting board',
|
652 |
+
'water cooler',
|
653 |
+
'clipboard',
|
654 |
+
'camera, photographic camera',
|
655 |
+
'pigeonhole',
|
656 |
+
'chips',
|
657 |
+
'food processor',
|
658 |
+
'post box',
|
659 |
+
'lid',
|
660 |
+
'drum',
|
661 |
+
'blender',
|
662 |
+
'cave entrance',
|
663 |
+
'dental chair',
|
664 |
+
'obelisk',
|
665 |
+
'canoe',
|
666 |
+
'mobile',
|
667 |
+
'monitors',
|
668 |
+
'pool ball',
|
669 |
+
'cue rack',
|
670 |
+
'baggage carts',
|
671 |
+
'shore',
|
672 |
+
'fork',
|
673 |
+
'paper filer',
|
674 |
+
'bicycle rack',
|
675 |
+
'coat rack',
|
676 |
+
'garland',
|
677 |
+
'sports bag',
|
678 |
+
'fish tank',
|
679 |
+
'towel dispenser',
|
680 |
+
'carriage',
|
681 |
+
'brochure',
|
682 |
+
'plaque',
|
683 |
+
'stringer',
|
684 |
+
'iron',
|
685 |
+
'spoon',
|
686 |
+
'flag pole',
|
687 |
+
'toilet brush',
|
688 |
+
'book stand',
|
689 |
+
'water faucet, water tap, tap, hydrant',
|
690 |
+
'ticket office',
|
691 |
+
'broom',
|
692 |
+
'dvd',
|
693 |
+
'ice bucket',
|
694 |
+
'carapace, shell, cuticle, shield',
|
695 |
+
'tureen',
|
696 |
+
'folders',
|
697 |
+
'chess',
|
698 |
+
'root',
|
699 |
+
'sewing machine',
|
700 |
+
'model',
|
701 |
+
'pen',
|
702 |
+
'violin',
|
703 |
+
'sweatshirt',
|
704 |
+
'recycling materials',
|
705 |
+
'mitten',
|
706 |
+
'chopping board, cutting board',
|
707 |
+
'mask',
|
708 |
+
'log',
|
709 |
+
'mouse, computer mouse',
|
710 |
+
'grill',
|
711 |
+
'hole',
|
712 |
+
'target',
|
713 |
+
'trash bag',
|
714 |
+
'chalk',
|
715 |
+
'sticks',
|
716 |
+
'balloon',
|
717 |
+
'score',
|
718 |
+
'hair spray',
|
719 |
+
'roll',
|
720 |
+
'runner',
|
721 |
+
'engine',
|
722 |
+
'inflatable glove',
|
723 |
+
'games',
|
724 |
+
'pallets',
|
725 |
+
'baskets',
|
726 |
+
'coop',
|
727 |
+
'dvd player',
|
728 |
+
'rocking horse',
|
729 |
+
'buckets',
|
730 |
+
'bread rolls',
|
731 |
+
'shawl',
|
732 |
+
'watering can',
|
733 |
+
'spotlights',
|
734 |
+
'post-it',
|
735 |
+
'bowls',
|
736 |
+
'security camera',
|
737 |
+
'runner cloth',
|
738 |
+
'lock',
|
739 |
+
'alarm, warning device, alarm system',
|
740 |
+
'side',
|
741 |
+
'roulette',
|
742 |
+
'bone',
|
743 |
+
'cutlery',
|
744 |
+
'pool balls',
|
745 |
+
'wheels',
|
746 |
+
'spice rack',
|
747 |
+
'plant pots',
|
748 |
+
'towel ring',
|
749 |
+
'bread box',
|
750 |
+
'video',
|
751 |
+
'funfair',
|
752 |
+
'breads',
|
753 |
+
'tripod',
|
754 |
+
'ironing board',
|
755 |
+
'skimmer',
|
756 |
+
'hollow',
|
757 |
+
'scratching post',
|
758 |
+
'tricycle',
|
759 |
+
'file box',
|
760 |
+
'mountain pass',
|
761 |
+
'tombstones',
|
762 |
+
'cooker',
|
763 |
+
'card game, cards',
|
764 |
+
'golf bag',
|
765 |
+
'towel paper',
|
766 |
+
'chaise lounge',
|
767 |
+
'sun',
|
768 |
+
'toilet paper holder',
|
769 |
+
'rake',
|
770 |
+
'key',
|
771 |
+
'umbrella stand',
|
772 |
+
'dartboard',
|
773 |
+
'transformer',
|
774 |
+
'fireplace utensils',
|
775 |
+
'sweatshirts',
|
776 |
+
'cellular telephone, cellular phone, cellphone, cell, mobile phone',
|
777 |
+
'tallboy',
|
778 |
+
'stapler',
|
779 |
+
'sauna',
|
780 |
+
'test tube',
|
781 |
+
'palette',
|
782 |
+
'shopping carts',
|
783 |
+
'tools',
|
784 |
+
'push button, push, button',
|
785 |
+
'star',
|
786 |
+
'roof rack',
|
787 |
+
'barbed wire',
|
788 |
+
'spray',
|
789 |
+
'ear',
|
790 |
+
'sponge',
|
791 |
+
'racket',
|
792 |
+
'tins',
|
793 |
+
'eyeglasses',
|
794 |
+
'file',
|
795 |
+
'scarfs',
|
796 |
+
'sugar bowl',
|
797 |
+
'flip flop',
|
798 |
+
'headstones',
|
799 |
+
'laptop bag',
|
800 |
+
'leash',
|
801 |
+
'climbing frame',
|
802 |
+
'suit hanger',
|
803 |
+
'floor spotlight',
|
804 |
+
'plate rack',
|
805 |
+
'sewer',
|
806 |
+
'hard drive',
|
807 |
+
'sprinkler',
|
808 |
+
'tools box',
|
809 |
+
'necklace',
|
810 |
+
'bulbs',
|
811 |
+
'steel industry',
|
812 |
+
'club',
|
813 |
+
'jack',
|
814 |
+
'door bars',
|
815 |
+
'control panel, instrument panel, control board, board, panel',
|
816 |
+
'hairbrush',
|
817 |
+
'napkin holder',
|
818 |
+
'office',
|
819 |
+
'smoke detector',
|
820 |
+
'utensils',
|
821 |
+
'apron',
|
822 |
+
'scissors',
|
823 |
+
'terminal',
|
824 |
+
'grinder',
|
825 |
+
'entry phone',
|
826 |
+
'newspaper stand',
|
827 |
+
'pepper shaker',
|
828 |
+
'onions',
|
829 |
+
(
|
830 |
+
'central processing unit, cpu, c p u , central processor, processor,'
|
831 |
+
' mainframe'
|
832 |
+
),
|
833 |
+
'tape',
|
834 |
+
'bat',
|
835 |
+
'coaster',
|
836 |
+
'calculator',
|
837 |
+
'potatoes',
|
838 |
+
'luggage rack',
|
839 |
+
'salt',
|
840 |
+
'street number',
|
841 |
+
'viewpoint',
|
842 |
+
'sword',
|
843 |
+
'cd',
|
844 |
+
'rowing machine',
|
845 |
+
'plug',
|
846 |
+
'andiron, firedog, dog, dog-iron',
|
847 |
+
'pepper',
|
848 |
+
'tongs',
|
849 |
+
'bonfire',
|
850 |
+
'dog dish',
|
851 |
+
'belt',
|
852 |
+
'dumbbells',
|
853 |
+
'videocassette recorder, vcr',
|
854 |
+
'hook',
|
855 |
+
'envelopes',
|
856 |
+
'shower faucet',
|
857 |
+
'watch',
|
858 |
+
'padlock',
|
859 |
+
'swimming pool ladder',
|
860 |
+
'spanners',
|
861 |
+
'gravy boat',
|
862 |
+
'notice board',
|
863 |
+
'trash bags',
|
864 |
+
'fire alarm',
|
865 |
+
'ladle',
|
866 |
+
'stethoscope',
|
867 |
+
'rocket',
|
868 |
+
'funnel',
|
869 |
+
'bowling pins',
|
870 |
+
'valve',
|
871 |
+
'thermometer',
|
872 |
+
'cups',
|
873 |
+
'spice jar',
|
874 |
+
'night light',
|
875 |
+
'soaps',
|
876 |
+
'games table',
|
877 |
+
'slotted spoon',
|
878 |
+
'reel',
|
879 |
+
'scourer',
|
880 |
+
'sleeping robe',
|
881 |
+
'desk mat',
|
882 |
+
'dumbbell',
|
883 |
+
'hammer',
|
884 |
+
'tie',
|
885 |
+
'typewriter',
|
886 |
+
'shaker',
|
887 |
+
'cheese dish',
|
888 |
+
'sea star',
|
889 |
+
'racquet',
|
890 |
+
'butane gas cylinder',
|
891 |
+
'paper weight',
|
892 |
+
'shaving brush',
|
893 |
+
'sunglasses',
|
894 |
+
'gear shift',
|
895 |
+
'towel rail',
|
896 |
+
'adding machine, totalizer, totaliser',
|
897 |
+
]
|
898 |
+
|
899 |
+
ADE_847_CLASS_ID = list(range(847))
|
900 |
+
|
901 |
+
ADE_847_STUFF_CLASS = [
|
902 |
+
'wall',
|
903 |
+
'sky',
|
904 |
+
'tree',
|
905 |
+
'road, route',
|
906 |
+
'floor, flooring',
|
907 |
+
'sidewalk, pavement',
|
908 |
+
'earth, ground',
|
909 |
+
'grass',
|
910 |
+
'mountain, mount',
|
911 |
+
'plant, flora, plant life',
|
912 |
+
'sea',
|
913 |
+
'water',
|
914 |
+
'rock, stone',
|
915 |
+
'snow',
|
916 |
+
'sand',
|
917 |
+
'island',
|
918 |
+
'field',
|
919 |
+
'forest',
|
920 |
+
'land, ground, soil',
|
921 |
+
'lake',
|
922 |
+
'ice, water ice',
|
923 |
+
'cliff, drop, drop-off',
|
924 |
+
'dirt track',
|
925 |
+
'hill',
|
926 |
+
'valley, vale',
|
927 |
+
'stream',
|
928 |
+
'shore',
|
929 |
+
'pond',
|
930 |
+
'iceberg',
|
931 |
+
]
|
932 |
+
|
933 |
+
ADE_847_THING_CLASS = [
|
934 |
+
'building, edifice',
|
935 |
+
'ceiling',
|
936 |
+
'bed',
|
937 |
+
'cabinet',
|
938 |
+
'person, individual, someone, somebody, mortal, soul',
|
939 |
+
'windowpane, window',
|
940 |
+
'car, auto, automobile, machine, motorcar',
|
941 |
+
'table',
|
942 |
+
'chair',
|
943 |
+
'curtain, drape, drapery, mantle, pall',
|
944 |
+
'door',
|
945 |
+
'sofa, couch, lounge',
|
946 |
+
'painting, picture',
|
947 |
+
'mirror',
|
948 |
+
'house',
|
949 |
+
'rug, carpet, carpeting',
|
950 |
+
'shelf',
|
951 |
+
'armchair',
|
952 |
+
'fence, fencing',
|
953 |
+
'lamp',
|
954 |
+
'seat',
|
955 |
+
'river',
|
956 |
+
'desk',
|
957 |
+
'bathtub, bathing tub, bath, tub',
|
958 |
+
'railing, rail',
|
959 |
+
'signboard, sign',
|
960 |
+
'cushion',
|
961 |
+
'path',
|
962 |
+
'work surface',
|
963 |
+
'stairs, steps',
|
964 |
+
'column, pillar',
|
965 |
+
'sink',
|
966 |
+
'wardrobe, closet, press',
|
967 |
+
'refrigerator, icebox',
|
968 |
+
'base, pedestal, stand',
|
969 |
+
'bridge, span',
|
970 |
+
'blind, screen',
|
971 |
+
'runway',
|
972 |
+
'fireplace, hearth, open fireplace',
|
973 |
+
'pillow',
|
974 |
+
'screen door, screen',
|
975 |
+
'toilet, can, commode, crapper, pot, potty, stool, throne',
|
976 |
+
'skyscraper',
|
977 |
+
'grandstand, covered stand',
|
978 |
+
'box',
|
979 |
+
'pool table, billiard table, snooker table',
|
980 |
+
'palm, palm tree',
|
981 |
+
'double door',
|
982 |
+
'coffee table, cocktail table',
|
983 |
+
'counter',
|
984 |
+
'countertop',
|
985 |
+
'chest of drawers, chest, bureau, dresser',
|
986 |
+
'kitchen island',
|
987 |
+
'boat',
|
988 |
+
'waterfall, falls',
|
989 |
+
'stove, kitchen stove, range, kitchen range, cooking stove',
|
990 |
+
'flower',
|
991 |
+
'bookcase',
|
992 |
+
'controls',
|
993 |
+
'book',
|
994 |
+
'stairway, staircase',
|
995 |
+
'streetlight, street lamp',
|
996 |
+
(
|
997 |
+
'computer, computing machine, computing device, data processor,'
|
998 |
+
' electronic computer, information processing system'
|
999 |
+
),
|
1000 |
+
(
|
1001 |
+
'bus, autobus, coach, charabanc, double-decker, jitney, motorbus,'
|
1002 |
+
' motorcoach, omnibus, passenger vehicle'
|
1003 |
+
),
|
1004 |
+
'swivel chair',
|
1005 |
+
'light, light source',
|
1006 |
+
'bench',
|
1007 |
+
'case, display case, showcase, vitrine',
|
1008 |
+
'towel',
|
1009 |
+
'fountain',
|
1010 |
+
'embankment',
|
1011 |
+
(
|
1012 |
+
'television receiver, television, television set, tv, tv set, idiot'
|
1013 |
+
' box, boob tube, telly, goggle box'
|
1014 |
+
),
|
1015 |
+
'van',
|
1016 |
+
'awning, sunshade, sunblind',
|
1017 |
+
'poster, posting, placard, notice, bill, card',
|
1018 |
+
'truck, motortruck',
|
1019 |
+
'airplane, aeroplane, plane',
|
1020 |
+
'pole',
|
1021 |
+
'tower',
|
1022 |
+
'court',
|
1023 |
+
'ball',
|
1024 |
+
'aircraft carrier, carrier, flattop, attack aircraft carrier',
|
1025 |
+
'buffet, counter, sideboard',
|
1026 |
+
'hovel, hut, hutch, shack, shanty',
|
1027 |
+
'apparel, wearing apparel, dress, clothes',
|
1028 |
+
'minibike, motorbike',
|
1029 |
+
'animal, animate being, beast, brute, creature, fauna',
|
1030 |
+
'chandelier, pendant, pendent',
|
1031 |
+
'step, stair',
|
1032 |
+
'booth, cubicle, stall, kiosk',
|
1033 |
+
'bicycle, bike, wheel, cycle',
|
1034 |
+
'doorframe, doorcase',
|
1035 |
+
'sconce',
|
1036 |
+
'trade name, brand name, brand, marque',
|
1037 |
+
'bannister, banister, balustrade, balusters, handrail',
|
1038 |
+
'bag',
|
1039 |
+
'traffic light, traffic signal, stoplight',
|
1040 |
+
'gazebo',
|
1041 |
+
'escalator, moving staircase, moving stairway',
|
1042 |
+
'board, plank',
|
1043 |
+
'arcade machine',
|
1044 |
+
'eiderdown, duvet, continental quilt',
|
1045 |
+
'bar',
|
1046 |
+
'stall, stand, sales booth',
|
1047 |
+
'playground',
|
1048 |
+
'ship',
|
1049 |
+
'ottoman, pouf, pouffe, puff, hassock',
|
1050 |
+
(
|
1051 |
+
'ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin,'
|
1052 |
+
' dustbin, trash barrel, trash bin'
|
1053 |
+
),
|
1054 |
+
'bottle',
|
1055 |
+
'cradle',
|
1056 |
+
'pot, flowerpot',
|
1057 |
+
'conveyer belt, conveyor belt, conveyer, conveyor, transporter',
|
1058 |
+
'train, railroad train',
|
1059 |
+
'stool',
|
1060 |
+
'tank, storage tank',
|
1061 |
+
'basket, handbasket',
|
1062 |
+
'manhole',
|
1063 |
+
'tent, collapsible shelter',
|
1064 |
+
'canopy',
|
1065 |
+
'microwave, microwave oven',
|
1066 |
+
'barrel, cask',
|
1067 |
+
'beam',
|
1068 |
+
'dishwasher, dish washer, dishwashing machine',
|
1069 |
+
'plate',
|
1070 |
+
'screen, crt screen',
|
1071 |
+
'ruins',
|
1072 |
+
'washer, automatic washer, washing machine',
|
1073 |
+
'blanket, cover',
|
1074 |
+
'plaything, toy',
|
1075 |
+
'food, solid food',
|
1076 |
+
'screen, silver screen, projection screen',
|
1077 |
+
'oven',
|
1078 |
+
'stage',
|
1079 |
+
'beacon, lighthouse, beacon light, pharos',
|
1080 |
+
'umbrella',
|
1081 |
+
'sculpture',
|
1082 |
+
'aqueduct',
|
1083 |
+
'container',
|
1084 |
+
'scaffolding, staging',
|
1085 |
+
'hood, exhaust hood',
|
1086 |
+
'curb, curbing, kerb',
|
1087 |
+
'roller coaster',
|
1088 |
+
'horse, equus caballus',
|
1089 |
+
'catwalk',
|
1090 |
+
'glass, drinking glass',
|
1091 |
+
'vase',
|
1092 |
+
'central reservation',
|
1093 |
+
'carousel',
|
1094 |
+
'radiator',
|
1095 |
+
'closet',
|
1096 |
+
'machine',
|
1097 |
+
'pier, wharf, wharfage, dock',
|
1098 |
+
'fan',
|
1099 |
+
'inflatable bounce game',
|
1100 |
+
'pitch',
|
1101 |
+
'paper',
|
1102 |
+
'arcade, colonnade',
|
1103 |
+
'hot tub',
|
1104 |
+
'helicopter',
|
1105 |
+
'tray',
|
1106 |
+
'partition, divider',
|
1107 |
+
'vineyard',
|
1108 |
+
'bowl',
|
1109 |
+
'bullring',
|
1110 |
+
'flag',
|
1111 |
+
'pot',
|
1112 |
+
'footbridge, overcrossing, pedestrian bridge',
|
1113 |
+
'shower',
|
1114 |
+
'bag, traveling bag, travelling bag, grip, suitcase',
|
1115 |
+
'bulletin board, notice board',
|
1116 |
+
'confessional booth',
|
1117 |
+
'trunk, tree trunk, bole',
|
1118 |
+
'elevator door',
|
1119 |
+
'laptop, laptop computer',
|
1120 |
+
'instrument panel',
|
1121 |
+
'bucket, pail',
|
1122 |
+
'tapestry, tapis',
|
1123 |
+
'platform',
|
1124 |
+
'jacket',
|
1125 |
+
'gate',
|
1126 |
+
'monitor, monitoring device',
|
1127 |
+
'telephone booth, phone booth, call box, telephone box, telephone kiosk',
|
1128 |
+
'spotlight, spot',
|
1129 |
+
'ring',
|
1130 |
+
'control panel',
|
1131 |
+
'blackboard, chalkboard',
|
1132 |
+
'air conditioner, air conditioning',
|
1133 |
+
'chest',
|
1134 |
+
'clock',
|
1135 |
+
'sand dune',
|
1136 |
+
'pipe, pipage, piping',
|
1137 |
+
'vault',
|
1138 |
+
'table football',
|
1139 |
+
'cannon',
|
1140 |
+
'swimming pool, swimming bath, natatorium',
|
1141 |
+
'fluorescent, fluorescent fixture',
|
1142 |
+
'statue',
|
1143 |
+
'loudspeaker, speaker, speaker unit, loudspeaker system, speaker system',
|
1144 |
+
'exhibitor',
|
1145 |
+
'ladder',
|
1146 |
+
'carport',
|
1147 |
+
'dam',
|
1148 |
+
'pulpit',
|
1149 |
+
'skylight, fanlight',
|
1150 |
+
'water tower',
|
1151 |
+
'grill, grille, grillwork',
|
1152 |
+
'display board',
|
1153 |
+
'pane, pane of glass, window glass',
|
1154 |
+
'rubbish, trash, scrap',
|
1155 |
+
'ice rink',
|
1156 |
+
'fruit',
|
1157 |
+
'patio',
|
1158 |
+
'vending machine',
|
1159 |
+
'telephone, phone, telephone set',
|
1160 |
+
'net',
|
1161 |
+
'backpack, back pack, knapsack, packsack, rucksack, haversack',
|
1162 |
+
'jar',
|
1163 |
+
'track',
|
1164 |
+
'magazine',
|
1165 |
+
'shutter',
|
1166 |
+
'roof',
|
1167 |
+
'banner, streamer',
|
1168 |
+
'landfill',
|
1169 |
+
'post',
|
1170 |
+
'altarpiece, reredos',
|
1171 |
+
'hat, chapeau, lid',
|
1172 |
+
'arch, archway',
|
1173 |
+
'table game',
|
1174 |
+
'bag, handbag, pocketbook, purse',
|
1175 |
+
'document, written document, papers',
|
1176 |
+
'dome',
|
1177 |
+
'pier',
|
1178 |
+
'shanties',
|
1179 |
+
'forecourt',
|
1180 |
+
'crane',
|
1181 |
+
'dog, domestic dog, canis familiaris',
|
1182 |
+
'piano, pianoforte, forte-piano',
|
1183 |
+
'drawing',
|
1184 |
+
'cabin',
|
1185 |
+
'ad, advertisement, advertizement, advertising, advertizing, advert',
|
1186 |
+
'amphitheater, amphitheatre, coliseum',
|
1187 |
+
'monument',
|
1188 |
+
'henhouse',
|
1189 |
+
'cockpit',
|
1190 |
+
'heater, warmer',
|
1191 |
+
'windmill, aerogenerator, wind generator',
|
1192 |
+
'pool',
|
1193 |
+
'elevator, lift',
|
1194 |
+
'decoration, ornament, ornamentation',
|
1195 |
+
'labyrinth',
|
1196 |
+
'text, textual matter',
|
1197 |
+
'printer',
|
1198 |
+
'mezzanine, first balcony',
|
1199 |
+
'mattress',
|
1200 |
+
'straw',
|
1201 |
+
'stalls',
|
1202 |
+
'patio, terrace',
|
1203 |
+
'billboard, hoarding',
|
1204 |
+
'bus stop',
|
1205 |
+
'trouser, pant',
|
1206 |
+
'console table, console',
|
1207 |
+
'rack',
|
1208 |
+
'notebook',
|
1209 |
+
'shrine',
|
1210 |
+
'pantry',
|
1211 |
+
'cart',
|
1212 |
+
'steam shovel',
|
1213 |
+
'porch',
|
1214 |
+
'postbox, mailbox, letter box',
|
1215 |
+
'figurine, statuette',
|
1216 |
+
'recycling bin',
|
1217 |
+
'folding screen',
|
1218 |
+
'telescope',
|
1219 |
+
'deck chair, beach chair',
|
1220 |
+
'kennel',
|
1221 |
+
'coffee maker',
|
1222 |
+
"altar, communion table, lord's table",
|
1223 |
+
'fish',
|
1224 |
+
'easel',
|
1225 |
+
'artificial golf green',
|
1226 |
+
'candlestick, candle holder',
|
1227 |
+
'shower stall, shower bath',
|
1228 |
+
'television stand',
|
1229 |
+
(
|
1230 |
+
'wall socket, wall plug, electric outlet, electrical outlet, outlet,'
|
1231 |
+
' electric receptacle'
|
1232 |
+
),
|
1233 |
+
'skeleton',
|
1234 |
+
'grand piano, grand',
|
1235 |
+
'candy, confect',
|
1236 |
+
'grille door',
|
1237 |
+
'pedestal, plinth, footstall',
|
1238 |
+
'jersey, t-shirt, tee shirt',
|
1239 |
+
'shoe',
|
1240 |
+
'gravestone, headstone, tombstone',
|
1241 |
+
'shanty',
|
1242 |
+
'structure',
|
1243 |
+
'rocking chair, rocker',
|
1244 |
+
'bird',
|
1245 |
+
'place mat',
|
1246 |
+
'tomb',
|
1247 |
+
'big top',
|
1248 |
+
'gas pump, gasoline pump, petrol pump, island dispenser',
|
1249 |
+
'lockers',
|
1250 |
+
'cage',
|
1251 |
+
'finger',
|
1252 |
+
'bleachers',
|
1253 |
+
'ferris wheel',
|
1254 |
+
'hairdresser chair',
|
1255 |
+
'mat',
|
1256 |
+
'stands',
|
1257 |
+
'aquarium, fish tank, marine museum',
|
1258 |
+
'streetcar, tram, tramcar, trolley, trolley car',
|
1259 |
+
'napkin, table napkin, serviette',
|
1260 |
+
'dummy',
|
1261 |
+
'booklet, brochure, folder, leaflet, pamphlet',
|
1262 |
+
'sand trap',
|
1263 |
+
'shop, store',
|
1264 |
+
'table cloth',
|
1265 |
+
'service station',
|
1266 |
+
'coffin',
|
1267 |
+
'drawer',
|
1268 |
+
'cages',
|
1269 |
+
'slot machine, coin machine',
|
1270 |
+
'balcony',
|
1271 |
+
'volleyball court',
|
1272 |
+
'table tennis',
|
1273 |
+
'control table',
|
1274 |
+
'shirt',
|
1275 |
+
'merchandise, ware, product',
|
1276 |
+
'railway',
|
1277 |
+
'parterre',
|
1278 |
+
'chimney',
|
1279 |
+
'can, tin, tin can',
|
1280 |
+
'tanks',
|
1281 |
+
'fabric, cloth, material, textile',
|
1282 |
+
'alga, algae',
|
1283 |
+
'system',
|
1284 |
+
'map',
|
1285 |
+
'greenhouse',
|
1286 |
+
'mug',
|
1287 |
+
'barbecue',
|
1288 |
+
'trailer',
|
1289 |
+
'toilet tissue, toilet paper, bathroom tissue',
|
1290 |
+
'organ',
|
1291 |
+
'dishrag, dishcloth',
|
1292 |
+
'keyboard',
|
1293 |
+
'trench',
|
1294 |
+
'basket, basketball hoop, hoop',
|
1295 |
+
'steering wheel, wheel',
|
1296 |
+
'pitcher, ewer',
|
1297 |
+
'goal',
|
1298 |
+
'bread, breadstuff, staff of life',
|
1299 |
+
'beds',
|
1300 |
+
'wood',
|
1301 |
+
'file cabinet',
|
1302 |
+
'newspaper, paper',
|
1303 |
+
'motorboat',
|
1304 |
+
'rope',
|
1305 |
+
'guitar',
|
1306 |
+
'rubble',
|
1307 |
+
'scarf',
|
1308 |
+
'barrels',
|
1309 |
+
'cap',
|
1310 |
+
'leaves',
|
1311 |
+
'control tower',
|
1312 |
+
'dashboard',
|
1313 |
+
'bandstand',
|
1314 |
+
'lectern',
|
1315 |
+
'switch, electric switch, electrical switch',
|
1316 |
+
'baseboard, mopboard, skirting board',
|
1317 |
+
'shower room',
|
1318 |
+
'smoke',
|
1319 |
+
'faucet, spigot',
|
1320 |
+
'bulldozer',
|
1321 |
+
'saucepan',
|
1322 |
+
'shops',
|
1323 |
+
'meter',
|
1324 |
+
'crevasse',
|
1325 |
+
'gear',
|
1326 |
+
'candelabrum, candelabra',
|
1327 |
+
'sofa bed',
|
1328 |
+
'tunnel',
|
1329 |
+
'pallet',
|
1330 |
+
'wire, conducting wire',
|
1331 |
+
'kettle, boiler',
|
1332 |
+
'bidet',
|
1333 |
+
(
|
1334 |
+
'baby buggy, baby carriage, carriage, perambulator, pram, stroller,'
|
1335 |
+
' go-cart, pushchair, pusher'
|
1336 |
+
),
|
1337 |
+
'music stand',
|
1338 |
+
'pipe, tube',
|
1339 |
+
'cup',
|
1340 |
+
'parking meter',
|
1341 |
+
'ice hockey rink',
|
1342 |
+
'shelter',
|
1343 |
+
'weeds',
|
1344 |
+
'temple',
|
1345 |
+
'patty, cake',
|
1346 |
+
'ski slope',
|
1347 |
+
'panel',
|
1348 |
+
'wallet',
|
1349 |
+
'wheel',
|
1350 |
+
'towel rack, towel horse',
|
1351 |
+
'roundabout',
|
1352 |
+
'canister, cannister, tin',
|
1353 |
+
'rod',
|
1354 |
+
'soap dispenser',
|
1355 |
+
'bell',
|
1356 |
+
'canvas',
|
1357 |
+
'box office, ticket office, ticket booth',
|
1358 |
+
'teacup',
|
1359 |
+
'trellis',
|
1360 |
+
'workbench',
|
1361 |
+
'toaster',
|
1362 |
+
'knife',
|
1363 |
+
'podium',
|
1364 |
+
'ramp',
|
1365 |
+
'tumble dryer',
|
1366 |
+
'fireplug, fire hydrant, plug',
|
1367 |
+
'gym shoe, sneaker, tennis shoe',
|
1368 |
+
'lab bench',
|
1369 |
+
'equipment',
|
1370 |
+
'rocky formation',
|
1371 |
+
'plastic',
|
1372 |
+
'calendar',
|
1373 |
+
'caravan',
|
1374 |
+
'check-in-desk',
|
1375 |
+
'ticket counter',
|
1376 |
+
'brush',
|
1377 |
+
'mill',
|
1378 |
+
'covered bridge',
|
1379 |
+
'bowling alley',
|
1380 |
+
'hanger',
|
1381 |
+
'excavator',
|
1382 |
+
'trestle',
|
1383 |
+
'revolving door',
|
1384 |
+
'blast furnace',
|
1385 |
+
'scale, weighing machine',
|
1386 |
+
'projector',
|
1387 |
+
'soap',
|
1388 |
+
'locker',
|
1389 |
+
'tractor',
|
1390 |
+
'stretcher',
|
1391 |
+
'frame',
|
1392 |
+
'grating',
|
1393 |
+
'alembic',
|
1394 |
+
'candle, taper, wax light',
|
1395 |
+
'barrier',
|
1396 |
+
'cardboard',
|
1397 |
+
'cave',
|
1398 |
+
'puddle',
|
1399 |
+
'tarp',
|
1400 |
+
'price tag',
|
1401 |
+
'watchtower',
|
1402 |
+
'meters',
|
1403 |
+
(
|
1404 |
+
'light bulb, lightbulb, bulb, incandescent lamp, electric light,'
|
1405 |
+
' electric-light bulb'
|
1406 |
+
),
|
1407 |
+
'tracks',
|
1408 |
+
'hair dryer',
|
1409 |
+
'skirt',
|
1410 |
+
'viaduct',
|
1411 |
+
'paper towel',
|
1412 |
+
'coat',
|
1413 |
+
'sheet',
|
1414 |
+
'fire extinguisher, extinguisher, asphyxiator',
|
1415 |
+
'water wheel',
|
1416 |
+
'pottery, clayware',
|
1417 |
+
'magazine rack',
|
1418 |
+
'teapot',
|
1419 |
+
'microphone, mike',
|
1420 |
+
'support',
|
1421 |
+
'forklift',
|
1422 |
+
'canyon',
|
1423 |
+
'cash register, register',
|
1424 |
+
'leaf, leafage, foliage',
|
1425 |
+
'remote control, remote',
|
1426 |
+
'soap dish',
|
1427 |
+
'windshield, windscreen',
|
1428 |
+
'cat',
|
1429 |
+
'cue, cue stick, pool cue, pool stick',
|
1430 |
+
'vent, venthole, vent-hole, blowhole',
|
1431 |
+
'videos',
|
1432 |
+
'shovel',
|
1433 |
+
'eaves',
|
1434 |
+
'antenna, aerial, transmitting aerial',
|
1435 |
+
'shipyard',
|
1436 |
+
'hen, biddy',
|
1437 |
+
'traffic cone',
|
1438 |
+
'washing machines',
|
1439 |
+
'truck crane',
|
1440 |
+
'cds',
|
1441 |
+
'niche',
|
1442 |
+
'scoreboard',
|
1443 |
+
'briefcase',
|
1444 |
+
'boot',
|
1445 |
+
'sweater, jumper',
|
1446 |
+
'hay',
|
1447 |
+
'pack',
|
1448 |
+
'bottle rack',
|
1449 |
+
'glacier',
|
1450 |
+
'pergola',
|
1451 |
+
'building materials',
|
1452 |
+
'television camera',
|
1453 |
+
'first floor',
|
1454 |
+
'rifle',
|
1455 |
+
'tennis table',
|
1456 |
+
'stadium',
|
1457 |
+
'safety belt',
|
1458 |
+
'cover',
|
1459 |
+
'dish rack',
|
1460 |
+
'synthesizer',
|
1461 |
+
'pumpkin',
|
1462 |
+
'gutter',
|
1463 |
+
'fruit stand',
|
1464 |
+
'ice floe, floe',
|
1465 |
+
'handle, grip, handgrip, hold',
|
1466 |
+
'wheelchair',
|
1467 |
+
'mousepad, mouse mat',
|
1468 |
+
'diploma',
|
1469 |
+
'fairground ride',
|
1470 |
+
'radio',
|
1471 |
+
'hotplate',
|
1472 |
+
'junk',
|
1473 |
+
'wheelbarrow',
|
1474 |
+
'toll plaza',
|
1475 |
+
'punching bag',
|
1476 |
+
'trough',
|
1477 |
+
'throne',
|
1478 |
+
'chair desk',
|
1479 |
+
'weighbridge',
|
1480 |
+
'extractor fan',
|
1481 |
+
'hanging clothes',
|
1482 |
+
'dish, dish aerial, dish antenna, saucer',
|
1483 |
+
'alarm clock, alarm',
|
1484 |
+
'ski lift',
|
1485 |
+
'chain',
|
1486 |
+
'garage',
|
1487 |
+
'mechanical shovel',
|
1488 |
+
'wine rack',
|
1489 |
+
'tramway',
|
1490 |
+
'treadmill',
|
1491 |
+
'menu',
|
1492 |
+
'block',
|
1493 |
+
'well',
|
1494 |
+
'witness stand',
|
1495 |
+
'branch',
|
1496 |
+
'duck',
|
1497 |
+
'casserole',
|
1498 |
+
'frying pan',
|
1499 |
+
'desk organizer',
|
1500 |
+
'mast',
|
1501 |
+
'spectacles, specs, eyeglasses, glasses',
|
1502 |
+
'service elevator',
|
1503 |
+
'dollhouse',
|
1504 |
+
'hammock',
|
1505 |
+
'clothes hanging',
|
1506 |
+
'photocopier',
|
1507 |
+
'notepad',
|
1508 |
+
'golf cart',
|
1509 |
+
'footpath',
|
1510 |
+
'cross',
|
1511 |
+
'baptismal font',
|
1512 |
+
'boiler',
|
1513 |
+
'skip',
|
1514 |
+
'rotisserie',
|
1515 |
+
'tables',
|
1516 |
+
'water mill',
|
1517 |
+
'helmet',
|
1518 |
+
'cover curtain',
|
1519 |
+
'brick',
|
1520 |
+
'table runner',
|
1521 |
+
'ashtray',
|
1522 |
+
'street box',
|
1523 |
+
'stick',
|
1524 |
+
'hangers',
|
1525 |
+
'cells',
|
1526 |
+
'urinal',
|
1527 |
+
'centerpiece',
|
1528 |
+
'portable fridge',
|
1529 |
+
'dvds',
|
1530 |
+
'golf club',
|
1531 |
+
'skirting board',
|
1532 |
+
'water cooler',
|
1533 |
+
'clipboard',
|
1534 |
+
'camera, photographic camera',
|
1535 |
+
'pigeonhole',
|
1536 |
+
'chips',
|
1537 |
+
'food processor',
|
1538 |
+
'post box',
|
1539 |
+
'lid',
|
1540 |
+
'drum',
|
1541 |
+
'blender',
|
1542 |
+
'cave entrance',
|
1543 |
+
'dental chair',
|
1544 |
+
'obelisk',
|
1545 |
+
'canoe',
|
1546 |
+
'mobile',
|
1547 |
+
'monitors',
|
1548 |
+
'pool ball',
|
1549 |
+
'cue rack',
|
1550 |
+
'baggage carts',
|
1551 |
+
'fork',
|
1552 |
+
'paper filer',
|
1553 |
+
'bicycle rack',
|
1554 |
+
'coat rack',
|
1555 |
+
'garland',
|
1556 |
+
'sports bag',
|
1557 |
+
'fish tank',
|
1558 |
+
'towel dispenser',
|
1559 |
+
'carriage',
|
1560 |
+
'brochure',
|
1561 |
+
'plaque',
|
1562 |
+
'stringer',
|
1563 |
+
'iron',
|
1564 |
+
'spoon',
|
1565 |
+
'flag pole',
|
1566 |
+
'toilet brush',
|
1567 |
+
'book stand',
|
1568 |
+
'water faucet, water tap, tap, hydrant',
|
1569 |
+
'ticket office',
|
1570 |
+
'broom',
|
1571 |
+
'dvd',
|
1572 |
+
'ice bucket',
|
1573 |
+
'carapace, shell, cuticle, shield',
|
1574 |
+
'tureen',
|
1575 |
+
'folders',
|
1576 |
+
'chess',
|
1577 |
+
'root',
|
1578 |
+
'sewing machine',
|
1579 |
+
'model',
|
1580 |
+
'pen',
|
1581 |
+
'violin',
|
1582 |
+
'sweatshirt',
|
1583 |
+
'recycling materials',
|
1584 |
+
'mitten',
|
1585 |
+
'chopping board, cutting board',
|
1586 |
+
'mask',
|
1587 |
+
'log',
|
1588 |
+
'mouse, computer mouse',
|
1589 |
+
'grill',
|
1590 |
+
'hole',
|
1591 |
+
'target',
|
1592 |
+
'trash bag',
|
1593 |
+
'chalk',
|
1594 |
+
'sticks',
|
1595 |
+
'balloon',
|
1596 |
+
'score',
|
1597 |
+
'hair spray',
|
1598 |
+
'roll',
|
1599 |
+
'runner',
|
1600 |
+
'engine',
|
1601 |
+
'inflatable glove',
|
1602 |
+
'games',
|
1603 |
+
'pallets',
|
1604 |
+
'baskets',
|
1605 |
+
'coop',
|
1606 |
+
'dvd player',
|
1607 |
+
'rocking horse',
|
1608 |
+
'buckets',
|
1609 |
+
'bread rolls',
|
1610 |
+
'shawl',
|
1611 |
+
'watering can',
|
1612 |
+
'spotlights',
|
1613 |
+
'post-it',
|
1614 |
+
'bowls',
|
1615 |
+
'security camera',
|
1616 |
+
'runner cloth',
|
1617 |
+
'lock',
|
1618 |
+
'alarm, warning device, alarm system',
|
1619 |
+
'side',
|
1620 |
+
'roulette',
|
1621 |
+
'bone',
|
1622 |
+
'cutlery',
|
1623 |
+
'pool balls',
|
1624 |
+
'wheels',
|
1625 |
+
'spice rack',
|
1626 |
+
'plant pots',
|
1627 |
+
'towel ring',
|
1628 |
+
'bread box',
|
1629 |
+
'video',
|
1630 |
+
'funfair',
|
1631 |
+
'breads',
|
1632 |
+
'tripod',
|
1633 |
+
'ironing board',
|
1634 |
+
'skimmer',
|
1635 |
+
'hollow',
|
1636 |
+
'scratching post',
|
1637 |
+
'tricycle',
|
1638 |
+
'file box',
|
1639 |
+
'mountain pass',
|
1640 |
+
'tombstones',
|
1641 |
+
'cooker',
|
1642 |
+
'card game, cards',
|
1643 |
+
'golf bag',
|
1644 |
+
'towel paper',
|
1645 |
+
'chaise lounge',
|
1646 |
+
'sun',
|
1647 |
+
'toilet paper holder',
|
1648 |
+
'rake',
|
1649 |
+
'key',
|
1650 |
+
'umbrella stand',
|
1651 |
+
'dartboard',
|
1652 |
+
'transformer',
|
1653 |
+
'fireplace utensils',
|
1654 |
+
'sweatshirts',
|
1655 |
+
'cellular telephone, cellular phone, cellphone, cell, mobile phone',
|
1656 |
+
'tallboy',
|
1657 |
+
'stapler',
|
1658 |
+
'sauna',
|
1659 |
+
'test tube',
|
1660 |
+
'palette',
|
1661 |
+
'shopping carts',
|
1662 |
+
'tools',
|
1663 |
+
'push button, push, button',
|
1664 |
+
'star',
|
1665 |
+
'roof rack',
|
1666 |
+
'barbed wire',
|
1667 |
+
'spray',
|
1668 |
+
'ear',
|
1669 |
+
'sponge',
|
1670 |
+
'racket',
|
1671 |
+
'tins',
|
1672 |
+
'eyeglasses',
|
1673 |
+
'file',
|
1674 |
+
'scarfs',
|
1675 |
+
'sugar bowl',
|
1676 |
+
'flip flop',
|
1677 |
+
'headstones',
|
1678 |
+
'laptop bag',
|
1679 |
+
'leash',
|
1680 |
+
'climbing frame',
|
1681 |
+
'suit hanger',
|
1682 |
+
'floor spotlight',
|
1683 |
+
'plate rack',
|
1684 |
+
'sewer',
|
1685 |
+
'hard drive',
|
1686 |
+
'sprinkler',
|
1687 |
+
'tools box',
|
1688 |
+
'necklace',
|
1689 |
+
'bulbs',
|
1690 |
+
'steel industry',
|
1691 |
+
'club',
|
1692 |
+
'jack',
|
1693 |
+
'door bars',
|
1694 |
+
'control panel, instrument panel, control board, board, panel',
|
1695 |
+
'hairbrush',
|
1696 |
+
'napkin holder',
|
1697 |
+
'office',
|
1698 |
+
'smoke detector',
|
1699 |
+
'utensils',
|
1700 |
+
'apron',
|
1701 |
+
'scissors',
|
1702 |
+
'terminal',
|
1703 |
+
'grinder',
|
1704 |
+
'entry phone',
|
1705 |
+
'newspaper stand',
|
1706 |
+
'pepper shaker',
|
1707 |
+
'onions',
|
1708 |
+
(
|
1709 |
+
'central processing unit, cpu, c p u , central processor, processor,'
|
1710 |
+
' mainframe'
|
1711 |
+
),
|
1712 |
+
'tape',
|
1713 |
+
'bat',
|
1714 |
+
'coaster',
|
1715 |
+
'calculator',
|
1716 |
+
'potatoes',
|
1717 |
+
'luggage rack',
|
1718 |
+
'salt',
|
1719 |
+
'street number',
|
1720 |
+
'viewpoint',
|
1721 |
+
'sword',
|
1722 |
+
'cd',
|
1723 |
+
'rowing machine',
|
1724 |
+
'plug',
|
1725 |
+
'andiron, firedog, dog, dog-iron',
|
1726 |
+
'pepper',
|
1727 |
+
'tongs',
|
1728 |
+
'bonfire',
|
1729 |
+
'dog dish',
|
1730 |
+
'belt',
|
1731 |
+
'dumbbells',
|
1732 |
+
'videocassette recorder, vcr',
|
1733 |
+
'hook',
|
1734 |
+
'envelopes',
|
1735 |
+
'shower faucet',
|
1736 |
+
'watch',
|
1737 |
+
'padlock',
|
1738 |
+
'swimming pool ladder',
|
1739 |
+
'spanners',
|
1740 |
+
'gravy boat',
|
1741 |
+
'notice board',
|
1742 |
+
'trash bags',
|
1743 |
+
'fire alarm',
|
1744 |
+
'ladle',
|
1745 |
+
'stethoscope',
|
1746 |
+
'rocket',
|
1747 |
+
'funnel',
|
1748 |
+
'bowling pins',
|
1749 |
+
'valve',
|
1750 |
+
'thermometer',
|
1751 |
+
'cups',
|
1752 |
+
'spice jar',
|
1753 |
+
'night light',
|
1754 |
+
'soaps',
|
1755 |
+
'games table',
|
1756 |
+
'slotted spoon',
|
1757 |
+
'reel',
|
1758 |
+
'scourer',
|
1759 |
+
'sleeping robe',
|
1760 |
+
'desk mat',
|
1761 |
+
'dumbbell',
|
1762 |
+
'hammer',
|
1763 |
+
'tie',
|
1764 |
+
'typewriter',
|
1765 |
+
'shaker',
|
1766 |
+
'cheese dish',
|
1767 |
+
'sea star',
|
1768 |
+
'racquet',
|
1769 |
+
'butane gas cylinder',
|
1770 |
+
'paper weight',
|
1771 |
+
'shaving brush',
|
1772 |
+
'sunglasses',
|
1773 |
+
'gear shift',
|
1774 |
+
'towel rail',
|
1775 |
+
'adding machine, totalizer, totaliser',
|
1776 |
+
]
|
1777 |
+
|
1778 |
+
ADE_847_STUFF_CLASS_ID = [
|
1779 |
+
0, 2, 3, 4, 5, 8, 9, 12, 15, 16, 22, 24, 33, 47, 54, 368, 31, 195, 118, 134,
|
1780 |
+
136, 53, 143, 90, 435, 546, 624, 111, 304,
|
1781 |
+
]
|
1782 |
+
|
1783 |
+
ADE_847_THING_CLASS_ID = [
|
1784 |
+
i for i in ADE_847_CLASS_ID if i not in ADE_847_STUFF_CLASS_ID
|
1785 |
+
]
|
1786 |
+
|
1787 |
+
|
1788 |
+
class ADE847Dataset(Dataset):
|
1789 |
+
"""ADE847 dataset."""
|
1790 |
+
|
1791 |
+
def __init__(self, root, split='validation', transform=None):
|
1792 |
+
super(ADE847Dataset, self).__init__()
|
1793 |
+
self.root = root
|
1794 |
+
self.split = split
|
1795 |
+
self.transforms = transform
|
1796 |
+
self.image_dir = os.path.join(root, 'images_detectron2', split)
|
1797 |
+
self.mask_dir = os.path.join(root, 'annotations_detectron2', split)
|
1798 |
+
self.images = os.listdir(self.image_dir)
|
1799 |
+
|
1800 |
+
def process_mask(self, mask):
|
1801 |
+
mask = np.array(mask)
|
1802 |
+
mask[mask > 847] = 0
|
1803 |
+
return mask
|
1804 |
+
|
1805 |
+
def __getitem__(self, index):
|
1806 |
+
image_path = os.path.join(self.image_dir, self.images[index])
|
1807 |
+
image = Image.open(image_path).convert('RGB')
|
1808 |
+
target = (
|
1809 |
+
np.asarray(
|
1810 |
+
Image.open(
|
1811 |
+
os.path.join(
|
1812 |
+
self.mask_dir, self.images[index].replace('jpg', 'tif')
|
1813 |
+
)
|
1814 |
+
),
|
1815 |
+
dtype=np.int32,
|
1816 |
+
)
|
1817 |
+
+ 1
|
1818 |
+
)
|
1819 |
+
target = self.process_mask(target)
|
1820 |
+
|
1821 |
+
if self.transforms:
|
1822 |
+
image = self.transforms(image)
|
1823 |
+
|
1824 |
+
return image, image_path, target, index
|
1825 |
+
|
1826 |
+
def __len__(self):
|
1827 |
+
return len(self.images)
|
data/coco.py
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2024 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""COCO Stuff Dataset."""
|
17 |
+
|
18 |
+
import os
|
19 |
+
import numpy as np
|
20 |
+
from PIL import Image
|
21 |
+
import torch
|
22 |
+
|
23 |
+
|
24 |
+
COCO_OBJECT_CLASSES = [
|
25 |
+
'person with clothes,people,human',
|
26 |
+
'bicycle',
|
27 |
+
'car',
|
28 |
+
'motorbike',
|
29 |
+
'aeroplane',
|
30 |
+
'bus',
|
31 |
+
'train',
|
32 |
+
'truck',
|
33 |
+
'boat',
|
34 |
+
'traffic light',
|
35 |
+
'fire hydrant',
|
36 |
+
'stop sign',
|
37 |
+
'parking meter',
|
38 |
+
'bench',
|
39 |
+
'bird avian',
|
40 |
+
'cat',
|
41 |
+
'dog',
|
42 |
+
'horse',
|
43 |
+
'sheep',
|
44 |
+
'cow',
|
45 |
+
'elephant',
|
46 |
+
'bear',
|
47 |
+
'zebra',
|
48 |
+
'giraffe',
|
49 |
+
'backpack,bag',
|
50 |
+
'umbrella,parasol',
|
51 |
+
'handbag,purse',
|
52 |
+
'necktie',
|
53 |
+
'suitcase',
|
54 |
+
'frisbee',
|
55 |
+
'skis',
|
56 |
+
'sknowboard',
|
57 |
+
'sports ball',
|
58 |
+
'kite',
|
59 |
+
'baseball bat',
|
60 |
+
'glove',
|
61 |
+
'skateboard',
|
62 |
+
'surfboard',
|
63 |
+
'tennis racket',
|
64 |
+
'bottle',
|
65 |
+
'wine glass',
|
66 |
+
'cup',
|
67 |
+
'fork',
|
68 |
+
'knife',
|
69 |
+
'dessertspoon',
|
70 |
+
'bowl',
|
71 |
+
'banana',
|
72 |
+
'apple',
|
73 |
+
'sandwich',
|
74 |
+
'orange',
|
75 |
+
'broccoli',
|
76 |
+
'carrot',
|
77 |
+
'hot dog',
|
78 |
+
'pizza',
|
79 |
+
'donut',
|
80 |
+
'cake',
|
81 |
+
'chair seat',
|
82 |
+
'sofa',
|
83 |
+
'pottedplant',
|
84 |
+
'bed',
|
85 |
+
'diningtable',
|
86 |
+
'toilet',
|
87 |
+
'tvmonitor screen',
|
88 |
+
'laptop',
|
89 |
+
'mouse',
|
90 |
+
'remote control',
|
91 |
+
'keyboard',
|
92 |
+
'cell phone',
|
93 |
+
'microwave',
|
94 |
+
'oven',
|
95 |
+
'toaster',
|
96 |
+
'sink',
|
97 |
+
'refrigerator',
|
98 |
+
'book',
|
99 |
+
'clock',
|
100 |
+
'vase',
|
101 |
+
'scissors',
|
102 |
+
'teddy bear',
|
103 |
+
'hairdrier,blowdrier',
|
104 |
+
'toothbrush',
|
105 |
+
]
|
106 |
+
|
107 |
+
|
108 |
+
class COCODataset(torch.utils.data.Dataset):
|
109 |
+
"""COCO Object Dataset."""
|
110 |
+
|
111 |
+
def __init__(self, root, split='val', transform=None):
|
112 |
+
"""Construct COCO Object Dataset.
|
113 |
+
|
114 |
+
Args:
|
115 |
+
root (string): Root directory where images are downloaded.
|
116 |
+
split (string): Path to the annotation file.
|
117 |
+
transform (callable, optional): Optional transform to be applied on a
|
118 |
+
sample.
|
119 |
+
"""
|
120 |
+
self.root = root
|
121 |
+
self.image_dir = os.path.join(root, 'images', f'{split}2017')
|
122 |
+
self.ann_dir = os.path.join(root, 'annotations', f'{split}2017')
|
123 |
+
self.images = os.listdir(self.image_dir)
|
124 |
+
self.transform = transform
|
125 |
+
|
126 |
+
def __getitem__(self, index):
|
127 |
+
img_path = os.path.join(self.image_dir, self.images[index])
|
128 |
+
img = Image.open(img_path).convert('RGB')
|
129 |
+
img = np.asarray(img)
|
130 |
+
idx = self.images[index].split('.')[0]
|
131 |
+
ann_path = os.path.join(self.ann_dir, f'{idx}_instanceTrainIds.png')
|
132 |
+
ann = np.asarray(Image.open(ann_path), dtype=np.int32)
|
133 |
+
|
134 |
+
return img, img_path, ann, idx
|
135 |
+
|
136 |
+
def __len__(self):
|
137 |
+
return len(self.images)
|
data/context.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2024 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""Pascal Context Dataset."""
|
17 |
+
|
18 |
+
from typing import Any, List, Tuple
|
19 |
+
|
20 |
+
import numpy as np
|
21 |
+
from PIL import Image
|
22 |
+
# pylint: disable=g-importing-member
|
23 |
+
from torchvision.datasets.voc import _VOCBase
|
24 |
+
|
25 |
+
|
26 |
+
PASCAL_CONTEXT_CLASSES = [
|
27 |
+
'airplane', 'bag', 'bed', 'bedclothes', 'bench', 'bicycle', 'bird', 'boat',
|
28 |
+
'book', 'bottle', 'building', 'bus', 'cabinet', 'car', 'cat', 'ceiling',
|
29 |
+
'chair', 'cloth', 'computer', 'cow', 'cup', 'curtain', 'dog', 'door',
|
30 |
+
'fence', 'floor', 'flower', 'food', 'grass', 'ground', 'horse', 'keyboard',
|
31 |
+
'light', 'motorbike', 'mountain', 'mouse', 'person', 'plate', 'platform',
|
32 |
+
'plant', 'road', 'rock', 'sheep', 'shelves', 'sidewalk', 'sign', 'sky',
|
33 |
+
'snow', 'sofa', 'table', 'track', 'train', 'tree', 'truck', 'monitor',
|
34 |
+
'wall', 'water', 'window', 'wood']
|
35 |
+
|
36 |
+
PASCAL_CONTEXT_STUFF_CLASS = [
|
37 |
+
'bedclothes', 'ceiling', 'cloth', 'curtain', 'floor', 'grass', 'ground',
|
38 |
+
'light', 'mountain', 'platform', 'road', 'sidewalk', 'sky', 'snow', 'wall',
|
39 |
+
'water', 'window', 'wood', 'door', 'fence', 'rock']
|
40 |
+
|
41 |
+
PASCAL_CONTEXT_THING_CLASS = [
|
42 |
+
'airplane', 'bag', 'bed', 'bench', 'bicycle', 'bird', 'boat', 'book',
|
43 |
+
'bottle', 'building', 'bus', 'cabinet', 'car', 'cat', 'chair', 'computer',
|
44 |
+
'cow', 'cup', 'dog', 'flower', 'food', 'horse', 'keyboard', 'motorbike',
|
45 |
+
'mouse', 'person', 'plate', 'plant', 'sheep', 'shelves', 'sign', 'sofa',
|
46 |
+
'table', 'track', 'train', 'tree', 'truck', 'monitor']
|
47 |
+
|
48 |
+
PASCAL_CONTEXT_STUFF_CLASS_ID = [
|
49 |
+
3, 15, 17, 21, 25, 28, 29, 32, 34, 38, 40, 44, 46, 47, 55, 56, 57, 58, 23,
|
50 |
+
24, 41]
|
51 |
+
|
52 |
+
PASCAL_CONTEXT_THING_CLASS_ID = [
|
53 |
+
0, 1, 2, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 16, 18, 19, 20, 22, 26, 27,
|
54 |
+
30, 31, 33, 35, 36, 37, 39, 42, 43, 45, 48, 49, 50, 51, 52, 53, 54]
|
55 |
+
|
56 |
+
|
57 |
+
class CONTEXTSegmentation(_VOCBase):
|
58 |
+
"""Pascal VOC <http://host.robots.ox.ac.uk/pascal/VOC/> Segmentation Dataset.
|
59 |
+
|
60 |
+
Attributes:
|
61 |
+
root (string): Root directory of the VOC Dataset.
|
62 |
+
year (string, optional): The dataset year, supports years ``"2007"`` to
|
63 |
+
``"2012"``.
|
64 |
+
image_set (string, optional): Select the image_set to use, ``"train"``,
|
65 |
+
``"trainval"`` or ``"val"``. If ``year=="2007"``, can also be
|
66 |
+
``"test"``.
|
67 |
+
download (bool, optional): If true, downloads the dataset from the
|
68 |
+
internet and puts it in root directory. If dataset is already
|
69 |
+
downloaded, it is not downloaded again.
|
70 |
+
transform (callable, optional): A function/transform that takes in an PIL
|
71 |
+
image and returns a transformed version. E.g, ``transforms.RandomCrop``
|
72 |
+
target_transform (callable, optional): A function/transform that takes in
|
73 |
+
the target and transforms it.
|
74 |
+
transforms (callable, optional): A function/transform that takes input
|
75 |
+
sample and its target as entry and returns a transformed version.
|
76 |
+
"""
|
77 |
+
|
78 |
+
_SPLITS_DIR = 'SegmentationContext'
|
79 |
+
_TARGET_DIR = 'SegmentationClassContext'
|
80 |
+
_TARGET_FILE_EXT = '.png'
|
81 |
+
|
82 |
+
@property
|
83 |
+
def masks(self):
|
84 |
+
return self.targets
|
85 |
+
|
86 |
+
def __getitem__(self, index):
|
87 |
+
"""Get a sample of image and segmentation.
|
88 |
+
|
89 |
+
Args:
|
90 |
+
index (int): Index
|
91 |
+
Returns:
|
92 |
+
tuple: (image, target) where target is the image segmentation.
|
93 |
+
"""
|
94 |
+
img = Image.open(self.images[index]).convert('RGB')
|
95 |
+
target = Image.open(self.masks[index])
|
96 |
+
|
97 |
+
if self.transforms is not None:
|
98 |
+
img, target = self.transforms(img, target)
|
99 |
+
|
100 |
+
return img, target
|
101 |
+
|
102 |
+
|
103 |
+
class CONTEXTDataset(CONTEXTSegmentation):
|
104 |
+
"""Pascal Context Dataset."""
|
105 |
+
|
106 |
+
def __init__(self, root, year='2012', split='val', transform=None):
|
107 |
+
super(CONTEXTDataset, self).__init__(
|
108 |
+
root=root,
|
109 |
+
image_set=split,
|
110 |
+
year=year,
|
111 |
+
transform=transform,
|
112 |
+
download=False,
|
113 |
+
)
|
114 |
+
# self.idx_to_class = {val: key for (key, val) in CLASS2ID.items()}
|
115 |
+
|
116 |
+
def __getitem__(self, index):
|
117 |
+
image_path = self.images[index]
|
118 |
+
image = Image.open(image_path).convert('RGB')
|
119 |
+
target = np.asarray(Image.open(self.masks[index]), dtype=np.int32)
|
120 |
+
# transpose the target width and height
|
121 |
+
# target = target.transpose(1, 0)
|
122 |
+
|
123 |
+
if self.transforms:
|
124 |
+
image = self.transform(image)
|
125 |
+
|
126 |
+
return image, str(image_path), target, index
|
data/gres.py
ADDED
@@ -0,0 +1,455 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2024 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""grefer v0.1.
|
17 |
+
|
18 |
+
This interface provides access to gRefCOCO.
|
19 |
+
|
20 |
+
The following API functions are defined:
|
21 |
+
G_REFER - REFER api class
|
22 |
+
getRefIds - get ref ids that satisfy given filter conditions.
|
23 |
+
getAnnIds - get ann ids that satisfy given filter conditions.
|
24 |
+
getImgIds - get image ids that satisfy given filter conditions.
|
25 |
+
getCatIds - get category ids that satisfy given filter conditions.
|
26 |
+
loadRefs - load refs with the specified ref ids.
|
27 |
+
loadAnns - load anns with the specified ann ids.
|
28 |
+
loadImgs - load images with the specified image ids.
|
29 |
+
loadCats - load category names with the specified category ids.
|
30 |
+
getRefBox - get ref's bounding box [x, y, w, h] given the ref_id
|
31 |
+
showRef - show image, segmentation or box of the referred object with the
|
32 |
+
ref
|
33 |
+
getMaskByRef - get mask and area of the referred object given ref or ref ids
|
34 |
+
getMask - get mask and area of the referred object given ref
|
35 |
+
showMask - show mask of the referred object given ref
|
36 |
+
"""
|
37 |
+
# Adapted from
|
38 |
+
# https://github.com/yz93/LAVT-RIS/blob/main/data/dataset_refer_bert.py
|
39 |
+
|
40 |
+
# pylint: disable=all
|
41 |
+
import itertools
|
42 |
+
import json
|
43 |
+
import os
|
44 |
+
import os.path as osp
|
45 |
+
import pickle
|
46 |
+
import time
|
47 |
+
# pylint: disable=g-importing-member
|
48 |
+
from matplotlib.collections import PatchCollection
|
49 |
+
from matplotlib.patches import Polygon
|
50 |
+
from matplotlib.patches import Rectangle
|
51 |
+
import matplotlib.pyplot as plt
|
52 |
+
import numpy as np
|
53 |
+
from PIL import Image
|
54 |
+
from pycocotools import mask
|
55 |
+
from skimage import io
|
56 |
+
import torch
|
57 |
+
from torch.utils import data
|
58 |
+
|
59 |
+
|
60 |
+
class G_REFER:
|
61 |
+
"""GRES dataset."""
|
62 |
+
|
63 |
+
def __init__(self, data_root, dataset='grefcoco', splitBy='unc'):
|
64 |
+
# provide data_root folder which contains grefcoco
|
65 |
+
print('loading dataset %s into memory...' % dataset)
|
66 |
+
self.ROOT_DIR = osp.abspath(osp.dirname(__file__))
|
67 |
+
self.DATA_DIR = osp.join(data_root, dataset)
|
68 |
+
if dataset in ['grefcoco']:
|
69 |
+
self.IMAGE_DIR = osp.join(data_root, 'images/mscoco/images/train2014')
|
70 |
+
else:
|
71 |
+
raise KeyError('No refer dataset is called [%s]' % dataset)
|
72 |
+
|
73 |
+
tic = time.time()
|
74 |
+
|
75 |
+
# load refs from data/dataset/refs(dataset).json
|
76 |
+
self.data = {}
|
77 |
+
self.data['dataset'] = dataset
|
78 |
+
|
79 |
+
ref_file = osp.join(self.DATA_DIR, f'grefs({splitBy}).p')
|
80 |
+
if osp.exists(ref_file):
|
81 |
+
self.data['refs'] = pickle.load(open(ref_file, 'rb'), fix_imports=True)
|
82 |
+
else:
|
83 |
+
ref_file = osp.join(self.DATA_DIR, f'grefs({splitBy}).json')
|
84 |
+
if osp.exists(ref_file):
|
85 |
+
self.data['refs'] = json.load(open(ref_file, 'rb'))
|
86 |
+
else:
|
87 |
+
raise FileNotFoundError('JSON file not found')
|
88 |
+
|
89 |
+
# load annotations from data/dataset/instances.json
|
90 |
+
instances_file = osp.join(self.DATA_DIR, 'instances.json')
|
91 |
+
instances = json.load(open(instances_file, 'r'))
|
92 |
+
self.data['images'] = instances['images']
|
93 |
+
self.data['annotations'] = instances['annotations']
|
94 |
+
self.data['categories'] = instances['categories']
|
95 |
+
|
96 |
+
# create index
|
97 |
+
self.createIndex()
|
98 |
+
print('DONE (t=%.2fs)' % (time.time() - tic))
|
99 |
+
|
100 |
+
@staticmethod
|
101 |
+
def _toList(x):
|
102 |
+
return x if isinstance(x, list) else [x]
|
103 |
+
|
104 |
+
@staticmethod
|
105 |
+
def match_any(a, b):
|
106 |
+
a = a if isinstance(a, list) else [a]
|
107 |
+
b = b if isinstance(b, list) else [b]
|
108 |
+
return set(a) & set(b)
|
109 |
+
|
110 |
+
def createIndex(self):
|
111 |
+
# create sets of mapping
|
112 |
+
# 1) Refs: {ref_id: ref}
|
113 |
+
# 2) Anns: {ann_id: ann}
|
114 |
+
# 3) Imgs: {image_id: image}
|
115 |
+
# 4) Cats: {category_id: category_name}
|
116 |
+
# 5) Sents: {sent_id: sent}
|
117 |
+
# 6) imgToRefs: {image_id: refs}
|
118 |
+
# 7) imgToAnns: {image_id: anns}
|
119 |
+
# 8) refToAnn: {ref_id: ann}
|
120 |
+
# 9) annToRef: {ann_id: ref}
|
121 |
+
# 10) catToRefs: {category_id: refs}
|
122 |
+
# 11) sentToRef: {sent_id: ref}
|
123 |
+
# 12) sentToTokens: {sent_id: tokens}
|
124 |
+
print('creating index...')
|
125 |
+
# fetch info from instances
|
126 |
+
Anns, Imgs, Cats, imgToAnns = {}, {}, {}, {}
|
127 |
+
Anns[-1] = None
|
128 |
+
for ann in self.data['annotations']:
|
129 |
+
Anns[ann['id']] = ann
|
130 |
+
imgToAnns[ann['image_id']] = imgToAnns.get(ann['image_id'], []) + [ann]
|
131 |
+
for img in self.data['images']:
|
132 |
+
Imgs[img['id']] = img
|
133 |
+
for cat in self.data['categories']:
|
134 |
+
Cats[cat['id']] = cat['name']
|
135 |
+
|
136 |
+
# fetch info from refs
|
137 |
+
Refs, imgToRefs, refToAnn, annToRef, catToRefs = {}, {}, {}, {}, {}
|
138 |
+
Sents, sentToRef, sentToTokens = {}, {}, {}
|
139 |
+
availableSplits = []
|
140 |
+
for ref in self.data['refs']:
|
141 |
+
# ids
|
142 |
+
ref_id = ref['ref_id']
|
143 |
+
ann_id = ref['ann_id']
|
144 |
+
category_id = ref['category_id']
|
145 |
+
image_id = ref['image_id']
|
146 |
+
|
147 |
+
if ref['split'] not in availableSplits:
|
148 |
+
availableSplits.append(ref['split'])
|
149 |
+
|
150 |
+
# add mapping related to ref
|
151 |
+
if ref_id in Refs:
|
152 |
+
print('Duplicate ref id')
|
153 |
+
Refs[ref_id] = ref
|
154 |
+
imgToRefs[image_id] = imgToRefs.get(image_id, []) + [ref]
|
155 |
+
|
156 |
+
category_id = self._toList(category_id)
|
157 |
+
added_cats = []
|
158 |
+
for cat in category_id:
|
159 |
+
if cat not in added_cats:
|
160 |
+
added_cats.append(cat)
|
161 |
+
catToRefs[cat] = catToRefs.get(cat, []) + [ref]
|
162 |
+
|
163 |
+
ann_id = self._toList(ann_id)
|
164 |
+
refToAnn[ref_id] = [Anns[ann] for ann in ann_id]
|
165 |
+
for ann_id_n in ann_id:
|
166 |
+
annToRef[ann_id_n] = annToRef.get(ann_id_n, []) + [ref]
|
167 |
+
|
168 |
+
# add mapping of sent
|
169 |
+
for sent in ref['sentences']:
|
170 |
+
Sents[sent['sent_id']] = sent
|
171 |
+
sentToRef[sent['sent_id']] = ref
|
172 |
+
sentToTokens[sent['sent_id']] = sent['tokens']
|
173 |
+
|
174 |
+
# create class members
|
175 |
+
self.Refs = Refs
|
176 |
+
self.Anns = Anns
|
177 |
+
self.Imgs = Imgs
|
178 |
+
self.Cats = Cats
|
179 |
+
self.Sents = Sents
|
180 |
+
self.imgToRefs = imgToRefs
|
181 |
+
self.imgToAnns = imgToAnns
|
182 |
+
self.refToAnn = refToAnn
|
183 |
+
self.annToRef = annToRef
|
184 |
+
self.catToRefs = catToRefs
|
185 |
+
self.sentToRef = sentToRef
|
186 |
+
self.sentToTokens = sentToTokens
|
187 |
+
self.availableSplits = availableSplits
|
188 |
+
print('index created.')
|
189 |
+
|
190 |
+
def getRefIds(self, image_ids=[], cat_ids=[], split=[]):
|
191 |
+
image_ids = self._toList(image_ids)
|
192 |
+
cat_ids = self._toList(cat_ids)
|
193 |
+
split = self._toList(split)
|
194 |
+
|
195 |
+
for s in split:
|
196 |
+
if s not in self.availableSplits:
|
197 |
+
raise ValueError(f'Invalid split name: {s}')
|
198 |
+
|
199 |
+
refs = self.data['refs']
|
200 |
+
|
201 |
+
if len(image_ids) > 0:
|
202 |
+
lists = [self.imgToRefs[image_id] for image_id in image_ids]
|
203 |
+
refs = list(itertools.chain.from_iterable(lists))
|
204 |
+
if len(cat_ids) > 0:
|
205 |
+
refs = [
|
206 |
+
ref for ref in refs if self.match_any(ref['category_id'], cat_ids)
|
207 |
+
]
|
208 |
+
if len(split) > 0:
|
209 |
+
refs = [ref for ref in refs if ref['split'] in split]
|
210 |
+
|
211 |
+
ref_ids = [ref['ref_id'] for ref in refs]
|
212 |
+
return ref_ids
|
213 |
+
|
214 |
+
def getAnnIds(self, image_ids=[], ref_ids=[]):
|
215 |
+
image_ids = self._toList(image_ids)
|
216 |
+
ref_ids = self._toList(ref_ids)
|
217 |
+
|
218 |
+
if any([len(image_ids), len(ref_ids)]):
|
219 |
+
if len(image_ids) > 0:
|
220 |
+
lists = [
|
221 |
+
self.imgToAnns[image_id]
|
222 |
+
for image_id in image_ids
|
223 |
+
if image_id in self.imgToAnns
|
224 |
+
]
|
225 |
+
anns = list(itertools.chain.from_iterable(lists))
|
226 |
+
else:
|
227 |
+
anns = self.data['annotations']
|
228 |
+
ann_ids = [ann['id'] for ann in anns]
|
229 |
+
if len(ref_ids) > 0:
|
230 |
+
lists = [self.Refs[ref_id]['ann_id'] for ref_id in ref_ids]
|
231 |
+
anns_by_ref_id = list(itertools.chain.from_iterable(lists))
|
232 |
+
ann_ids = list(set(ann_ids).intersection(set(anns_by_ref_id)))
|
233 |
+
else:
|
234 |
+
ann_ids = [ann['id'] for ann in self.data['annotations']]
|
235 |
+
|
236 |
+
return ann_ids
|
237 |
+
|
238 |
+
def getImgIds(self, ref_ids=[]):
|
239 |
+
ref_ids = self._toList(ref_ids)
|
240 |
+
|
241 |
+
if len(ref_ids) > 0:
|
242 |
+
image_ids = list(
|
243 |
+
set([self.Refs[ref_id]['image_id'] for ref_id in ref_ids])
|
244 |
+
)
|
245 |
+
else:
|
246 |
+
image_ids = self.Imgs.keys()
|
247 |
+
return image_ids
|
248 |
+
|
249 |
+
def getCatIds(self):
|
250 |
+
return self.Cats.keys()
|
251 |
+
|
252 |
+
def loadRefs(self, ref_ids=[]):
|
253 |
+
return [self.Refs[ref_id] for ref_id in self._toList(ref_ids)]
|
254 |
+
|
255 |
+
def loadAnns(self, ann_ids=[]):
|
256 |
+
if isinstance(ann_ids, str):
|
257 |
+
ann_ids = int(ann_ids)
|
258 |
+
return [self.Anns[ann_id] for ann_id in self._toList(ann_ids)]
|
259 |
+
|
260 |
+
def loadImgs(self, image_ids=[]):
|
261 |
+
return [self.Imgs[image_id] for image_id in self._toList(image_ids)]
|
262 |
+
|
263 |
+
def loadCats(self, cat_ids=[]):
|
264 |
+
return [self.Cats[cat_id] for cat_id in self._toList(cat_ids)]
|
265 |
+
|
266 |
+
def getRefBox(self, ref_id):
|
267 |
+
anns = self.refToAnn[ref_id]
|
268 |
+
return [ann['bbox'] for ann in anns] # [x, y, w, h]
|
269 |
+
|
270 |
+
def showRef(self, ref, seg_box='seg'):
|
271 |
+
ax = plt.gca()
|
272 |
+
# show image
|
273 |
+
image = self.Imgs[ref['image_id']]
|
274 |
+
I = io.imread(osp.join(self.IMAGE_DIR, image['file_name']))
|
275 |
+
ax.imshow(I)
|
276 |
+
# show refer expression
|
277 |
+
for sid, sent in enumerate(ref['sentences']):
|
278 |
+
print('%s. %s' % (sid + 1, sent['sent']))
|
279 |
+
# show segmentations
|
280 |
+
if seg_box == 'seg':
|
281 |
+
ann_id = ref['ann_id']
|
282 |
+
ann = self.Anns[ann_id]
|
283 |
+
polygons = []
|
284 |
+
color = []
|
285 |
+
c = 'none'
|
286 |
+
if type(ann['segmentation'][0]) == list:
|
287 |
+
# polygon used for refcoco*
|
288 |
+
for seg in ann['segmentation']:
|
289 |
+
poly = np.array(seg).reshape((len(seg) / 2, 2))
|
290 |
+
polygons.append(Polygon(poly, True, alpha=0.4))
|
291 |
+
color.append(c)
|
292 |
+
p = PatchCollection(
|
293 |
+
polygons,
|
294 |
+
facecolors=color,
|
295 |
+
edgecolors=(1, 1, 0, 0),
|
296 |
+
linewidths=3,
|
297 |
+
alpha=1,
|
298 |
+
)
|
299 |
+
ax.add_collection(p) # thick yellow polygon
|
300 |
+
p = PatchCollection(
|
301 |
+
polygons,
|
302 |
+
facecolors=color,
|
303 |
+
edgecolors=(1, 0, 0, 0),
|
304 |
+
linewidths=1,
|
305 |
+
alpha=1,
|
306 |
+
)
|
307 |
+
ax.add_collection(p) # thin red polygon
|
308 |
+
else:
|
309 |
+
# mask used for refclef
|
310 |
+
rle = ann['segmentation']
|
311 |
+
m = mask.decode(rle)
|
312 |
+
img = np.ones((m.shape[0], m.shape[1], 3))
|
313 |
+
color_mask = np.array([2.0, 166.0, 101.0]) / 255
|
314 |
+
for i in range(3):
|
315 |
+
img[:, :, i] = color_mask[i]
|
316 |
+
ax.imshow(np.dstack((img, m * 0.5)))
|
317 |
+
# show bounding-box
|
318 |
+
elif seg_box == 'box':
|
319 |
+
# ann_id = ref['ann_id']
|
320 |
+
# ann = self.Anns[ann_id]
|
321 |
+
bbox = self.getRefBox(ref['ref_id'])
|
322 |
+
box_plot = Rectangle(
|
323 |
+
(bbox[0], bbox[1]),
|
324 |
+
bbox[2],
|
325 |
+
bbox[3],
|
326 |
+
fill=False,
|
327 |
+
edgecolor='green',
|
328 |
+
linewidth=3,
|
329 |
+
)
|
330 |
+
ax.add_patch(box_plot)
|
331 |
+
|
332 |
+
def getMask(self, ann):
|
333 |
+
if not ann:
|
334 |
+
return None
|
335 |
+
if ann['iscrowd']:
|
336 |
+
raise ValueError('Crowd object')
|
337 |
+
image = self.Imgs[ann['image_id']]
|
338 |
+
if type(ann['segmentation'][0]) == list: # polygon
|
339 |
+
rle = mask.frPyObjects(
|
340 |
+
ann['segmentation'], image['height'], image['width']
|
341 |
+
)
|
342 |
+
else:
|
343 |
+
rle = ann['segmentation']
|
344 |
+
|
345 |
+
m = mask.decode(rle)
|
346 |
+
# sometimes there are multiple binary map (corresponding to multiple segs)
|
347 |
+
m = np.sum(m, axis=2)
|
348 |
+
m = m.astype(np.uint8) # convert to np.uint8
|
349 |
+
# compute area
|
350 |
+
area = sum(mask.area(rle)) # should be close to ann['area']
|
351 |
+
return {'mask': m, 'area': area}
|
352 |
+
|
353 |
+
def getMaskByRef(self, ref=None, ref_id=None, merge=False):
|
354 |
+
if not ref and not ref_id:
|
355 |
+
raise ValueError
|
356 |
+
if ref:
|
357 |
+
ann_ids = ref['ann_id']
|
358 |
+
ref_id = ref['ref_id']
|
359 |
+
else:
|
360 |
+
ann_ids = self.getAnnIds(ref_ids=ref_id)
|
361 |
+
|
362 |
+
if ann_ids == [-1]:
|
363 |
+
img = self.Imgs[self.Refs[ref_id]['image_id']]
|
364 |
+
return {
|
365 |
+
'mask': np.zeros([img['height'], img['width']], dtype=np.uint8),
|
366 |
+
'empty': True,
|
367 |
+
}
|
368 |
+
|
369 |
+
anns = self.loadAnns(ann_ids)
|
370 |
+
mask_list = [self.getMask(ann) for ann in anns if not ann['iscrowd']]
|
371 |
+
|
372 |
+
if merge:
|
373 |
+
merged_masks = sum([mask['mask'] for mask in mask_list])
|
374 |
+
merged_masks[np.where(merged_masks > 1)] = 1
|
375 |
+
return {'mask': merged_masks, 'empty': False}
|
376 |
+
else:
|
377 |
+
return mask_list
|
378 |
+
|
379 |
+
def showMask(self, ref):
|
380 |
+
M = self.getMask(ref)
|
381 |
+
msk = M['mask']
|
382 |
+
ax = plt.gca()
|
383 |
+
ax.imshow(msk)
|
384 |
+
|
385 |
+
|
386 |
+
class GReferDataset(data.Dataset):
|
387 |
+
|
388 |
+
def __init__(self, root, transform=None, split='val'):
|
389 |
+
|
390 |
+
self.classes = []
|
391 |
+
self.image_transforms = transform
|
392 |
+
self.split = split
|
393 |
+
self.refer = G_REFER(root)
|
394 |
+
|
395 |
+
ref_ids = self.refer.getRefIds(split=self.split)
|
396 |
+
img_ids = self.refer.getImgIds(ref_ids)
|
397 |
+
|
398 |
+
all_imgs = self.refer.Imgs
|
399 |
+
self.imgs = list(all_imgs[i] for i in img_ids)
|
400 |
+
self.ref_ids = []
|
401 |
+
# print(len(ref_ids))
|
402 |
+
# print(len(self.imgs))
|
403 |
+
self.sentence_raw = []
|
404 |
+
# if we are testing on a dataset, test all sentences of an object;
|
405 |
+
# o/w, we are validating during training, randomly sample one sentence
|
406 |
+
# for efficiency
|
407 |
+
for r in ref_ids:
|
408 |
+
ref = self.refer.Refs[r]
|
409 |
+
# ref_sentences = []
|
410 |
+
# for i, (el, sent_id) in enumerate(zip(ref['sentences'],
|
411 |
+
# ref['sent_ids'])):
|
412 |
+
for el in ref['sentences']:
|
413 |
+
sentence_raw = el['raw']
|
414 |
+
if len(sentence_raw) == 0:
|
415 |
+
continue
|
416 |
+
self.sentence_raw.append(sentence_raw)
|
417 |
+
self.ref_ids.append(r)
|
418 |
+
|
419 |
+
# print(len(self.sentence_raw))
|
420 |
+
|
421 |
+
def get_classes(self):
|
422 |
+
return self.classes
|
423 |
+
|
424 |
+
def __len__(self):
|
425 |
+
return len(self.ref_ids)
|
426 |
+
|
427 |
+
def __getitem__(self, index):
|
428 |
+
this_ref_id = self.ref_ids[index]
|
429 |
+
this_img_id = self.refer.getImgIds(this_ref_id)
|
430 |
+
this_img = self.refer.Imgs[this_img_id[0]]
|
431 |
+
# print(this_ref_id, this_img_id)
|
432 |
+
# print(len(self.ref_ids))
|
433 |
+
img_path = os.path.join(self.refer.IMAGE_DIR, this_img['file_name'])
|
434 |
+
img = Image.open(img_path).convert('RGB')
|
435 |
+
ref = self.refer.loadRefs(this_ref_id)
|
436 |
+
# print("ref",ref)
|
437 |
+
|
438 |
+
ref_mask_ann = self.refer.getMaskByRef(ref[0])
|
439 |
+
if type(ref_mask_ann) == list:
|
440 |
+
ref_mask_ann = ref_mask_ann[0]
|
441 |
+
ref_mask = ref_mask_ann['mask']
|
442 |
+
annot = np.zeros(ref_mask.shape)
|
443 |
+
annot[ref_mask == 1] = 1
|
444 |
+
|
445 |
+
target = Image.fromarray(annot.astype(np.uint8), mode='P')
|
446 |
+
# print(np.array(target), np.unique(np.array(target).flatten()))
|
447 |
+
if self.image_transforms is not None:
|
448 |
+
# resize, from PIL to tensor, and mean and std normalization
|
449 |
+
img = self.image_transforms(img)
|
450 |
+
# target = self.target_transforms(target)
|
451 |
+
target = torch.as_tensor(np.array(target, copy=True))
|
452 |
+
# target = target.permute((2, 0, 1))
|
453 |
+
sentence = self.sentence_raw[index]
|
454 |
+
|
455 |
+
return img, img_path, target, sentence
|
data/pascal459.py
ADDED
@@ -0,0 +1,998 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2024 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""Pascal-459 Dataset."""
|
17 |
+
|
18 |
+
import os
|
19 |
+
import numpy as np
|
20 |
+
from PIL import Image
|
21 |
+
# pylint: disable=g-importing-member
|
22 |
+
from torch.utils.data import Dataset
|
23 |
+
|
24 |
+
|
25 |
+
PASCAL_459_CLASSES = [
|
26 |
+
'accordion',
|
27 |
+
'aeroplane',
|
28 |
+
'air conditioner',
|
29 |
+
'antenna',
|
30 |
+
'artillery',
|
31 |
+
'ashtray',
|
32 |
+
'atrium',
|
33 |
+
'baby carriage',
|
34 |
+
'bag',
|
35 |
+
'ball',
|
36 |
+
'balloon',
|
37 |
+
'bamboo weaving',
|
38 |
+
'barrel',
|
39 |
+
'baseball bat',
|
40 |
+
'basket',
|
41 |
+
'basketball backboard',
|
42 |
+
'bathtub',
|
43 |
+
'bed',
|
44 |
+
'bedclothes',
|
45 |
+
'beer',
|
46 |
+
'bell',
|
47 |
+
'bench',
|
48 |
+
'bicycle',
|
49 |
+
'binoculars',
|
50 |
+
'bird',
|
51 |
+
'bird cage',
|
52 |
+
'bird feeder',
|
53 |
+
'bird nest',
|
54 |
+
'blackboard',
|
55 |
+
'board',
|
56 |
+
'boat',
|
57 |
+
'bone',
|
58 |
+
'book',
|
59 |
+
'bottle',
|
60 |
+
'bottle opener',
|
61 |
+
'bowl',
|
62 |
+
'box',
|
63 |
+
'bracelet',
|
64 |
+
'brick',
|
65 |
+
'bridge',
|
66 |
+
'broom',
|
67 |
+
'brush',
|
68 |
+
'bucket',
|
69 |
+
'building',
|
70 |
+
'bus',
|
71 |
+
'cabinet',
|
72 |
+
'cabinet door',
|
73 |
+
'cage',
|
74 |
+
'cake',
|
75 |
+
'calculator',
|
76 |
+
'calendar',
|
77 |
+
'camel',
|
78 |
+
'camera',
|
79 |
+
'camera lens',
|
80 |
+
'can',
|
81 |
+
'candle',
|
82 |
+
'candle holder',
|
83 |
+
'cap',
|
84 |
+
'car',
|
85 |
+
'card',
|
86 |
+
'cart',
|
87 |
+
'case',
|
88 |
+
'casette recorder',
|
89 |
+
'cash register',
|
90 |
+
'cat',
|
91 |
+
'cd',
|
92 |
+
'cd player',
|
93 |
+
'ceiling',
|
94 |
+
'cell phone',
|
95 |
+
'cello',
|
96 |
+
'chain',
|
97 |
+
'chair',
|
98 |
+
'chessboard',
|
99 |
+
'chicken',
|
100 |
+
'chopstick',
|
101 |
+
'clip',
|
102 |
+
'clippers',
|
103 |
+
'clock',
|
104 |
+
'closet',
|
105 |
+
'cloth',
|
106 |
+
'clothes tree',
|
107 |
+
'coffee',
|
108 |
+
'coffee machine',
|
109 |
+
'comb',
|
110 |
+
'computer',
|
111 |
+
'concrete',
|
112 |
+
'cone',
|
113 |
+
'container',
|
114 |
+
'control booth',
|
115 |
+
'controller',
|
116 |
+
'cooker',
|
117 |
+
'copying machine',
|
118 |
+
'coral',
|
119 |
+
'cork',
|
120 |
+
'corkscrew',
|
121 |
+
'counter',
|
122 |
+
'court',
|
123 |
+
'cow',
|
124 |
+
'crabstick',
|
125 |
+
'crane',
|
126 |
+
'crate',
|
127 |
+
'cross',
|
128 |
+
'crutch',
|
129 |
+
'cup',
|
130 |
+
'curtain',
|
131 |
+
'cushion',
|
132 |
+
'cutting board',
|
133 |
+
'dais',
|
134 |
+
'disc',
|
135 |
+
'disc case',
|
136 |
+
'dishwasher',
|
137 |
+
'dock',
|
138 |
+
'dog',
|
139 |
+
'dolphin',
|
140 |
+
'door',
|
141 |
+
'drainer',
|
142 |
+
'dray',
|
143 |
+
'drink dispenser',
|
144 |
+
'drinking machine',
|
145 |
+
'drop',
|
146 |
+
'drug',
|
147 |
+
'drum',
|
148 |
+
'drum kit',
|
149 |
+
'duck',
|
150 |
+
'dumbbell',
|
151 |
+
'earphone',
|
152 |
+
'earrings',
|
153 |
+
'egg',
|
154 |
+
'electric fan',
|
155 |
+
'electric iron',
|
156 |
+
'electric pot',
|
157 |
+
'electric saw',
|
158 |
+
'electronic keyboard',
|
159 |
+
'engine',
|
160 |
+
'envelope',
|
161 |
+
'equipment',
|
162 |
+
'escalator',
|
163 |
+
'exhibition booth',
|
164 |
+
'extinguisher',
|
165 |
+
'eyeglass',
|
166 |
+
'fan',
|
167 |
+
'faucet',
|
168 |
+
'fax machine',
|
169 |
+
'fence',
|
170 |
+
'ferris wheel',
|
171 |
+
'fire extinguisher',
|
172 |
+
'fire hydrant',
|
173 |
+
'fire place',
|
174 |
+
'fish',
|
175 |
+
'fish tank',
|
176 |
+
'fishbowl',
|
177 |
+
'fishing net',
|
178 |
+
'fishing pole',
|
179 |
+
'flag',
|
180 |
+
'flagstaff',
|
181 |
+
'flame',
|
182 |
+
'flashlight',
|
183 |
+
'floor',
|
184 |
+
'flower',
|
185 |
+
'fly',
|
186 |
+
'foam',
|
187 |
+
'food',
|
188 |
+
'footbridge',
|
189 |
+
'forceps',
|
190 |
+
'fork',
|
191 |
+
'forklift',
|
192 |
+
'fountain',
|
193 |
+
'fox',
|
194 |
+
'frame',
|
195 |
+
'fridge',
|
196 |
+
'frog',
|
197 |
+
'fruit',
|
198 |
+
'funnel',
|
199 |
+
'furnace',
|
200 |
+
'game controller',
|
201 |
+
'game machine',
|
202 |
+
'gas cylinder',
|
203 |
+
'gas hood',
|
204 |
+
'gas stove',
|
205 |
+
'gift box',
|
206 |
+
'glass',
|
207 |
+
'glass marble',
|
208 |
+
'globe',
|
209 |
+
'glove',
|
210 |
+
'goal',
|
211 |
+
'grandstand',
|
212 |
+
'grass',
|
213 |
+
'gravestone',
|
214 |
+
'ground',
|
215 |
+
'guardrail',
|
216 |
+
'guitar',
|
217 |
+
'gun',
|
218 |
+
'hammer',
|
219 |
+
'hand cart',
|
220 |
+
'handle',
|
221 |
+
'handrail',
|
222 |
+
'hanger',
|
223 |
+
'hard disk drive',
|
224 |
+
'hat',
|
225 |
+
'hay',
|
226 |
+
'headphone',
|
227 |
+
'heater',
|
228 |
+
'helicopter',
|
229 |
+
'helmet',
|
230 |
+
'holder',
|
231 |
+
'hook',
|
232 |
+
'horse',
|
233 |
+
'horse-drawn carriage',
|
234 |
+
'hot-air balloon',
|
235 |
+
'hydrovalve',
|
236 |
+
'ice',
|
237 |
+
'inflator pump',
|
238 |
+
'ipod',
|
239 |
+
'iron',
|
240 |
+
'ironing board',
|
241 |
+
'jar',
|
242 |
+
'kart',
|
243 |
+
'kettle',
|
244 |
+
'key',
|
245 |
+
'keyboard',
|
246 |
+
'kitchen range',
|
247 |
+
'kite',
|
248 |
+
'knife',
|
249 |
+
'knife block',
|
250 |
+
'ladder',
|
251 |
+
'ladder truck',
|
252 |
+
'ladle',
|
253 |
+
'laptop',
|
254 |
+
'leaves',
|
255 |
+
'lid',
|
256 |
+
'life buoy',
|
257 |
+
'light',
|
258 |
+
'light bulb',
|
259 |
+
'lighter',
|
260 |
+
'line',
|
261 |
+
'lion',
|
262 |
+
'lobster',
|
263 |
+
'lock',
|
264 |
+
'machine',
|
265 |
+
'mailbox',
|
266 |
+
'mannequin',
|
267 |
+
'map',
|
268 |
+
'mask',
|
269 |
+
'mat',
|
270 |
+
'match book',
|
271 |
+
'mattress',
|
272 |
+
'menu',
|
273 |
+
'metal',
|
274 |
+
'meter box',
|
275 |
+
'microphone',
|
276 |
+
'microwave',
|
277 |
+
'mirror',
|
278 |
+
'missile',
|
279 |
+
'model',
|
280 |
+
'money',
|
281 |
+
'monkey',
|
282 |
+
'mop',
|
283 |
+
'motorbike',
|
284 |
+
'mountain',
|
285 |
+
'mouse',
|
286 |
+
'mouse pad',
|
287 |
+
'musical instrument',
|
288 |
+
'napkin',
|
289 |
+
'net',
|
290 |
+
'newspaper',
|
291 |
+
'oar',
|
292 |
+
'ornament',
|
293 |
+
'outlet',
|
294 |
+
'oven',
|
295 |
+
'oxygen bottle',
|
296 |
+
'pack',
|
297 |
+
'pan',
|
298 |
+
'paper',
|
299 |
+
'paper box',
|
300 |
+
'paper cutter',
|
301 |
+
'parachute',
|
302 |
+
'parasol',
|
303 |
+
'parterre',
|
304 |
+
'patio',
|
305 |
+
'pelage',
|
306 |
+
'pen',
|
307 |
+
'pen container',
|
308 |
+
'pencil',
|
309 |
+
'person',
|
310 |
+
'photo',
|
311 |
+
'piano',
|
312 |
+
'picture',
|
313 |
+
'pig',
|
314 |
+
'pillar',
|
315 |
+
'pillow',
|
316 |
+
'pipe',
|
317 |
+
'pitcher',
|
318 |
+
'plant',
|
319 |
+
'plastic',
|
320 |
+
'plate',
|
321 |
+
'platform',
|
322 |
+
'player',
|
323 |
+
'playground',
|
324 |
+
'pliers',
|
325 |
+
'plume',
|
326 |
+
'poker',
|
327 |
+
'poker chip',
|
328 |
+
'pole',
|
329 |
+
'pool table',
|
330 |
+
'postcard',
|
331 |
+
'poster',
|
332 |
+
'pot',
|
333 |
+
'pottedplant',
|
334 |
+
'printer',
|
335 |
+
'projector',
|
336 |
+
'pumpkin',
|
337 |
+
'rabbit',
|
338 |
+
'racket',
|
339 |
+
'radiator',
|
340 |
+
'radio',
|
341 |
+
'rail',
|
342 |
+
'rake',
|
343 |
+
'ramp',
|
344 |
+
'range hood',
|
345 |
+
'receiver',
|
346 |
+
'recorder',
|
347 |
+
'recreational machines',
|
348 |
+
'remote control',
|
349 |
+
'road',
|
350 |
+
'robot',
|
351 |
+
'rock',
|
352 |
+
'rocket',
|
353 |
+
'rocking horse',
|
354 |
+
'rope',
|
355 |
+
'rug',
|
356 |
+
'ruler',
|
357 |
+
'runway',
|
358 |
+
'saddle',
|
359 |
+
'sand',
|
360 |
+
'saw',
|
361 |
+
'scale',
|
362 |
+
'scanner',
|
363 |
+
'scissors',
|
364 |
+
'scoop',
|
365 |
+
'screen',
|
366 |
+
'screwdriver',
|
367 |
+
'sculpture',
|
368 |
+
'scythe',
|
369 |
+
'sewer',
|
370 |
+
'sewing machine',
|
371 |
+
'shed',
|
372 |
+
'sheep',
|
373 |
+
'shell',
|
374 |
+
'shelves',
|
375 |
+
'shoe',
|
376 |
+
'shopping cart',
|
377 |
+
'shovel',
|
378 |
+
'sidecar',
|
379 |
+
'sidewalk',
|
380 |
+
'sign',
|
381 |
+
'signal light',
|
382 |
+
'sink',
|
383 |
+
'skateboard',
|
384 |
+
'ski',
|
385 |
+
'sky',
|
386 |
+
'sled',
|
387 |
+
'slippers',
|
388 |
+
'smoke',
|
389 |
+
'snail',
|
390 |
+
'snake',
|
391 |
+
'snow',
|
392 |
+
'snowmobiles',
|
393 |
+
'sofa',
|
394 |
+
'spanner',
|
395 |
+
'spatula',
|
396 |
+
'speaker',
|
397 |
+
'speed bump',
|
398 |
+
'spice container',
|
399 |
+
'spoon',
|
400 |
+
'sprayer',
|
401 |
+
'squirrel',
|
402 |
+
'stage',
|
403 |
+
'stair',
|
404 |
+
'stapler',
|
405 |
+
'stick',
|
406 |
+
'sticky note',
|
407 |
+
'stone',
|
408 |
+
'stool',
|
409 |
+
'stove',
|
410 |
+
'straw',
|
411 |
+
'stretcher',
|
412 |
+
'sun',
|
413 |
+
'sunglass',
|
414 |
+
'sunshade',
|
415 |
+
'surveillance camera',
|
416 |
+
'swan',
|
417 |
+
'sweeper',
|
418 |
+
'swim ring',
|
419 |
+
'swimming pool',
|
420 |
+
'swing',
|
421 |
+
'switch',
|
422 |
+
'table',
|
423 |
+
'tableware',
|
424 |
+
'tank',
|
425 |
+
'tap',
|
426 |
+
'tape',
|
427 |
+
'tarp',
|
428 |
+
'telephone',
|
429 |
+
'telephone booth',
|
430 |
+
'tent',
|
431 |
+
'tire',
|
432 |
+
'toaster',
|
433 |
+
'toilet',
|
434 |
+
'tong',
|
435 |
+
'tool',
|
436 |
+
'toothbrush',
|
437 |
+
'towel',
|
438 |
+
'toy',
|
439 |
+
'toy car',
|
440 |
+
'track',
|
441 |
+
'train',
|
442 |
+
'trampoline',
|
443 |
+
'trash bin',
|
444 |
+
'tray',
|
445 |
+
'tree',
|
446 |
+
'tricycle',
|
447 |
+
'tripod',
|
448 |
+
'trophy',
|
449 |
+
'truck',
|
450 |
+
'tube',
|
451 |
+
'turtle',
|
452 |
+
'tvmonitor',
|
453 |
+
'tweezers',
|
454 |
+
'typewriter',
|
455 |
+
'umbrella',
|
456 |
+
'unknown',
|
457 |
+
'vacuum cleaner',
|
458 |
+
'vending machine',
|
459 |
+
'video camera',
|
460 |
+
'video game console',
|
461 |
+
'video player',
|
462 |
+
'video tape',
|
463 |
+
'violin',
|
464 |
+
'wakeboard',
|
465 |
+
'wall',
|
466 |
+
'wallet',
|
467 |
+
'wardrobe',
|
468 |
+
'washing machine',
|
469 |
+
'watch',
|
470 |
+
'water',
|
471 |
+
'water dispenser',
|
472 |
+
'water pipe',
|
473 |
+
'water skate board',
|
474 |
+
'watermelon',
|
475 |
+
'whale',
|
476 |
+
'wharf',
|
477 |
+
'wheel',
|
478 |
+
'wheelchair',
|
479 |
+
'window',
|
480 |
+
'window blinds',
|
481 |
+
'wineglass',
|
482 |
+
'wire',
|
483 |
+
'wood',
|
484 |
+
'wool',
|
485 |
+
]
|
486 |
+
|
487 |
+
PASCAL_459_CLASSE_ID = list(range(459))
|
488 |
+
|
489 |
+
|
490 |
+
PASCAL_459_STUFF_CLASS = [
|
491 |
+
'atrium',
|
492 |
+
'ceiling',
|
493 |
+
'concrete',
|
494 |
+
'coral',
|
495 |
+
'court',
|
496 |
+
'dock',
|
497 |
+
'floor',
|
498 |
+
'foam',
|
499 |
+
'grass',
|
500 |
+
'ground',
|
501 |
+
'ice',
|
502 |
+
'leaves',
|
503 |
+
'mountain',
|
504 |
+
'parterre',
|
505 |
+
'patio',
|
506 |
+
'road',
|
507 |
+
'rock',
|
508 |
+
'rug',
|
509 |
+
'sand',
|
510 |
+
'sky',
|
511 |
+
'snow',
|
512 |
+
'stone',
|
513 |
+
'sun',
|
514 |
+
'wall',
|
515 |
+
'water',
|
516 |
+
'wood',
|
517 |
+
]
|
518 |
+
|
519 |
+
PASCAL_459_THING_CLASS = [
|
520 |
+
'accordion',
|
521 |
+
'aeroplane',
|
522 |
+
'air conditioner',
|
523 |
+
'antenna',
|
524 |
+
'artillery',
|
525 |
+
'ashtray',
|
526 |
+
'baby carriage',
|
527 |
+
'bag',
|
528 |
+
'ball',
|
529 |
+
'balloon',
|
530 |
+
'bamboo weaving',
|
531 |
+
'barrel',
|
532 |
+
'baseball bat',
|
533 |
+
'basket',
|
534 |
+
'basketball backboard',
|
535 |
+
'bathtub',
|
536 |
+
'bed',
|
537 |
+
'bedclothes',
|
538 |
+
'beer',
|
539 |
+
'bell',
|
540 |
+
'bench',
|
541 |
+
'bicycle',
|
542 |
+
'binoculars',
|
543 |
+
'bird',
|
544 |
+
'bird cage',
|
545 |
+
'bird feeder',
|
546 |
+
'bird nest',
|
547 |
+
'blackboard',
|
548 |
+
'board',
|
549 |
+
'boat',
|
550 |
+
'bone',
|
551 |
+
'book',
|
552 |
+
'bottle',
|
553 |
+
'bottle opener',
|
554 |
+
'bowl',
|
555 |
+
'box',
|
556 |
+
'bracelet',
|
557 |
+
'brick',
|
558 |
+
'bridge',
|
559 |
+
'broom',
|
560 |
+
'brush',
|
561 |
+
'bucket',
|
562 |
+
'building',
|
563 |
+
'bus',
|
564 |
+
'cabinet',
|
565 |
+
'cabinet door',
|
566 |
+
'cage',
|
567 |
+
'cake',
|
568 |
+
'calculator',
|
569 |
+
'calendar',
|
570 |
+
'camel',
|
571 |
+
'camera',
|
572 |
+
'camera lens',
|
573 |
+
'can',
|
574 |
+
'candle',
|
575 |
+
'candle holder',
|
576 |
+
'cap',
|
577 |
+
'car',
|
578 |
+
'card',
|
579 |
+
'cart',
|
580 |
+
'case',
|
581 |
+
'casette recorder',
|
582 |
+
'cash register',
|
583 |
+
'cat',
|
584 |
+
'cd',
|
585 |
+
'cd player',
|
586 |
+
'cell phone',
|
587 |
+
'cello',
|
588 |
+
'chain',
|
589 |
+
'chair',
|
590 |
+
'chessboard',
|
591 |
+
'chicken',
|
592 |
+
'chopstick',
|
593 |
+
'clip',
|
594 |
+
'clippers',
|
595 |
+
'clock',
|
596 |
+
'closet',
|
597 |
+
'cloth',
|
598 |
+
'clothes tree',
|
599 |
+
'coffee',
|
600 |
+
'coffee machine',
|
601 |
+
'comb',
|
602 |
+
'computer',
|
603 |
+
'cone',
|
604 |
+
'container',
|
605 |
+
'control booth',
|
606 |
+
'controller',
|
607 |
+
'cooker',
|
608 |
+
'copying machine',
|
609 |
+
'cork',
|
610 |
+
'corkscrew',
|
611 |
+
'counter',
|
612 |
+
'cow',
|
613 |
+
'crabstick',
|
614 |
+
'crane',
|
615 |
+
'crate',
|
616 |
+
'cross',
|
617 |
+
'crutch',
|
618 |
+
'cup',
|
619 |
+
'curtain',
|
620 |
+
'cushion',
|
621 |
+
'cutting board',
|
622 |
+
'dais',
|
623 |
+
'disc',
|
624 |
+
'disc case',
|
625 |
+
'dishwasher',
|
626 |
+
'dog',
|
627 |
+
'dolphin',
|
628 |
+
'door',
|
629 |
+
'drainer',
|
630 |
+
'dray',
|
631 |
+
'drink dispenser',
|
632 |
+
'drinking machine',
|
633 |
+
'drop',
|
634 |
+
'drug',
|
635 |
+
'drum',
|
636 |
+
'drum kit',
|
637 |
+
'duck',
|
638 |
+
'dumbbell',
|
639 |
+
'earphone',
|
640 |
+
'earrings',
|
641 |
+
'egg',
|
642 |
+
'electric fan',
|
643 |
+
'electric iron',
|
644 |
+
'electric pot',
|
645 |
+
'electric saw',
|
646 |
+
'electronic keyboard',
|
647 |
+
'engine',
|
648 |
+
'envelope',
|
649 |
+
'equipment',
|
650 |
+
'escalator',
|
651 |
+
'exhibition booth',
|
652 |
+
'extinguisher',
|
653 |
+
'eyeglass',
|
654 |
+
'fan',
|
655 |
+
'faucet',
|
656 |
+
'fax machine',
|
657 |
+
'fence',
|
658 |
+
'ferris wheel',
|
659 |
+
'fire extinguisher',
|
660 |
+
'fire hydrant',
|
661 |
+
'fire place',
|
662 |
+
'fish',
|
663 |
+
'fish tank',
|
664 |
+
'fishbowl',
|
665 |
+
'fishing net',
|
666 |
+
'fishing pole',
|
667 |
+
'flag',
|
668 |
+
'flagstaff',
|
669 |
+
'flame',
|
670 |
+
'flashlight',
|
671 |
+
'flower',
|
672 |
+
'fly',
|
673 |
+
'food',
|
674 |
+
'footbridge',
|
675 |
+
'forceps',
|
676 |
+
'fork',
|
677 |
+
'forklift',
|
678 |
+
'fountain',
|
679 |
+
'fox',
|
680 |
+
'frame',
|
681 |
+
'fridge',
|
682 |
+
'frog',
|
683 |
+
'fruit',
|
684 |
+
'funnel',
|
685 |
+
'furnace',
|
686 |
+
'game controller',
|
687 |
+
'game machine',
|
688 |
+
'gas cylinder',
|
689 |
+
'gas hood',
|
690 |
+
'gas stove',
|
691 |
+
'gift box',
|
692 |
+
'glass',
|
693 |
+
'glass marble',
|
694 |
+
'globe',
|
695 |
+
'glove',
|
696 |
+
'goal',
|
697 |
+
'grandstand',
|
698 |
+
'gravestone',
|
699 |
+
'guardrail',
|
700 |
+
'guitar',
|
701 |
+
'gun',
|
702 |
+
'hammer',
|
703 |
+
'hand cart',
|
704 |
+
'handle',
|
705 |
+
'handrail',
|
706 |
+
'hanger',
|
707 |
+
'hard disk drive',
|
708 |
+
'hat',
|
709 |
+
'hay',
|
710 |
+
'headphone',
|
711 |
+
'heater',
|
712 |
+
'helicopter',
|
713 |
+
'helmet',
|
714 |
+
'holder',
|
715 |
+
'hook',
|
716 |
+
'horse',
|
717 |
+
'horse-drawn carriage',
|
718 |
+
'hot-air balloon',
|
719 |
+
'hydrovalve',
|
720 |
+
'inflator pump',
|
721 |
+
'ipod',
|
722 |
+
'iron',
|
723 |
+
'ironing board',
|
724 |
+
'jar',
|
725 |
+
'kart',
|
726 |
+
'kettle',
|
727 |
+
'key',
|
728 |
+
'keyboard',
|
729 |
+
'kitchen range',
|
730 |
+
'kite',
|
731 |
+
'knife',
|
732 |
+
'knife block',
|
733 |
+
'ladder',
|
734 |
+
'ladder truck',
|
735 |
+
'ladle',
|
736 |
+
'laptop',
|
737 |
+
'lid',
|
738 |
+
'life buoy',
|
739 |
+
'light',
|
740 |
+
'light bulb',
|
741 |
+
'lighter',
|
742 |
+
'line',
|
743 |
+
'lion',
|
744 |
+
'lobster',
|
745 |
+
'lock',
|
746 |
+
'machine',
|
747 |
+
'mailbox',
|
748 |
+
'mannequin',
|
749 |
+
'map',
|
750 |
+
'mask',
|
751 |
+
'mat',
|
752 |
+
'match book',
|
753 |
+
'mattress',
|
754 |
+
'menu',
|
755 |
+
'metal',
|
756 |
+
'meter box',
|
757 |
+
'microphone',
|
758 |
+
'microwave',
|
759 |
+
'mirror',
|
760 |
+
'missile',
|
761 |
+
'model',
|
762 |
+
'money',
|
763 |
+
'monkey',
|
764 |
+
'mop',
|
765 |
+
'motorbike',
|
766 |
+
'mouse',
|
767 |
+
'mouse pad',
|
768 |
+
'musical instrument',
|
769 |
+
'napkin',
|
770 |
+
'net',
|
771 |
+
'newspaper',
|
772 |
+
'oar',
|
773 |
+
'ornament',
|
774 |
+
'outlet',
|
775 |
+
'oven',
|
776 |
+
'oxygen bottle',
|
777 |
+
'pack',
|
778 |
+
'pan',
|
779 |
+
'paper',
|
780 |
+
'paper box',
|
781 |
+
'paper cutter',
|
782 |
+
'parachute',
|
783 |
+
'parasol',
|
784 |
+
'pelage',
|
785 |
+
'pen',
|
786 |
+
'pen container',
|
787 |
+
'pencil',
|
788 |
+
'person',
|
789 |
+
'photo',
|
790 |
+
'piano',
|
791 |
+
'picture',
|
792 |
+
'pig',
|
793 |
+
'pillar',
|
794 |
+
'pillow',
|
795 |
+
'pipe',
|
796 |
+
'pitcher',
|
797 |
+
'plant',
|
798 |
+
'plastic',
|
799 |
+
'plate',
|
800 |
+
'platform',
|
801 |
+
'player',
|
802 |
+
'playground',
|
803 |
+
'pliers',
|
804 |
+
'plume',
|
805 |
+
'poker',
|
806 |
+
'poker chip',
|
807 |
+
'pole',
|
808 |
+
'pool table',
|
809 |
+
'postcard',
|
810 |
+
'poster',
|
811 |
+
'pot',
|
812 |
+
'pottedplant',
|
813 |
+
'printer',
|
814 |
+
'projector',
|
815 |
+
'pumpkin',
|
816 |
+
'rabbit',
|
817 |
+
'racket',
|
818 |
+
'radiator',
|
819 |
+
'radio',
|
820 |
+
'rail',
|
821 |
+
'rake',
|
822 |
+
'ramp',
|
823 |
+
'range hood',
|
824 |
+
'receiver',
|
825 |
+
'recorder',
|
826 |
+
'recreational machines',
|
827 |
+
'remote control',
|
828 |
+
'robot',
|
829 |
+
'rocket',
|
830 |
+
'rocking horse',
|
831 |
+
'rope',
|
832 |
+
'ruler',
|
833 |
+
'runway',
|
834 |
+
'saddle',
|
835 |
+
'saw',
|
836 |
+
'scale',
|
837 |
+
'scanner',
|
838 |
+
'scissors',
|
839 |
+
'scoop',
|
840 |
+
'screen',
|
841 |
+
'screwdriver',
|
842 |
+
'sculpture',
|
843 |
+
'scythe',
|
844 |
+
'sewer',
|
845 |
+
'sewing machine',
|
846 |
+
'shed',
|
847 |
+
'sheep',
|
848 |
+
'shell',
|
849 |
+
'shelves',
|
850 |
+
'shoe',
|
851 |
+
'shopping cart',
|
852 |
+
'shovel',
|
853 |
+
'sidecar',
|
854 |
+
'sidewalk',
|
855 |
+
'sign',
|
856 |
+
'signal light',
|
857 |
+
'sink',
|
858 |
+
'skateboard',
|
859 |
+
'ski',
|
860 |
+
'sled',
|
861 |
+
'slippers',
|
862 |
+
'smoke',
|
863 |
+
'snail',
|
864 |
+
'snake',
|
865 |
+
'snowmobiles',
|
866 |
+
'sofa',
|
867 |
+
'spanner',
|
868 |
+
'spatula',
|
869 |
+
'speaker',
|
870 |
+
'speed bump',
|
871 |
+
'spice container',
|
872 |
+
'spoon',
|
873 |
+
'sprayer',
|
874 |
+
'squirrel',
|
875 |
+
'stage',
|
876 |
+
'stair',
|
877 |
+
'stapler',
|
878 |
+
'stick',
|
879 |
+
'sticky note',
|
880 |
+
'stool',
|
881 |
+
'stove',
|
882 |
+
'straw',
|
883 |
+
'stretcher',
|
884 |
+
'sunglass',
|
885 |
+
'sunshade',
|
886 |
+
'surveillance camera',
|
887 |
+
'swan',
|
888 |
+
'sweeper',
|
889 |
+
'swim ring',
|
890 |
+
'swimming pool',
|
891 |
+
'swing',
|
892 |
+
'switch',
|
893 |
+
'table',
|
894 |
+
'tableware',
|
895 |
+
'tank',
|
896 |
+
'tap',
|
897 |
+
'tape',
|
898 |
+
'tarp',
|
899 |
+
'telephone',
|
900 |
+
'telephone booth',
|
901 |
+
'tent',
|
902 |
+
'tire',
|
903 |
+
'toaster',
|
904 |
+
'toilet',
|
905 |
+
'tong',
|
906 |
+
'tool',
|
907 |
+
'toothbrush',
|
908 |
+
'towel',
|
909 |
+
'toy',
|
910 |
+
'toy car',
|
911 |
+
'track',
|
912 |
+
'train',
|
913 |
+
'trampoline',
|
914 |
+
'trash bin',
|
915 |
+
'tray',
|
916 |
+
'tree',
|
917 |
+
'tricycle',
|
918 |
+
'tripod',
|
919 |
+
'trophy',
|
920 |
+
'truck',
|
921 |
+
'tube',
|
922 |
+
'turtle',
|
923 |
+
'tvmonitor',
|
924 |
+
'tweezers',
|
925 |
+
'typewriter',
|
926 |
+
'umbrella',
|
927 |
+
'unknown',
|
928 |
+
'vacuum cleaner',
|
929 |
+
'vending machine',
|
930 |
+
'video camera',
|
931 |
+
'video game console',
|
932 |
+
'video player',
|
933 |
+
'video tape',
|
934 |
+
'violin',
|
935 |
+
'wakeboard',
|
936 |
+
'wallet',
|
937 |
+
'wardrobe',
|
938 |
+
'washing machine',
|
939 |
+
'watch',
|
940 |
+
'water dispenser',
|
941 |
+
'water pipe',
|
942 |
+
'water skate board',
|
943 |
+
'watermelon',
|
944 |
+
'whale',
|
945 |
+
'wharf',
|
946 |
+
'wheel',
|
947 |
+
'wheelchair',
|
948 |
+
'window',
|
949 |
+
'window blinds',
|
950 |
+
'wineglass',
|
951 |
+
'wire',
|
952 |
+
'wool',
|
953 |
+
]
|
954 |
+
|
955 |
+
PASCAL_459_STUFF_CLASS_ID = [
|
956 |
+
6, 67, 85, 92, 96, 111, 157, 160, 186, 188, 210, 228, 258, 277, 278, 323,
|
957 |
+
325, 329, 333, 359, 365, 381, 386, 439, 444, 457,
|
958 |
+
]
|
959 |
+
|
960 |
+
PASCAL_459_THING_CLASS_ID = [
|
961 |
+
i for i in range(459) if i not in PASCAL_459_STUFF_CLASS_ID
|
962 |
+
]
|
963 |
+
|
964 |
+
|
965 |
+
class Pascal459Dataset(Dataset):
|
966 |
+
"""PASCAL 459 dataset."""
|
967 |
+
|
968 |
+
def __init__(self, root, split='validation', transform=None):
|
969 |
+
super(Pascal459Dataset, self).__init__()
|
970 |
+
self.root = root
|
971 |
+
self.split = split
|
972 |
+
self.transforms = transform
|
973 |
+
self.image_dir = os.path.join(root, 'images', split)
|
974 |
+
self.mask_dir = os.path.join(root, 'annotations_ctx459', split)
|
975 |
+
self.images = os.listdir(self.image_dir)
|
976 |
+
|
977 |
+
def __getitem__(self, index):
|
978 |
+
image_path = os.path.join(self.image_dir, self.images[index])
|
979 |
+
image = Image.open(image_path).convert('RGB')
|
980 |
+
target = (
|
981 |
+
np.asarray(
|
982 |
+
Image.open(
|
983 |
+
os.path.join(
|
984 |
+
self.mask_dir, self.images[index].replace('jpg', 'tif')
|
985 |
+
)
|
986 |
+
),
|
987 |
+
dtype=np.int32,
|
988 |
+
)
|
989 |
+
+ 1
|
990 |
+
)
|
991 |
+
|
992 |
+
if self.transforms:
|
993 |
+
image = self.transforms(image)
|
994 |
+
|
995 |
+
return image, image_path, target, index
|
996 |
+
|
997 |
+
def __len__(self):
|
998 |
+
return len(self.images)
|
data/preprocess.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2024 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""Preprocess for referring datasets.
|
17 |
+
|
18 |
+
Adapted from
|
19 |
+
https://github.com/yz93/LAVT-RIS/blob/main/data/dataset_refer_bert.py
|
20 |
+
"""
|
21 |
+
# pylint: disable=all
|
22 |
+
from refer.refer import REFER
|
23 |
+
from torch.utils import data
|
24 |
+
|
25 |
+
|
26 |
+
class ReferDataset(data.Dataset):
|
27 |
+
"""Refer dataset."""
|
28 |
+
|
29 |
+
def __init__(
|
30 |
+
self,
|
31 |
+
root,
|
32 |
+
dataset='refcoco',
|
33 |
+
splitBy='unc',
|
34 |
+
image_transforms=None,
|
35 |
+
target_transforms=None,
|
36 |
+
split='train',
|
37 |
+
eval_mode=False,
|
38 |
+
):
|
39 |
+
|
40 |
+
self.classes = []
|
41 |
+
self.image_transforms = image_transforms
|
42 |
+
self.target_transforms = target_transforms
|
43 |
+
self.split = split
|
44 |
+
self.refer = REFER(root, dataset=dataset, splitBy=splitBy)
|
45 |
+
|
46 |
+
ref_ids = self.refer.getRefIds(split=self.split)
|
47 |
+
img_ids = self.refer.getImgIds(ref_ids)
|
48 |
+
|
49 |
+
all_imgs = self.refer.Imgs
|
50 |
+
self.imgs = list(all_imgs[i] for i in img_ids)
|
51 |
+
self.ref_ids = ref_ids
|
52 |
+
print(len(ref_ids))
|
53 |
+
print(len(self.imgs))
|
54 |
+
# print(self.imgs)
|
55 |
+
self.sentence_raw = []
|
56 |
+
|
57 |
+
self.eval_mode = eval_mode
|
58 |
+
# if we are testing on a dataset, test all sentences of an object;
|
59 |
+
# o/w, we are validating during training, randomly sample one sentence for
|
60 |
+
# efficiency
|
61 |
+
for r in ref_ids:
|
62 |
+
ref = self.refer.Refs[r]
|
63 |
+
ref_sentences = []
|
64 |
+
for el, _ in zip(ref['sentences'], ref['sent_ids']):
|
65 |
+
sentence_raw = el['raw']
|
66 |
+
ref_sentences.append(sentence_raw)
|
67 |
+
|
68 |
+
self.sentence_raw.append(ref_sentences)
|
69 |
+
# print(len(self.sentence_raw))
|
70 |
+
|
71 |
+
def get_classes(self):
|
72 |
+
return self.classes
|
73 |
+
|
74 |
+
def __len__(self):
|
75 |
+
return len(self.imgs)
|
76 |
+
|
77 |
+
def __getitem__(self, index):
|
78 |
+
this_img_id = self.imgs[index]['id']
|
79 |
+
this_ref_ids = self.refer.getRefIds(this_img_id)
|
80 |
+
this_img = self.refer.Imgs[this_img_id]
|
81 |
+
refs = [self.refer.loadRefs(this_ref_id) for this_ref_id in this_ref_ids]
|
82 |
+
|
83 |
+
batch_sentences = {}
|
84 |
+
# batch_targets = {}
|
85 |
+
for ref in refs:
|
86 |
+
# Get sentence
|
87 |
+
sentence_lis = []
|
88 |
+
for el, _ in zip(ref[0]['sentences'], ref[0]['sent_ids']):
|
89 |
+
sentence_raw = el['raw']
|
90 |
+
sentence_lis.append(sentence_raw)
|
91 |
+
batch_sentences.update({ref[0]['ref_id']: sentence_lis})
|
92 |
+
|
93 |
+
return [this_img['file_name']], batch_sentences
|
94 |
+
|
95 |
+
def get_ref(self):
|
96 |
+
name_lis = []
|
97 |
+
for i in range(len(self.ref_ids)):
|
98 |
+
rid = self.ref_ids[i]
|
99 |
+
# print(rid)
|
100 |
+
ref = self.refer.loadRefs(rid)
|
101 |
+
if ref[0]['file_name'] == '':
|
102 |
+
print(1)
|
103 |
+
# print(ref[0]['file_name'])
|
104 |
+
# if ref[0]['file_name'] in name_lis:
|
105 |
+
# print("md")
|
106 |
+
name_lis.append(ref[0]['file_name'])
|
107 |
+
print(ref[0]['file_name'])
|
108 |
+
# print(name_lis)
|
109 |
+
print(len(name_lis))
|
110 |
+
print(len(list(set(name_lis))))
|
data/refcoco.py
ADDED
@@ -0,0 +1,449 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2024 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""RefCOCO dataset."""
|
17 |
+
|
18 |
+
# Adapted from
|
19 |
+
# https://github.com/yz93/LAVT-RIS/blob/main/data/dataset_refer_bert.py
|
20 |
+
# pylint: disable=all
|
21 |
+
import itertools
|
22 |
+
import json
|
23 |
+
import os
|
24 |
+
import os.path as osp
|
25 |
+
import pickle as pickle
|
26 |
+
import sys
|
27 |
+
import time
|
28 |
+
# pylint: disable=g-importing-member
|
29 |
+
from matplotlib.collections import PatchCollection
|
30 |
+
from matplotlib.patches import Polygon
|
31 |
+
from matplotlib.patches import Rectangle
|
32 |
+
import matplotlib.pyplot as plt
|
33 |
+
import numpy as np
|
34 |
+
from PIL import Image
|
35 |
+
from pycocotools import mask
|
36 |
+
import skimage.io as io
|
37 |
+
import torch
|
38 |
+
import torch.utils.data as data
|
39 |
+
from torchvision import transforms
|
40 |
+
|
41 |
+
|
42 |
+
class REFER:
|
43 |
+
"""RefCOCO dataset."""
|
44 |
+
|
45 |
+
def __init__(self, data_root, dataset='refcoco', splitBy='unc', split='val'):
|
46 |
+
# provide data_root folder which contains refclef, refcoco, refcoco+ and refcocog
|
47 |
+
# also provide dataset name and splitBy information
|
48 |
+
# e.g., dataset = 'refcoco', splitBy = 'unc'
|
49 |
+
print('loading dataset %s into memory...' % dataset)
|
50 |
+
if dataset == 'refcocog':
|
51 |
+
print('Split by {}!'.format(splitBy))
|
52 |
+
self.ROOT_DIR = osp.abspath(osp.dirname(__file__))
|
53 |
+
self.DATA_DIR = osp.join(data_root, dataset)
|
54 |
+
if dataset in ['refcoco', 'refcoco+', 'refcocog']:
|
55 |
+
self.IMAGE_DIR = osp.join(data_root, 'images/mscoco/images/train2014')
|
56 |
+
elif dataset == 'refclef':
|
57 |
+
self.IMAGE_DIR = osp.join(data_root, 'images/saiapr_tc-12')
|
58 |
+
else:
|
59 |
+
print('No refer dataset is called [%s]' % dataset)
|
60 |
+
sys.exit()
|
61 |
+
|
62 |
+
# load refs from data/dataset/refs(dataset).json
|
63 |
+
tic = time.time()
|
64 |
+
ref_file = osp.join(self.DATA_DIR, 'refs(' + splitBy + ').p')
|
65 |
+
self.data = {}
|
66 |
+
self.data['dataset'] = dataset
|
67 |
+
# f = open(ref_file, 'r')
|
68 |
+
self.data['refs'] = pickle.load(open(ref_file, 'rb'))
|
69 |
+
|
70 |
+
# load annotations from data/dataset/instances.json
|
71 |
+
instances_file = osp.join(self.DATA_DIR, 'instances.json')
|
72 |
+
instances = json.load(open(instances_file, 'r'))
|
73 |
+
self.data['images'] = instances['images']
|
74 |
+
self.data['annotations'] = instances['annotations']
|
75 |
+
self.data['categories'] = instances['categories']
|
76 |
+
|
77 |
+
# create index
|
78 |
+
self.createIndex()
|
79 |
+
self.split = split
|
80 |
+
print('DONE (t=%.2fs)' % (time.time() - tic))
|
81 |
+
|
82 |
+
def createIndex(self):
|
83 |
+
# create sets of mapping
|
84 |
+
# 1) Refs: {ref_id: ref}
|
85 |
+
# 2) Anns: {ann_id: ann}
|
86 |
+
# 3) Imgs: {image_id: image}
|
87 |
+
# 4) Cats: {category_id: category_name}
|
88 |
+
# 5) Sents: {sent_id: sent}
|
89 |
+
# 6) imgToRefs: {image_id: refs}
|
90 |
+
# 7) imgToAnns: {image_id: anns}
|
91 |
+
# 8) refToAnn: {ref_id: ann}
|
92 |
+
# 9) annToRef: {ann_id: ref}
|
93 |
+
# 10) catToRefs: {category_id: refs}
|
94 |
+
# 11) sentToRef: {sent_id: ref}
|
95 |
+
# 12) sentToTokens: {sent_id: tokens}
|
96 |
+
print('creating index...')
|
97 |
+
# fetch info from instances
|
98 |
+
Anns, Imgs, Cats, imgToAnns = {}, {}, {}, {}
|
99 |
+
for ann in self.data['annotations']:
|
100 |
+
Anns[ann['id']] = ann
|
101 |
+
imgToAnns[ann['image_id']] = imgToAnns.get(ann['image_id'], []) + [ann]
|
102 |
+
for img in self.data['images']:
|
103 |
+
Imgs[img['id']] = img
|
104 |
+
for cat in self.data['categories']:
|
105 |
+
Cats[cat['id']] = cat['name']
|
106 |
+
|
107 |
+
# fetch info from refs
|
108 |
+
Refs, imgToRefs, refToAnn, annToRef, catToRefs = {}, {}, {}, {}, {}
|
109 |
+
Sents, sentToRef, sentToTokens = {}, {}, {}
|
110 |
+
for ref in self.data['refs']:
|
111 |
+
# ids
|
112 |
+
ref_id = ref['ref_id']
|
113 |
+
ann_id = ref['ann_id']
|
114 |
+
category_id = ref['category_id']
|
115 |
+
image_id = ref['image_id']
|
116 |
+
|
117 |
+
# add mapping related to ref
|
118 |
+
Refs[ref_id] = ref
|
119 |
+
imgToRefs[image_id] = imgToRefs.get(image_id, []) + [ref]
|
120 |
+
catToRefs[category_id] = catToRefs.get(category_id, []) + [ref]
|
121 |
+
refToAnn[ref_id] = Anns[ann_id]
|
122 |
+
annToRef[ann_id] = ref
|
123 |
+
|
124 |
+
# add mapping of sent
|
125 |
+
for sent in ref['sentences']:
|
126 |
+
Sents[sent['sent_id']] = sent
|
127 |
+
sentToRef[sent['sent_id']] = ref
|
128 |
+
sentToTokens[sent['sent_id']] = sent['tokens']
|
129 |
+
|
130 |
+
# create class members
|
131 |
+
self.Refs = Refs
|
132 |
+
self.Anns = Anns
|
133 |
+
self.Imgs = Imgs
|
134 |
+
self.Cats = Cats
|
135 |
+
self.Sents = Sents
|
136 |
+
self.imgToRefs = imgToRefs
|
137 |
+
self.imgToAnns = imgToAnns
|
138 |
+
self.refToAnn = refToAnn
|
139 |
+
self.annToRef = annToRef
|
140 |
+
self.catToRefs = catToRefs
|
141 |
+
self.sentToRef = sentToRef
|
142 |
+
self.sentToTokens = sentToTokens
|
143 |
+
print('index created.')
|
144 |
+
|
145 |
+
def getRefIds(self, image_ids=[], cat_ids=[], ref_ids=[], split=''):
|
146 |
+
image_ids = image_ids if type(image_ids) == list else [image_ids]
|
147 |
+
cat_ids = cat_ids if type(cat_ids) == list else [cat_ids]
|
148 |
+
ref_ids = ref_ids if type(ref_ids) == list else [ref_ids]
|
149 |
+
|
150 |
+
if len(image_ids) == len(cat_ids) == len(ref_ids) == len(split) == 0:
|
151 |
+
refs = self.data['refs']
|
152 |
+
else:
|
153 |
+
if not len(image_ids) == 0:
|
154 |
+
refs = [self.imgToRefs[image_id] for image_id in image_ids]
|
155 |
+
ref_ids = []
|
156 |
+
for img_ref in refs:
|
157 |
+
ref_ids.extend([ref['ref_id'] for ref in img_ref])
|
158 |
+
return ref_ids
|
159 |
+
else:
|
160 |
+
refs = self.data['refs']
|
161 |
+
if not len(cat_ids) == 0:
|
162 |
+
refs = [ref for ref in refs if ref['category_id'] in cat_ids]
|
163 |
+
if not len(ref_ids) == 0:
|
164 |
+
refs = [ref for ref in refs if ref['ref_id'] in ref_ids]
|
165 |
+
if not len(split) == 0:
|
166 |
+
if split in ['testA', 'testB', 'testC']:
|
167 |
+
# we also consider testAB, testBC, ...
|
168 |
+
refs = [ref for ref in refs if split[-1] in ref['split']]
|
169 |
+
elif split in ['testAB', 'testBC', 'testAC']:
|
170 |
+
# rarely used I guess...
|
171 |
+
refs = [ref for ref in refs if ref['split'] == split]
|
172 |
+
elif split == 'test':
|
173 |
+
refs = [ref for ref in refs if 'test' in ref['split']]
|
174 |
+
elif split == 'train' or split == 'val':
|
175 |
+
refs = [ref for ref in refs if ref['split'] == split]
|
176 |
+
else:
|
177 |
+
print('No such split [%s]' % split)
|
178 |
+
sys.exit()
|
179 |
+
ref_ids = [ref['ref_id'] for ref in refs]
|
180 |
+
return ref_ids
|
181 |
+
|
182 |
+
def getAnnIds(self, image_ids=[], cat_ids=[], ref_ids=[]):
|
183 |
+
image_ids = image_ids if type(image_ids) == list else [image_ids]
|
184 |
+
cat_ids = cat_ids if type(cat_ids) == list else [cat_ids]
|
185 |
+
ref_ids = ref_ids if type(ref_ids) == list else [ref_ids]
|
186 |
+
|
187 |
+
if len(image_ids) == len(cat_ids) == len(ref_ids) == 0:
|
188 |
+
ann_ids = [ann['id'] for ann in self.data['annotations']]
|
189 |
+
else:
|
190 |
+
if not len(image_ids) == 0:
|
191 |
+
lists = [
|
192 |
+
self.imgToAnns[image_id]
|
193 |
+
for image_id in image_ids
|
194 |
+
if image_id in self.imgToAnns
|
195 |
+
] # list of [anns]
|
196 |
+
anns = list(itertools.chain.from_iterable(lists))
|
197 |
+
else:
|
198 |
+
anns = self.data['annotations']
|
199 |
+
if not len(cat_ids) == 0:
|
200 |
+
anns = [ann for ann in anns if ann['category_id'] in cat_ids]
|
201 |
+
ann_ids = [ann['id'] for ann in anns]
|
202 |
+
# if not len(ref_ids) == 0:
|
203 |
+
# ids = set(ann_ids).intersection(
|
204 |
+
# set([self.Refs[ref_id]['ann_id'] for ref_id in ref_ids])
|
205 |
+
# )
|
206 |
+
return ann_ids
|
207 |
+
|
208 |
+
def getImgIds(self, ref_ids=[]):
|
209 |
+
ref_ids = ref_ids if type(ref_ids) == list else [ref_ids]
|
210 |
+
|
211 |
+
if not len(ref_ids) == 0:
|
212 |
+
image_ids = list(
|
213 |
+
set([self.Refs[ref_id]['image_id'] for ref_id in ref_ids])
|
214 |
+
)
|
215 |
+
else:
|
216 |
+
image_ids = self.Imgs.keys()
|
217 |
+
return image_ids
|
218 |
+
|
219 |
+
def getCatIds(self):
|
220 |
+
return self.Cats.keys()
|
221 |
+
|
222 |
+
def loadRefs(self, ref_ids=[]):
|
223 |
+
if type(ref_ids) == list:
|
224 |
+
return [self.Refs[ref_id] for ref_id in ref_ids]
|
225 |
+
elif type(ref_ids) == int:
|
226 |
+
return [self.Refs[ref_ids]]
|
227 |
+
|
228 |
+
def loadAnns(self, ann_ids=[]):
|
229 |
+
if type(ann_ids) == list:
|
230 |
+
return [self.Anns[ann_id] for ann_id in ann_ids]
|
231 |
+
elif type(ann_ids) == int or type(ann_ids) == unicode:
|
232 |
+
return [self.Anns[ann_ids]]
|
233 |
+
|
234 |
+
def loadImgs(self, image_ids=[]):
|
235 |
+
if type(image_ids) == list:
|
236 |
+
return [self.Imgs[image_id] for image_id in image_ids]
|
237 |
+
elif type(image_ids) == int:
|
238 |
+
return [self.Imgs[image_ids]]
|
239 |
+
|
240 |
+
def loadCats(self, cat_ids=[]):
|
241 |
+
if type(cat_ids) == list:
|
242 |
+
return [self.Cats[cat_id] for cat_id in cat_ids]
|
243 |
+
elif type(cat_ids) == int:
|
244 |
+
return [self.Cats[cat_ids]]
|
245 |
+
|
246 |
+
def getRefBox(self, ref_id):
|
247 |
+
# ref = self.Refs[ref_id]
|
248 |
+
ann = self.refToAnn[ref_id]
|
249 |
+
return ann['bbox'] # [x, y, w, h]
|
250 |
+
|
251 |
+
def showRef(self, ref, seg_box='seg'):
|
252 |
+
ax = plt.gca()
|
253 |
+
# show image
|
254 |
+
image = self.Imgs[ref['image_id']]
|
255 |
+
I = io.imread(osp.join(self.IMAGE_DIR, image['file_name']))
|
256 |
+
ax.imshow(I)
|
257 |
+
# show refer expression
|
258 |
+
for sid, sent in enumerate(ref['sentences']):
|
259 |
+
print('%s. %s' % (sid + 1, sent['sent']))
|
260 |
+
# show segmentations
|
261 |
+
if seg_box == 'seg':
|
262 |
+
ann_id = ref['ann_id']
|
263 |
+
ann = self.Anns[ann_id]
|
264 |
+
polygons = []
|
265 |
+
color = []
|
266 |
+
c = 'none'
|
267 |
+
if type(ann['segmentation'][0]) == list:
|
268 |
+
# polygon used for refcoco*
|
269 |
+
for seg in ann['segmentation']:
|
270 |
+
poly = np.array(seg).reshape((len(seg) / 2, 2))
|
271 |
+
polygons.append(Polygon(poly, True, alpha=0.4))
|
272 |
+
color.append(c)
|
273 |
+
p = PatchCollection(
|
274 |
+
polygons,
|
275 |
+
facecolors=color,
|
276 |
+
edgecolors=(1, 1, 0, 0),
|
277 |
+
linewidths=3,
|
278 |
+
alpha=1,
|
279 |
+
)
|
280 |
+
ax.add_collection(p) # thick yellow polygon
|
281 |
+
p = PatchCollection(
|
282 |
+
polygons,
|
283 |
+
facecolors=color,
|
284 |
+
edgecolors=(1, 0, 0, 0),
|
285 |
+
linewidths=1,
|
286 |
+
alpha=1,
|
287 |
+
)
|
288 |
+
ax.add_collection(p) # thin red polygon
|
289 |
+
else:
|
290 |
+
# mask used for refclef
|
291 |
+
rle = ann['segmentation']
|
292 |
+
m = mask.decode(rle)
|
293 |
+
img = np.ones((m.shape[0], m.shape[1], 3))
|
294 |
+
color_mask = np.array([2.0, 166.0, 101.0]) / 255
|
295 |
+
for i in range(3):
|
296 |
+
img[:, :, i] = color_mask[i]
|
297 |
+
ax.imshow(np.dstack((img, m * 0.5)))
|
298 |
+
# show bounding-box
|
299 |
+
elif seg_box == 'box':
|
300 |
+
# ann_id = ref['ann_id']
|
301 |
+
# ann = self.Anns[ann_id]
|
302 |
+
bbox = self.getRefBox(ref['ref_id'])
|
303 |
+
box_plot = Rectangle(
|
304 |
+
(bbox[0], bbox[1]),
|
305 |
+
bbox[2],
|
306 |
+
bbox[3],
|
307 |
+
fill=False,
|
308 |
+
edgecolor='green',
|
309 |
+
linewidth=3,
|
310 |
+
)
|
311 |
+
ax.add_patch(box_plot)
|
312 |
+
|
313 |
+
def getMask(self, ref):
|
314 |
+
# return mask, area and mask-center
|
315 |
+
ann = self.refToAnn[ref['ref_id']]
|
316 |
+
image = self.Imgs[ref['image_id']]
|
317 |
+
|
318 |
+
if type(ann['segmentation'][0]) == list: # polygon
|
319 |
+
rle = mask.frPyObjects(
|
320 |
+
ann['segmentation'], image['height'], image['width']
|
321 |
+
)
|
322 |
+
else:
|
323 |
+
rle = ann['segmentation']
|
324 |
+
|
325 |
+
m = mask.decode(rle)
|
326 |
+
# sometimes there are multiple binary map (corresponding to multiple segs)
|
327 |
+
m = np.sum(m, axis=2)
|
328 |
+
m = m.astype(np.uint8) # convert to np.uint8
|
329 |
+
# compute area
|
330 |
+
area = sum(mask.area(rle)) # should be close to ann['area']
|
331 |
+
return {'mask': m, 'area': area}
|
332 |
+
|
333 |
+
def showMask(self, ref):
|
334 |
+
M = self.getMask(ref)
|
335 |
+
msk = M['mask']
|
336 |
+
ax = plt.gca()
|
337 |
+
ax.imshow(msk)
|
338 |
+
|
339 |
+
|
340 |
+
class ReferDataset(data.Dataset):
|
341 |
+
|
342 |
+
def __init__(
|
343 |
+
self,
|
344 |
+
root,
|
345 |
+
dataset='refcoco',
|
346 |
+
splitBy='google',
|
347 |
+
image_transforms=None,
|
348 |
+
target_transforms=None,
|
349 |
+
split='train',
|
350 |
+
eval_mode=False,
|
351 |
+
):
|
352 |
+
|
353 |
+
self.classes = []
|
354 |
+
self.image_transforms = image_transforms
|
355 |
+
self.target_transforms = target_transforms
|
356 |
+
self.split = split
|
357 |
+
self.refer = REFER(root, dataset=dataset, splitBy=splitBy)
|
358 |
+
|
359 |
+
ref_ids = self.refer.getRefIds(split=self.split)
|
360 |
+
img_ids = self.refer.getImgIds(ref_ids)
|
361 |
+
|
362 |
+
all_imgs = self.refer.Imgs
|
363 |
+
self.imgs = list(all_imgs[i] for i in img_ids)
|
364 |
+
self.ref_ids = ref_ids
|
365 |
+
# print(len(ref_ids))
|
366 |
+
# print(len(self.imgs))
|
367 |
+
self.sentence_raw = []
|
368 |
+
|
369 |
+
self.eval_mode = eval_mode
|
370 |
+
# if we are testing on a dataset, test all sentences of an object;
|
371 |
+
# o/w, we are validating during training, randomly sample one sentence
|
372 |
+
# for efficiency
|
373 |
+
for r in ref_ids:
|
374 |
+
ref = self.refer.Refs[r]
|
375 |
+
# ref_sentences = []
|
376 |
+
# for i, (el, sent_id) in enumerate(zip(ref['sentences'],
|
377 |
+
# ref['sent_ids'])):
|
378 |
+
for el in ref['sentences']:
|
379 |
+
sentence_raw = el['raw']
|
380 |
+
ref_sentences.append(sentence_raw)
|
381 |
+
self.sentence_raw.append(ref_sentences)
|
382 |
+
# print(len(self.sentence_raw))
|
383 |
+
|
384 |
+
def get_classes(self):
|
385 |
+
return self.classes
|
386 |
+
|
387 |
+
def __len__(self):
|
388 |
+
return len(self.ref_ids)
|
389 |
+
|
390 |
+
def __getitem__(self, index):
|
391 |
+
this_ref_id = self.ref_ids[index]
|
392 |
+
this_img_id = self.refer.getImgIds(this_ref_id)
|
393 |
+
this_img = self.refer.Imgs[this_img_id[0]]
|
394 |
+
# print(this_ref_id, this_img_id)
|
395 |
+
# print(len(self.ref_ids))
|
396 |
+
img_path = os.path.join(self.refer.IMAGE_DIR, this_img['file_name'])
|
397 |
+
img = Image.open(img_path).convert('RGB')
|
398 |
+
ref = self.refer.loadRefs(this_ref_id)
|
399 |
+
# print("ref",ref)
|
400 |
+
|
401 |
+
ref_mask = np.array(self.refer.getMask(ref[0])['mask'])
|
402 |
+
annot = np.zeros(ref_mask.shape)
|
403 |
+
annot[ref_mask == 1] = 1
|
404 |
+
|
405 |
+
target = Image.fromarray(annot.astype(np.uint8), mode='P')
|
406 |
+
# print(np.array(target), np.unique(np.array(target).flatten()))
|
407 |
+
if self.image_transforms is not None:
|
408 |
+
# resize, from PIL to tensor, and mean and std normalization
|
409 |
+
img = self.image_transforms(img)
|
410 |
+
# target = self.target_transforms(target)
|
411 |
+
target = torch.as_tensor(np.array(target, copy=True))
|
412 |
+
# target = target.permute((2, 0, 1))
|
413 |
+
sentence = self.sentence_raw[index]
|
414 |
+
|
415 |
+
return img, img_path, target, sentence
|
416 |
+
|
417 |
+
|
418 |
+
if __name__ == '__main__':
|
419 |
+
|
420 |
+
def get_transform():
|
421 |
+
transform = [
|
422 |
+
transforms.Resize((224, 224)),
|
423 |
+
transforms.ToTensor(),
|
424 |
+
# T.Normalize(mean=[0.485, 0.456, 0.406],
|
425 |
+
# std=[0.229, 0.224, 0.225])
|
426 |
+
]
|
427 |
+
|
428 |
+
return transforms.Compose(transform)
|
429 |
+
|
430 |
+
transform = get_transform()
|
431 |
+
dataset_test = ReferDataset(
|
432 |
+
root='/datasets/refseg',
|
433 |
+
dataset='refcoco+',
|
434 |
+
splitBy='google',
|
435 |
+
image_transforms=transform,
|
436 |
+
target_transforms=transform,
|
437 |
+
split='train',
|
438 |
+
eval_mode=False,
|
439 |
+
)
|
440 |
+
print('loaded')
|
441 |
+
test_sampler = torch.utils.data.SequentialSampler(dataset_test)
|
442 |
+
data_loader_test = torch.utils.data.DataLoader(
|
443 |
+
dataset_test, batch_size=1, sampler=test_sampler, num_workers=1
|
444 |
+
)
|
445 |
+
|
446 |
+
for img, target, sentence in data_loader_test:
|
447 |
+
# print(type(img),type(target))
|
448 |
+
print(sentence)
|
449 |
+
break
|
data/voc.py
ADDED
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2024 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""Pascal VOC dataset."""
|
17 |
+
|
18 |
+
import numpy as np
|
19 |
+
from PIL import Image
|
20 |
+
# pylint: disable=g-importing-member
|
21 |
+
from torchvision.datasets import VOCSegmentation
|
22 |
+
|
23 |
+
CLASS2ID = {
|
24 |
+
'Background': 0,
|
25 |
+
'Aero plane': 1,
|
26 |
+
'Bicycle': 2,
|
27 |
+
'Bird': 3,
|
28 |
+
'Boat': 4,
|
29 |
+
'Bottle': 5,
|
30 |
+
'Bus': 6,
|
31 |
+
'Car': 7,
|
32 |
+
'Cat': 8,
|
33 |
+
'Chair': 9,
|
34 |
+
'Cow': 10,
|
35 |
+
'Dining table': 11,
|
36 |
+
'Dog': 12,
|
37 |
+
'Horse': 13,
|
38 |
+
'Motorbike': 14,
|
39 |
+
'Person': 15,
|
40 |
+
'Potted plant': 16,
|
41 |
+
'Sheep': 17,
|
42 |
+
'Sofa': 18,
|
43 |
+
'Train': 19,
|
44 |
+
'Tv/Monitor': 20,
|
45 |
+
# ... add more entries as needed
|
46 |
+
'Border': 255,
|
47 |
+
}
|
48 |
+
|
49 |
+
|
50 |
+
VOC_CLASSES = [
|
51 |
+
'aeroplane',
|
52 |
+
'bicycle',
|
53 |
+
'bird avian',
|
54 |
+
'boat',
|
55 |
+
'bottle',
|
56 |
+
'bus',
|
57 |
+
'car',
|
58 |
+
'cat',
|
59 |
+
'chair seat',
|
60 |
+
'cow',
|
61 |
+
'diningtable',
|
62 |
+
'dog',
|
63 |
+
'horse',
|
64 |
+
'motorbike',
|
65 |
+
'person with clothes,people,human',
|
66 |
+
'pottedplant',
|
67 |
+
'sheep',
|
68 |
+
'sofa',
|
69 |
+
'train',
|
70 |
+
'tvmonitor screen',
|
71 |
+
]
|
72 |
+
|
73 |
+
|
74 |
+
BACKGROUND_CATEGORY = [
|
75 |
+
'ground',
|
76 |
+
'land',
|
77 |
+
'grass',
|
78 |
+
'tree',
|
79 |
+
'building',
|
80 |
+
'wall',
|
81 |
+
'sky',
|
82 |
+
'lake',
|
83 |
+
'water',
|
84 |
+
'river',
|
85 |
+
'sea',
|
86 |
+
'keyboard',
|
87 |
+
'helmet',
|
88 |
+
'cloud',
|
89 |
+
'house',
|
90 |
+
'mountain',
|
91 |
+
'ocean',
|
92 |
+
'road',
|
93 |
+
'rock',
|
94 |
+
'street',
|
95 |
+
'valley',
|
96 |
+
'bridge',
|
97 |
+
'sign',
|
98 |
+
]
|
99 |
+
|
100 |
+
|
101 |
+
class VOCDataset(VOCSegmentation):
|
102 |
+
"""Pascal VOC dataset."""
|
103 |
+
|
104 |
+
def __init__(
|
105 |
+
self,
|
106 |
+
root='/datasets/jianhaoy/PASCAL/',
|
107 |
+
year='2012',
|
108 |
+
split='val',
|
109 |
+
target_transform=None,
|
110 |
+
download=False,
|
111 |
+
transform=None,
|
112 |
+
):
|
113 |
+
super(VOCDataset, self).__init__(
|
114 |
+
root=root,
|
115 |
+
image_set=split,
|
116 |
+
year=year,
|
117 |
+
target_transform=transform,
|
118 |
+
download=download,
|
119 |
+
transform=transform,
|
120 |
+
)
|
121 |
+
self.idx_to_class = {val: key for (key, val) in CLASS2ID.items()}
|
122 |
+
|
123 |
+
def __getitem__(self, index):
|
124 |
+
image_path = self.images[index]
|
125 |
+
image = Image.open(image_path).convert('RGB')
|
126 |
+
target = np.asarray(Image.open(self.masks[index]), dtype=np.int32)
|
127 |
+
|
128 |
+
_, unique_values = self.process_target(np.array(target))
|
129 |
+
classnames = [self.idx_to_class[idx] for idx in unique_values]
|
130 |
+
|
131 |
+
if self.transforms:
|
132 |
+
image = self.transform(image)
|
133 |
+
|
134 |
+
return image, str(image_path), target, classnames
|
135 |
+
|
136 |
+
def process_target(self, arr):
|
137 |
+
# Set values 0 and 255 to 1
|
138 |
+
arr[(arr == 0) | (arr == 255)] = 0
|
139 |
+
|
140 |
+
# Find unique values (excluding 0 and 255)
|
141 |
+
unique_values = np.unique(arr[(arr != 0) & (arr != 255)])
|
142 |
+
|
143 |
+
# Create separate masks for each unique value
|
144 |
+
masks = [arr == value for value in unique_values]
|
145 |
+
masks = [Image.fromarray(arr) for arr in masks]
|
146 |
+
masks = [self.target_transform(arr) for arr in masks]
|
147 |
+
|
148 |
+
return masks, unique_values
|
demo.py
ADDED
@@ -0,0 +1,227 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Run a demo of the CaR model on a single image."""
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import os
|
5 |
+
import argparse
|
6 |
+
from functools import reduce
|
7 |
+
import PIL.Image as Image
|
8 |
+
import torch
|
9 |
+
from modeling.model import CaR
|
10 |
+
from utils.utils import Config, load_yaml
|
11 |
+
import matplotlib.pyplot as plt
|
12 |
+
import colorsys
|
13 |
+
from modeling.post_process.post_process import (
|
14 |
+
match_masks,
|
15 |
+
generate_masks_from_sam,
|
16 |
+
)
|
17 |
+
from sam.sam import SAMPipeline
|
18 |
+
from sam.utils import build_sam_config
|
19 |
+
import random
|
20 |
+
import time
|
21 |
+
|
22 |
+
|
23 |
+
def generate_distinct_colors(n):
|
24 |
+
colors = []
|
25 |
+
# generate a random number from 0 to 1
|
26 |
+
random_color_bias = random.random()
|
27 |
+
|
28 |
+
for i in range(n):
|
29 |
+
hue = float(i) / n
|
30 |
+
hue += random_color_bias
|
31 |
+
hue = hue % 1.0
|
32 |
+
rgb = colorsys.hsv_to_rgb(hue, 1.0, 1.0)
|
33 |
+
# Convert RGB values from [0, 1] range to [0, 255]
|
34 |
+
colors.append(tuple(int(val * 255) for val in rgb))
|
35 |
+
return colors
|
36 |
+
|
37 |
+
|
38 |
+
def overlap_masks(masks):
|
39 |
+
"""
|
40 |
+
Overlap masks to generate a single mask for visualization.
|
41 |
+
|
42 |
+
Parameters:
|
43 |
+
- masks: list of np.arrays of shape (H, W) representing binary masks
|
44 |
+
for each class.
|
45 |
+
|
46 |
+
Returns:
|
47 |
+
- overlap_mask: list of np.array of shape (H, W) that have no overlaps
|
48 |
+
"""
|
49 |
+
overlap_mask = torch.zeros_like(masks[0])
|
50 |
+
for mask_idx, mask in enumerate(masks):
|
51 |
+
overlap_mask[mask > 0] = mask_idx + 1
|
52 |
+
|
53 |
+
clean_masks = [
|
54 |
+
overlap_mask == mask_idx + 1 for mask_idx in range(len(masks))
|
55 |
+
]
|
56 |
+
clean_masks = torch.stack(clean_masks, dim=0)
|
57 |
+
|
58 |
+
return clean_masks
|
59 |
+
|
60 |
+
|
61 |
+
def visualize_segmentation(
|
62 |
+
image, masks, class_names, alpha=0.45, y_list=None, x_list=None
|
63 |
+
):
|
64 |
+
"""
|
65 |
+
Visualize segmentation masks on an image.
|
66 |
+
|
67 |
+
Parameters:
|
68 |
+
- image: np.array of shape (H, W, 3) representing the RGB image
|
69 |
+
- masks: list of np.arrays of shape (H, W) representing binary masks
|
70 |
+
for each class.
|
71 |
+
- class_names: list of strings representing names of each class
|
72 |
+
- alpha: float, transparency level of masks on the image
|
73 |
+
|
74 |
+
Returns:
|
75 |
+
- visualization: plt.figure object
|
76 |
+
"""
|
77 |
+
# Create a figure and axis
|
78 |
+
fig, ax = plt.subplots(1, figsize=(12, 9))
|
79 |
+
# Display the image
|
80 |
+
# ax.imshow(image)
|
81 |
+
# Generate distinct colors for each mask
|
82 |
+
final_mask = np.zeros(
|
83 |
+
(masks.shape[1], masks.shape[2], 3), dtype=np.float32
|
84 |
+
)
|
85 |
+
colors = generate_distinct_colors(len(class_names))
|
86 |
+
idx = 0
|
87 |
+
for mask, color, class_name in zip(masks, colors, class_names):
|
88 |
+
# Overlay the mask
|
89 |
+
final_mask += np.dstack([mask * c for c in color])
|
90 |
+
# Find a representative point (e.g., centroid) for placing the label
|
91 |
+
if y_list is None or x_list is None:
|
92 |
+
y, x = np.argwhere(mask).mean(axis=0)
|
93 |
+
else:
|
94 |
+
y, x = y_list[idx], x_list[idx]
|
95 |
+
ax.text(
|
96 |
+
x,
|
97 |
+
y,
|
98 |
+
class_name,
|
99 |
+
color="white",
|
100 |
+
fontsize=36,
|
101 |
+
va="center",
|
102 |
+
ha="center",
|
103 |
+
bbox=dict(facecolor="black", alpha=0.7, edgecolor="none"),
|
104 |
+
)
|
105 |
+
|
106 |
+
idx += 1
|
107 |
+
|
108 |
+
final_image = image * (1 - alpha) + final_mask * alpha
|
109 |
+
final_image = final_image.astype(np.uint8)
|
110 |
+
ax.imshow(final_image)
|
111 |
+
# Remove axis ticks and labels
|
112 |
+
ax.axis("off")
|
113 |
+
return fig
|
114 |
+
|
115 |
+
|
116 |
+
def get_sam_masks(config, image_path, masks, img_sam=None, pipeline=None):
|
117 |
+
print("generating sam masks online")
|
118 |
+
mask_tensor, mask_list = generate_masks_from_sam(
|
119 |
+
image_path,
|
120 |
+
save_path="./",
|
121 |
+
pipeline=pipeline,
|
122 |
+
img_sam=img_sam,
|
123 |
+
visualize=False,
|
124 |
+
)
|
125 |
+
mask_tensor = mask_tensor.to(masks.device)
|
126 |
+
# only conduct sam on masks that is not all zero
|
127 |
+
attn_map, mask_ids = [], []
|
128 |
+
for mask_id, mask in enumerate(masks):
|
129 |
+
if torch.sum(mask) > 0:
|
130 |
+
attn_map.append(mask.unsqueeze(0))
|
131 |
+
mask_ids.append(mask_id)
|
132 |
+
matched_masks = [
|
133 |
+
match_masks(
|
134 |
+
mask_tensor,
|
135 |
+
attn,
|
136 |
+
mask_list,
|
137 |
+
iom_thres=config.car.iom_thres,
|
138 |
+
min_pred_threshold=config.sam.min_pred_threshold,
|
139 |
+
)
|
140 |
+
for attn in attn_map
|
141 |
+
]
|
142 |
+
for matched_mask, mask_id in zip(matched_masks, mask_ids):
|
143 |
+
sam_masks = np.array([item["segmentation"] for item in matched_mask])
|
144 |
+
sam_mask = np.any(sam_masks, axis=0)
|
145 |
+
masks[mask_id] = torch.from_numpy(sam_mask).to(masks.device)
|
146 |
+
return masks
|
147 |
+
|
148 |
+
|
149 |
+
def load_sam(config, sam_device):
|
150 |
+
sam_checkpoint, model_type = build_sam_config(config)
|
151 |
+
pipelines = SAMPipeline(
|
152 |
+
sam_checkpoint,
|
153 |
+
model_type,
|
154 |
+
device=sam_device,
|
155 |
+
points_per_side=config.sam.points_per_side,
|
156 |
+
pred_iou_thresh=config.sam.pred_iou_thresh,
|
157 |
+
stability_score_thresh=config.sam.stability_score_thresh,
|
158 |
+
box_nms_thresh=config.sam.box_nms_thresh,
|
159 |
+
)
|
160 |
+
return pipelines
|
161 |
+
|
162 |
+
|
163 |
+
if __name__ == "__main__":
|
164 |
+
parser = argparse.ArgumentParser("CaR")
|
165 |
+
# default arguments
|
166 |
+
|
167 |
+
# additional arguments
|
168 |
+
parser.add_argument(
|
169 |
+
"--output_path", type=str, default="", help="path to save outputs"
|
170 |
+
)
|
171 |
+
parser.add_argument(
|
172 |
+
"--cfg-path",
|
173 |
+
default="configs/voc_test.yaml",
|
174 |
+
help="path to configuration file.",
|
175 |
+
)
|
176 |
+
args = parser.parse_args()
|
177 |
+
|
178 |
+
cfg = Config(**load_yaml(args.cfg_path))
|
179 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
180 |
+
# device = 'cpu'
|
181 |
+
folder_name = reduce(
|
182 |
+
lambda x, y: x.replace(" ", "_") + "_" + y, cfg.image_caption
|
183 |
+
)
|
184 |
+
if len(folder_name) > 20:
|
185 |
+
folder_name = folder_name[:20]
|
186 |
+
|
187 |
+
car_model = CaR(
|
188 |
+
cfg, visualize=True, seg_mode=cfg.test.seg_mode, device=device
|
189 |
+
)
|
190 |
+
|
191 |
+
sam_pipeline = load_sam(cfg, device)
|
192 |
+
|
193 |
+
img = Image.open(cfg.image_path).convert("RGB")
|
194 |
+
import pdb; pdb.set_trace()
|
195 |
+
# resize image by dividing 2 if the size is larger than 1000
|
196 |
+
if img.size[0] > 1000:
|
197 |
+
img = img.resize((img.size[0] // 3, img.size[1] // 3))
|
198 |
+
|
199 |
+
label_space = cfg.image_caption
|
200 |
+
pseudo_masks, scores, _ = car_model(img, label_space)
|
201 |
+
|
202 |
+
|
203 |
+
if not cfg.test.use_pseudo:
|
204 |
+
t1 = time.time()
|
205 |
+
pseudo_masks = get_sam_masks(
|
206 |
+
cfg,
|
207 |
+
cfg.image_path,
|
208 |
+
pseudo_masks,
|
209 |
+
img_sam=np.array(img),
|
210 |
+
pipeline=sam_pipeline,
|
211 |
+
)
|
212 |
+
pseudo_masks = overlap_masks(pseudo_masks)
|
213 |
+
t2 = time.time()
|
214 |
+
print(f"sam time: {t2 - t1}")
|
215 |
+
|
216 |
+
# visualize segmentation masks
|
217 |
+
demo_fig = visualize_segmentation(
|
218 |
+
np.array(img),
|
219 |
+
pseudo_masks.detach().cpu().numpy(),
|
220 |
+
label_space,
|
221 |
+
)
|
222 |
+
save_path = f"vis_results/{folder_name}"
|
223 |
+
if not os.path.exists(save_path):
|
224 |
+
os.makedirs(save_path)
|
225 |
+
demo_fig.savefig(os.path.join(save_path, "demo.png"), bbox_inches="tight")
|
226 |
+
|
227 |
+
print(f"results saved to {save_path}.")
|
evaluate.py
ADDED
@@ -0,0 +1,511 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2024 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""Evaluate CaR on segmentation benchmarks."""
|
17 |
+
# pylint: disable=g-importing-member
|
18 |
+
import argparse
|
19 |
+
import numpy as np
|
20 |
+
import torch
|
21 |
+
from torch.utils import tensorboard
|
22 |
+
import torch.utils.data
|
23 |
+
from torch.utils.data import Subset
|
24 |
+
import torchvision.transforms as T
|
25 |
+
|
26 |
+
# pylint: disable=g-bad-import-order
|
27 |
+
from modeling.model.car import CaR
|
28 |
+
from sam.utils import build_sam_config
|
29 |
+
from utils.utils import Config
|
30 |
+
from utils.utils import load_yaml
|
31 |
+
from utils.utils import MetricLogger
|
32 |
+
from utils.utils import SmoothedValue
|
33 |
+
from utils.inference_pipeline import inference_car
|
34 |
+
from utils.merge_mask import merge_masks_simple
|
35 |
+
|
36 |
+
# Datasets
|
37 |
+
# pylint: disable=g-multiple-import
|
38 |
+
from data.ade import ADE_THING_CLASS, ADE_STUFF_CLASS, ADE_THING_CLASS_ID, ADE_STUFF_CLASS_ID, ADEDataset
|
39 |
+
from data.ade847 import ADE_847_THING_CLASS_ID, ADE_847_STUFF_CLASS_ID, ADE_847_THING_CLASS, ADE_847_STUFF_CLASS, ADE847Dataset
|
40 |
+
from data.coco import COCO_OBJECT_CLASSES, COCODataset
|
41 |
+
from data.context import PASCAL_CONTEXT_STUFF_CLASS_ID, PASCAL_CONTEXT_THING_CLASS_ID, PASCAL_CONTEXT_STUFF_CLASS, PASCAL_CONTEXT_THING_CLASS, CONTEXTDataset
|
42 |
+
from data.gres import GReferDataset
|
43 |
+
from data.pascal459 import PASCAL_459_THING_CLASS_ID, PASCAL_459_STUFF_CLASS_ID, PASCAL_459_THING_CLASS, PASCAL_459_STUFF_CLASS, Pascal459Dataset
|
44 |
+
from data.refcoco import ReferDataset
|
45 |
+
from data.voc import VOC_CLASSES, VOCDataset
|
46 |
+
|
47 |
+
|
48 |
+
IMAGE_WIDTH, IMAGE_HEIGHT = 512, 512
|
49 |
+
|
50 |
+
# set random seed
|
51 |
+
torch.manual_seed(0)
|
52 |
+
np.random.seed(0)
|
53 |
+
|
54 |
+
|
55 |
+
def get_dataset(cfg, ds_name, split, transform, data_root=None):
|
56 |
+
"""Get dataset."""
|
57 |
+
data_args = dict(root=data_root) if data_root is not None else {}
|
58 |
+
if 'refcoco' in ds_name:
|
59 |
+
splitby = cfg.test.splitby if hasattr(cfg.test, 'splitby') else 'unc'
|
60 |
+
ds = ReferDataset(
|
61 |
+
dataset=ds_name,
|
62 |
+
splitBy=splitby,
|
63 |
+
split=split,
|
64 |
+
image_transforms=transform,
|
65 |
+
target_transforms=transform,
|
66 |
+
eval_mode=True,
|
67 |
+
prompts_augment=cfg.test.prompts_augment,
|
68 |
+
**data_args,
|
69 |
+
)
|
70 |
+
elif ds_name == 'gres':
|
71 |
+
ds = GReferDataset(split=split, transform=transform, **data_args)
|
72 |
+
elif ds_name == 'voc':
|
73 |
+
ds = VOCDataset(
|
74 |
+
year='2012',
|
75 |
+
split=split,
|
76 |
+
transform=transform,
|
77 |
+
target_transform=transform,
|
78 |
+
**data_args,
|
79 |
+
)
|
80 |
+
|
81 |
+
elif ds_name == 'cocostuff':
|
82 |
+
ds = COCODataset(transform=transform, **data_args)
|
83 |
+
|
84 |
+
elif ds_name == 'context':
|
85 |
+
ds = CONTEXTDataset(
|
86 |
+
year='2010', transform=transform, split=split, **data_args
|
87 |
+
)
|
88 |
+
elif ds_name == 'ade':
|
89 |
+
ds = ADEDataset(split=split, transform=transform, **data_args)
|
90 |
+
elif ds_name == 'pascal_459':
|
91 |
+
ds = Pascal459Dataset(split=split, transform=transform, **data_args)
|
92 |
+
elif ds_name == 'ade_847':
|
93 |
+
ds = ADE847Dataset(split=split, transform=transform, **data_args)
|
94 |
+
else:
|
95 |
+
raise ValueError(f'Dataset {ds_name} not implemented')
|
96 |
+
return ds
|
97 |
+
|
98 |
+
|
99 |
+
def get_transform():
|
100 |
+
transforms = [
|
101 |
+
T.Resize((IMAGE_WIDTH, IMAGE_HEIGHT)),
|
102 |
+
T.ToTensor(),
|
103 |
+
]
|
104 |
+
|
105 |
+
return T.Compose(transforms)
|
106 |
+
|
107 |
+
|
108 |
+
def assign_label(
|
109 |
+
all_masks,
|
110 |
+
scores,
|
111 |
+
stuff_masks=None,
|
112 |
+
stuff_scores=None,
|
113 |
+
id_mapping=None,
|
114 |
+
stuff_id_mapping=None,
|
115 |
+
):
|
116 |
+
"""Assign labels."""
|
117 |
+
label_preds = np.zeros_like(all_masks[0]).astype(np.int32)
|
118 |
+
if stuff_masks is not None:
|
119 |
+
sorted_idxs = np.argsort(stuff_scores.detach().cpu().numpy())
|
120 |
+
stuff_masks = stuff_masks[sorted_idxs]
|
121 |
+
stuff_scores = stuff_scores.detach().cpu().numpy()[sorted_idxs]
|
122 |
+
for sorted_idx, mask, score in zip(sorted_idxs, stuff_masks, stuff_scores):
|
123 |
+
if score > 0:
|
124 |
+
# convert mask to boolean
|
125 |
+
mask = mask > 0.5
|
126 |
+
# assign label
|
127 |
+
if stuff_id_mapping is not None:
|
128 |
+
label_preds[mask] = stuff_id_mapping[sorted_idx] + 1
|
129 |
+
else:
|
130 |
+
label_preds[mask] = sorted_idx + 1
|
131 |
+
sorted_idxs = np.argsort(scores.detach().cpu().numpy())
|
132 |
+
all_masks = all_masks[sorted_idxs]
|
133 |
+
scores = scores.detach().cpu().numpy()[sorted_idxs]
|
134 |
+
for sorted_idx, mask, score in zip(sorted_idxs, all_masks, scores):
|
135 |
+
if score > 0:
|
136 |
+
# convert mask to boolean
|
137 |
+
mask = mask > 0.5
|
138 |
+
# assign label
|
139 |
+
if id_mapping is not None:
|
140 |
+
label_preds[mask] = id_mapping[sorted_idx] + 1
|
141 |
+
else:
|
142 |
+
label_preds[mask] = sorted_idx + 1
|
143 |
+
|
144 |
+
return label_preds
|
145 |
+
|
146 |
+
|
147 |
+
def eval_semantic(
|
148 |
+
label_space,
|
149 |
+
algo,
|
150 |
+
cfg,
|
151 |
+
model,
|
152 |
+
image_path,
|
153 |
+
stuff_label_space=None,
|
154 |
+
sam_pipeline=None,
|
155 |
+
):
|
156 |
+
"""Semantic segmentation evaluation."""
|
157 |
+
|
158 |
+
if label_space is None:
|
159 |
+
raise ValueError(
|
160 |
+
'label_space must be provided for semantic segmentation evaluation'
|
161 |
+
)
|
162 |
+
if algo == 'car':
|
163 |
+
all_masks, scores = inference_car(
|
164 |
+
cfg, model, image_path, label_space, sam_pipeline=sam_pipeline
|
165 |
+
)
|
166 |
+
if stuff_label_space is not None:
|
167 |
+
if cfg.test.ds_name == 'context':
|
168 |
+
thing_id_mapping = PASCAL_CONTEXT_THING_CLASS_ID
|
169 |
+
stuff_id_mapping = PASCAL_CONTEXT_STUFF_CLASS_ID
|
170 |
+
elif cfg.test.ds_name == 'ade':
|
171 |
+
thing_id_mapping = ADE_THING_CLASS_ID
|
172 |
+
stuff_id_mapping = ADE_STUFF_CLASS_ID
|
173 |
+
elif cfg.test.ds_name == 'pascal_459':
|
174 |
+
thing_id_mapping = PASCAL_459_THING_CLASS_ID
|
175 |
+
stuff_id_mapping = PASCAL_459_STUFF_CLASS_ID
|
176 |
+
elif cfg.test.ds_name == 'ade_847':
|
177 |
+
thing_id_mapping = ADE_847_THING_CLASS_ID
|
178 |
+
stuff_id_mapping = ADE_847_STUFF_CLASS_ID
|
179 |
+
else:
|
180 |
+
raise ValueError(f'Dataset {cfg.test.ds_name} not supported')
|
181 |
+
|
182 |
+
model.mask_generator.set_bg_cls(label_space)
|
183 |
+
model.set_visual_prompt_type(cfg.car.stuff_visual_prompt_type)
|
184 |
+
model.set_bg_factor(cfg.car.stuff_bg_factor)
|
185 |
+
stuff_masks, stuff_scores = inference_car(
|
186 |
+
cfg, model, image_path, stuff_label_space, sam_pipeline=sam_pipeline
|
187 |
+
)
|
188 |
+
model.mask_generator.set_bg_cls(cfg.car.bg_cls)
|
189 |
+
model.set_visual_prompt_type(cfg.car.visual_prompt_type)
|
190 |
+
model.set_bg_factor(cfg.car.bg_factor)
|
191 |
+
all_masks = all_masks.detach().cpu().numpy()
|
192 |
+
stuff_masks = stuff_masks.detach().cpu().numpy()
|
193 |
+
label_preds = assign_label(
|
194 |
+
all_masks,
|
195 |
+
scores,
|
196 |
+
stuff_masks=stuff_masks,
|
197 |
+
stuff_scores=stuff_scores,
|
198 |
+
id_mapping=thing_id_mapping,
|
199 |
+
stuff_id_mapping=stuff_id_mapping,
|
200 |
+
)
|
201 |
+
else:
|
202 |
+
all_masks = all_masks.detach().cpu().numpy()
|
203 |
+
label_preds = assign_label(all_masks, scores)
|
204 |
+
return label_preds.squeeze()
|
205 |
+
else:
|
206 |
+
raise NotImplementedError(f'algo {algo} not implemented')
|
207 |
+
|
208 |
+
|
209 |
+
def _fast_hist(label_true, label_pred, n_class=21):
|
210 |
+
mask = (label_true >= 0) & (label_true < n_class)
|
211 |
+
hist = np.bincount(
|
212 |
+
n_class * label_true[mask].astype(int) + label_pred[mask],
|
213 |
+
minlength=n_class**2,
|
214 |
+
).reshape(n_class, n_class)
|
215 |
+
return hist
|
216 |
+
|
217 |
+
|
218 |
+
def semantic_iou(label_trues, label_preds, n_class=21, ignore_background=False):
|
219 |
+
"""Semantic segmentation IOU."""
|
220 |
+
|
221 |
+
hist = np.zeros((n_class, n_class))
|
222 |
+
for lt, lp in zip(label_trues, label_preds):
|
223 |
+
hist += _fast_hist(lt.flatten(), lp.flatten(), n_class)
|
224 |
+
if ignore_background:
|
225 |
+
hist = hist[1:, 1:]
|
226 |
+
acc = np.diag(hist).sum() / hist.sum()
|
227 |
+
acc_cls = np.diag(hist) / hist.sum(axis=1)
|
228 |
+
acc_cls = np.nanmean(acc_cls)
|
229 |
+
iu = np.diag(hist) / (hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist))
|
230 |
+
valid = hist.sum(axis=1) > 0 # added
|
231 |
+
if valid.sum() == 0:
|
232 |
+
mean_iu = 0
|
233 |
+
else:
|
234 |
+
mean_iu = np.nanmean(iu[valid])
|
235 |
+
freq = hist.sum(axis=1) / hist.sum()
|
236 |
+
fwavacc = (freq[freq > 0] * iu[freq > 0]).sum()
|
237 |
+
if ignore_background:
|
238 |
+
cls_iu = dict(zip(range(1, n_class), iu))
|
239 |
+
else:
|
240 |
+
cls_iu = dict(zip(range(n_class), iu))
|
241 |
+
|
242 |
+
return {
|
243 |
+
'Pixel Accuracy': acc,
|
244 |
+
'Mean Accuracy': acc_cls,
|
245 |
+
'Frequency Weighted IoU': fwavacc,
|
246 |
+
'mIoU': mean_iu,
|
247 |
+
'Class IoU': cls_iu,
|
248 |
+
}
|
249 |
+
|
250 |
+
|
251 |
+
def evaluate(
|
252 |
+
data_loader,
|
253 |
+
cfg,
|
254 |
+
model,
|
255 |
+
test_cfg,
|
256 |
+
label_space=None,
|
257 |
+
stuff_label_space=None,
|
258 |
+
sam_pipeline=None,
|
259 |
+
):
|
260 |
+
"""Run evaluation."""
|
261 |
+
|
262 |
+
if (
|
263 |
+
test_cfg.ds_name
|
264 |
+
not in ['voc', 'cocostuff', 'context', 'ade', 'pascal_459', 'ade_847']
|
265 |
+
and test_cfg.seg_mode == 'semantic'
|
266 |
+
):
|
267 |
+
raise ValueError((
|
268 |
+
'Semantic segmentation evaluation is only implemented for voc, '
|
269 |
+
'context, coco object, ade, pascal459, ade847 dataset'
|
270 |
+
))
|
271 |
+
|
272 |
+
metric_logger = MetricLogger(delimiter=' ')
|
273 |
+
metric_logger.add_meter(
|
274 |
+
'mIoU', SmoothedValue(window_size=1, fmt='{value:.4f} ({global_avg:.4f})')
|
275 |
+
)
|
276 |
+
# evaluation variables
|
277 |
+
cum_i, cum_u = 0, 0
|
278 |
+
eval_seg_iou_list = [0.5, 0.6, 0.7, 0.8, 0.9]
|
279 |
+
seg_correct = np.zeros(len(eval_seg_iou_list), dtype=np.int32)
|
280 |
+
seg_total = 0
|
281 |
+
mean_iou = []
|
282 |
+
header = 'Test:'
|
283 |
+
|
284 |
+
# all_masks = []
|
285 |
+
label_preds, label_gts = [], []
|
286 |
+
print(len(data_loader))
|
287 |
+
cc = 0
|
288 |
+
use_tensorboard = False
|
289 |
+
if hasattr(cfg.test, 'use_tensorboard'):
|
290 |
+
use_tensorboard = cfg.test.use_tensorboard
|
291 |
+
|
292 |
+
if use_tensorboard:
|
293 |
+
writer = tensorboard.SummaryWriter(log_dir=cfg.test.output_path)
|
294 |
+
for data in metric_logger.log_every(data_loader, 1, header):
|
295 |
+
_, image_paths, target_list, sentences_list = data
|
296 |
+
# print(type(target_lis))
|
297 |
+
|
298 |
+
if not isinstance(target_list, list):
|
299 |
+
target_list, sentences_list = [target_list], [sentences_list]
|
300 |
+
for target, sentences in zip(target_list, sentences_list):
|
301 |
+
image_path = image_paths[0]
|
302 |
+
# print(image_path)
|
303 |
+
if test_cfg.seg_mode == 'refer':
|
304 |
+
all_masks, all_scores = inference_car(
|
305 |
+
cfg, model, image_path, sentences, sam_pipeline=sam_pipeline
|
306 |
+
)
|
307 |
+
# final_mask = merge_masks(all_masks, *target.shape[1:])
|
308 |
+
final_mask = merge_masks_simple(
|
309 |
+
all_masks, *target.shape[1:], scores=all_scores
|
310 |
+
)
|
311 |
+
intersection, union, cur_iou = compute_iou(final_mask, target)
|
312 |
+
# cur_iou = IoU(final_mask, target, 0)
|
313 |
+
metric_logger.update(mIoU=cur_iou)
|
314 |
+
mean_iou.append(cur_iou)
|
315 |
+
if use_tensorboard:
|
316 |
+
writer.add_scalar('Mean IoU', cur_iou, cc)
|
317 |
+
cum_i += intersection
|
318 |
+
cum_u += union
|
319 |
+
for n_eval_iou in range(len(eval_seg_iou_list)):
|
320 |
+
eval_seg_iou = eval_seg_iou_list[n_eval_iou]
|
321 |
+
seg_correct[n_eval_iou] += cur_iou >= eval_seg_iou
|
322 |
+
seg_total += 1
|
323 |
+
elif test_cfg.seg_mode == 'semantic':
|
324 |
+
# torch.cuda.empty_cache()
|
325 |
+
label_pred = eval_semantic(
|
326 |
+
label_space,
|
327 |
+
test_cfg.algo,
|
328 |
+
cfg,
|
329 |
+
model,
|
330 |
+
image_path,
|
331 |
+
stuff_label_space,
|
332 |
+
)
|
333 |
+
label_gt = target.squeeze().cpu().numpy()
|
334 |
+
cur_iou = semantic_iou(
|
335 |
+
[label_gt],
|
336 |
+
[label_pred],
|
337 |
+
n_class=cfg.test.n_class,
|
338 |
+
ignore_background=cfg.test.ignore_background,
|
339 |
+
)['mIoU']
|
340 |
+
metric_logger.update(mIoU=cur_iou)
|
341 |
+
label_preds.append(label_pred)
|
342 |
+
label_gts.append(label_gt)
|
343 |
+
|
344 |
+
cc += 1
|
345 |
+
|
346 |
+
if test_cfg.seg_mode == 'refer':
|
347 |
+
mean_iou = np.array(mean_iou)
|
348 |
+
m_iou = np.mean(mean_iou)
|
349 |
+
if use_tensorboard:
|
350 |
+
writer.add_scalar('mIoU', m_iou.item(), len(data_loader))
|
351 |
+
print('Final results:')
|
352 |
+
print('Mean IoU is %.2f\n' % (m_iou * 100.0))
|
353 |
+
results_str = ''
|
354 |
+
for n_eval_iou in range(len(eval_seg_iou_list)):
|
355 |
+
results_str += ' precision@%s = %.2f\n' % (
|
356 |
+
str(eval_seg_iou_list[n_eval_iou]),
|
357 |
+
seg_correct[n_eval_iou] * 100.0 / seg_total,
|
358 |
+
)
|
359 |
+
o_iou = cum_i * 100.0 / cum_u
|
360 |
+
results_str += ' overall IoU = %.2f\n' % o_iou
|
361 |
+
if use_tensorboard:
|
362 |
+
writer.add_scalar('oIoU', o_iou, 0)
|
363 |
+
print(results_str)
|
364 |
+
elif test_cfg.seg_mode == 'semantic':
|
365 |
+
iou_score = semantic_iou(
|
366 |
+
label_gts,
|
367 |
+
label_preds,
|
368 |
+
n_class=cfg.test.n_class,
|
369 |
+
ignore_background=cfg.test.ignore_background,
|
370 |
+
)
|
371 |
+
if use_tensorboard:
|
372 |
+
writer.add_scalar('mIoU', iou_score['mIoU'].item(), len(data_loader))
|
373 |
+
|
374 |
+
print(iou_score)
|
375 |
+
if use_tensorboard:
|
376 |
+
writer.close()
|
377 |
+
|
378 |
+
|
379 |
+
def compute_iou(pred_seg, gd_seg):
|
380 |
+
"""Compute IoU."""
|
381 |
+
intersection = torch.sum(torch.logical_and(pred_seg, gd_seg))
|
382 |
+
union = torch.sum(torch.logical_or(pred_seg, gd_seg))
|
383 |
+
iou = intersection * 1.0 / union
|
384 |
+
if union == 0:
|
385 |
+
iou = 0
|
386 |
+
return intersection, union, iou
|
387 |
+
|
388 |
+
|
389 |
+
def list_of_strings(arg):
|
390 |
+
return [a.strip() for a in arg.split(',')]
|
391 |
+
|
392 |
+
|
393 |
+
# pylint: disable=redefined-outer-name
|
394 |
+
def parse_args():
|
395 |
+
"""Parse arguments."""
|
396 |
+
parser = argparse.ArgumentParser(description='Training')
|
397 |
+
parser.add_argument(
|
398 |
+
'--cfg-path',
|
399 |
+
default='configs/refcoco_test_prompt.yaml',
|
400 |
+
help='path to configuration file.',
|
401 |
+
)
|
402 |
+
parser.add_argument('--index', default=0, type=int, help='split task')
|
403 |
+
parser.add_argument('--mask_threshold', default=0.0, type=float)
|
404 |
+
parser.add_argument('--confidence_threshold', default=0.0, type=float)
|
405 |
+
parser.add_argument('--clipes_threshold', default=0.0, type=float)
|
406 |
+
parser.add_argument('--stuff_bg_factor', default=0.0, type=float)
|
407 |
+
parser.add_argument('--bg_factor', default=0.0, type=float)
|
408 |
+
parser.add_argument('--output_path', default=None, type=str)
|
409 |
+
parser.add_argument(
|
410 |
+
'--visual_prompt_type', default=None, type=list_of_strings
|
411 |
+
)
|
412 |
+
parser.add_argument(
|
413 |
+
'--stuff_visual_prompt_type', default=None, type=list_of_strings
|
414 |
+
)
|
415 |
+
|
416 |
+
args = parser.parse_args()
|
417 |
+
|
418 |
+
return args
|
419 |
+
|
420 |
+
|
421 |
+
def main(args):
|
422 |
+
cfg = Config(**load_yaml(args.cfg_path))
|
423 |
+
if args.mask_threshold > 0:
|
424 |
+
cfg.car.mask_threshold = args.mask_threshold
|
425 |
+
if args.confidence_threshold > 0:
|
426 |
+
cfg.car.confidence_threshold = args.confidence_threshold
|
427 |
+
if args.clipes_threshold > 0:
|
428 |
+
cfg.car.clipes_threshold = args.clipes_threshold
|
429 |
+
if args.bg_factor > 0:
|
430 |
+
cfg.car.bg_factor = args.bg_factor
|
431 |
+
if args.stuff_bg_factor > 0:
|
432 |
+
cfg.car.stuff_bg_factor = args.stuff_bg_factor
|
433 |
+
if args.output_path is not None:
|
434 |
+
cfg.test.output_path = args.output_path
|
435 |
+
if args.visual_prompt_type is not None:
|
436 |
+
cfg.car.visual_prompt_type = args.visual_prompt_type
|
437 |
+
if args.stuff_visual_prompt_type is not None:
|
438 |
+
cfg.car.stuff_visual_prompt_type = args.stuff_visual_prompt_type
|
439 |
+
|
440 |
+
try:
|
441 |
+
data_root = cfg.test.data_root
|
442 |
+
except ValueError:
|
443 |
+
data_root = None
|
444 |
+
|
445 |
+
dataset_test = get_dataset(
|
446 |
+
cfg, cfg.test.ds_name, cfg.test.split, get_transform(), data_root
|
447 |
+
)
|
448 |
+
|
449 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
450 |
+
|
451 |
+
stuff_label_space = None
|
452 |
+
if cfg.test.ds_name == 'voc':
|
453 |
+
label_space = VOC_CLASSES
|
454 |
+
elif cfg.test.ds_name == 'cocostuff':
|
455 |
+
label_space = COCO_OBJECT_CLASSES
|
456 |
+
elif cfg.test.ds_name == 'context':
|
457 |
+
# label_space = PASCAL_CONTEXT_CLASSES
|
458 |
+
label_space = PASCAL_CONTEXT_THING_CLASS
|
459 |
+
stuff_label_space = PASCAL_CONTEXT_STUFF_CLASS
|
460 |
+
elif cfg.test.ds_name == 'ade':
|
461 |
+
label_space = ADE_THING_CLASS
|
462 |
+
stuff_label_space = ADE_STUFF_CLASS
|
463 |
+
elif cfg.test.ds_name == 'pascal_459':
|
464 |
+
label_space = PASCAL_459_THING_CLASS
|
465 |
+
stuff_label_space = PASCAL_459_STUFF_CLASS
|
466 |
+
elif cfg.test.ds_name == 'ade_847':
|
467 |
+
label_space = ADE_847_THING_CLASS
|
468 |
+
stuff_label_space = ADE_847_STUFF_CLASS
|
469 |
+
else:
|
470 |
+
label_space = None
|
471 |
+
|
472 |
+
num_chunks, chunk_index = 1, 0
|
473 |
+
if hasattr(cfg.test, 'num_chunks'):
|
474 |
+
num_chunks = cfg.test.num_chunks
|
475 |
+
if hasattr(cfg.test, 'chunk_index'):
|
476 |
+
chunk_index = cfg.test.chunk_index
|
477 |
+
# Size of each chunk
|
478 |
+
chunk_size = len(dataset_test) // num_chunks
|
479 |
+
# Choose which chunk to load (0-indexed)
|
480 |
+
# Define a subset of the dataset
|
481 |
+
subset_indices = range(
|
482 |
+
chunk_index * chunk_size, (chunk_index + 1) * chunk_size
|
483 |
+
)
|
484 |
+
subset_dataset = Subset(dataset_test, indices=subset_indices)
|
485 |
+
|
486 |
+
data_loader_test = torch.utils.data.DataLoader(
|
487 |
+
subset_dataset, batch_size=1, shuffle=False, num_workers=1
|
488 |
+
)
|
489 |
+
|
490 |
+
car_model = CaR(cfg, device=device, seg_mode=cfg.test.seg_mode)
|
491 |
+
|
492 |
+
car_model = car_model.to(device)
|
493 |
+
|
494 |
+
if not cfg.test.use_pseudo and cfg.test.sam_mask_root is None:
|
495 |
+
print('Using sam online')
|
496 |
+
# sam_checkpoint, model_type = build_sam_config(cfg)
|
497 |
+
build_sam_config(cfg)
|
498 |
+
|
499 |
+
evaluate(
|
500 |
+
data_loader_test,
|
501 |
+
cfg,
|
502 |
+
car_model,
|
503 |
+
test_cfg=cfg.test,
|
504 |
+
label_space=label_space,
|
505 |
+
stuff_label_space=stuff_label_space,
|
506 |
+
)
|
507 |
+
|
508 |
+
|
509 |
+
if __name__ == '__main__':
|
510 |
+
args = parse_args()
|
511 |
+
main(args)
|
modeling/__init__.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2024 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
modeling/model/cam.py
ADDED
@@ -0,0 +1,222 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2024 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""Get CAM activation."""
|
17 |
+
|
18 |
+
import cv2
|
19 |
+
import numpy as np
|
20 |
+
import torch
|
21 |
+
|
22 |
+
|
23 |
+
_EPSILON = 1e-15
|
24 |
+
|
25 |
+
|
26 |
+
def scale_cam_image(cam, target_size=None):
|
27 |
+
"""Normalize and rescale cam image."""
|
28 |
+
result = []
|
29 |
+
for img in cam:
|
30 |
+
img = img - np.min(img)
|
31 |
+
img = img / (_EPSILON + np.max(img))
|
32 |
+
if target_size is not None:
|
33 |
+
img = cv2.resize(img, target_size)
|
34 |
+
result.append(img)
|
35 |
+
result = np.float32(result)
|
36 |
+
|
37 |
+
return result
|
38 |
+
|
39 |
+
|
40 |
+
class ActivationsAndGradients:
|
41 |
+
"""Class for extracting activations and registering gradients from targetted intermediate layers."""
|
42 |
+
|
43 |
+
def __init__(self, model, target_layers, reshape_transform, stride=16):
|
44 |
+
self.model = model
|
45 |
+
self.gradients = []
|
46 |
+
self.activations = []
|
47 |
+
self.reshape_transform = reshape_transform
|
48 |
+
self.handles = []
|
49 |
+
self.stride = stride
|
50 |
+
for target_layer in target_layers:
|
51 |
+
self.handles.append(
|
52 |
+
target_layer.register_forward_hook(self.save_activation)
|
53 |
+
)
|
54 |
+
# Because of https://github.com/pytorch/pytorch/issues/61519,
|
55 |
+
# we don't use backward hook to record gradients.
|
56 |
+
self.handles.append(
|
57 |
+
target_layer.register_forward_hook(self.save_gradient)
|
58 |
+
)
|
59 |
+
|
60 |
+
# pylint: disable=unused-argument
|
61 |
+
# pylint: disable=redefined-builtin
|
62 |
+
def save_activation(self, module, input, output):
|
63 |
+
"""Saves activations from targetted layer."""
|
64 |
+
activation = output
|
65 |
+
|
66 |
+
if self.reshape_transform is not None:
|
67 |
+
activation = self.reshape_transform(activation, self.height, self.width)
|
68 |
+
self.activations.append(activation.cpu().detach())
|
69 |
+
|
70 |
+
def save_gradient(self, module, input, output):
|
71 |
+
if not hasattr(output, "requires_grad") or not output.requires_grad:
|
72 |
+
# You can only register hooks on tensor requires grad.
|
73 |
+
return
|
74 |
+
|
75 |
+
# Gradients are computed in reverse order
|
76 |
+
def _store_grad(grad):
|
77 |
+
if self.reshape_transform is not None:
|
78 |
+
grad = self.reshape_transform(grad, self.height, self.width)
|
79 |
+
self.gradients = [grad.cpu().detach()] + self.gradients
|
80 |
+
|
81 |
+
output.register_hook(_store_grad)
|
82 |
+
|
83 |
+
# pylint: enable=unused-argument
|
84 |
+
# pylint: enable=redefined-builtin
|
85 |
+
|
86 |
+
def __call__(self, x, h, w):
|
87 |
+
self.height = h // self.stride
|
88 |
+
self.width = w // self.stride
|
89 |
+
self.gradients = []
|
90 |
+
self.activations = []
|
91 |
+
if isinstance(x, tuple) or isinstance(x, list):
|
92 |
+
return self.model.forward_last_layer(x[0], x[1])
|
93 |
+
else:
|
94 |
+
return self.model(x)
|
95 |
+
|
96 |
+
def release(self):
|
97 |
+
for handle in self.handles:
|
98 |
+
handle.remove()
|
99 |
+
|
100 |
+
|
101 |
+
# pylint: disable=g-bare-generic
|
102 |
+
class CAM:
|
103 |
+
"""CAM module."""
|
104 |
+
|
105 |
+
def __init__(
|
106 |
+
self,
|
107 |
+
model,
|
108 |
+
target_layers,
|
109 |
+
use_cuda=False,
|
110 |
+
reshape_transform=None,
|
111 |
+
compute_input_gradient=False,
|
112 |
+
stride=16,
|
113 |
+
):
|
114 |
+
self.model = model.eval()
|
115 |
+
self.target_layers = target_layers
|
116 |
+
self.cuda = use_cuda
|
117 |
+
self.model = model.cuda() if self.cuda else self.model
|
118 |
+
self.reshape_transform = reshape_transform
|
119 |
+
self.compute_input_gradient = compute_input_gradient
|
120 |
+
self.activations_and_grads = ActivationsAndGradients(
|
121 |
+
self.model, target_layers, reshape_transform, stride=stride
|
122 |
+
)
|
123 |
+
|
124 |
+
def get_cam(self, activations, grads):
|
125 |
+
weights = np.mean(grads, axis=(2, 3))
|
126 |
+
weighted_activations = weights[:, :, None, None] * activations
|
127 |
+
cam = weighted_activations.sum(axis=1)
|
128 |
+
return cam
|
129 |
+
|
130 |
+
def forward(
|
131 |
+
self,
|
132 |
+
input_tensor,
|
133 |
+
targets,
|
134 |
+
target_size,
|
135 |
+
):
|
136 |
+
"""CAM forward pass."""
|
137 |
+
if self.compute_input_gradient:
|
138 |
+
input_tensor = torch.autograd.Variable(input_tensor, requires_grad=True)
|
139 |
+
|
140 |
+
w, h = self.get_target_width_height(input_tensor)
|
141 |
+
outputs = self.activations_and_grads(input_tensor, h, w)
|
142 |
+
|
143 |
+
self.model.zero_grad()
|
144 |
+
if isinstance(input_tensor, (tuple, list)):
|
145 |
+
loss = sum(
|
146 |
+
[target(output[0]) for target, output in zip(targets, outputs)]
|
147 |
+
)
|
148 |
+
else:
|
149 |
+
loss = sum([target(output) for target, output in zip(targets, outputs)])
|
150 |
+
loss.backward(retain_graph=True)
|
151 |
+
cam_per_layer = self.compute_cam_per_layer(target_size)
|
152 |
+
if isinstance(input_tensor, (tuple, list)):
|
153 |
+
return (
|
154 |
+
self.aggregate_multi_layers(cam_per_layer),
|
155 |
+
outputs[0],
|
156 |
+
outputs[1],
|
157 |
+
)
|
158 |
+
else:
|
159 |
+
return self.aggregate_multi_layers(cam_per_layer), outputs
|
160 |
+
|
161 |
+
def get_target_width_height(self, input_tensor):
|
162 |
+
width = None
|
163 |
+
height = None
|
164 |
+
if isinstance(input_tensor, (tuple, list)):
|
165 |
+
width, height = input_tensor[-1], input_tensor[-2]
|
166 |
+
return width, height
|
167 |
+
|
168 |
+
def compute_cam_per_layer(self, target_size):
|
169 |
+
"""Computes cam per target layer."""
|
170 |
+
activations_list = [
|
171 |
+
a.cpu().data.numpy() for a in self.activations_and_grads.activations
|
172 |
+
]
|
173 |
+
grads_list = [
|
174 |
+
g.cpu().data.numpy() for g in self.activations_and_grads.gradients
|
175 |
+
]
|
176 |
+
|
177 |
+
cam_per_target_layer = []
|
178 |
+
# Loop over the saliency image from every layer
|
179 |
+
for i in range(len(self.target_layers)):
|
180 |
+
layer_activations = None
|
181 |
+
layer_grads = None
|
182 |
+
if i < len(activations_list):
|
183 |
+
layer_activations = activations_list[i]
|
184 |
+
if i < len(grads_list):
|
185 |
+
layer_grads = grads_list[i]
|
186 |
+
|
187 |
+
cam = self.get_cam(layer_activations, layer_grads)
|
188 |
+
cam = np.maximum(cam, 0).astype(np.float32) # float16->32
|
189 |
+
scaled = scale_cam_image(cam, target_size)
|
190 |
+
cam_per_target_layer.append(scaled[:, None, :])
|
191 |
+
|
192 |
+
return cam_per_target_layer
|
193 |
+
|
194 |
+
def aggregate_multi_layers(self, cam_per_target_layer):
|
195 |
+
cam_per_target_layer = np.concatenate(cam_per_target_layer, axis=1)
|
196 |
+
cam_per_target_layer = np.maximum(cam_per_target_layer, 0)
|
197 |
+
result = np.mean(cam_per_target_layer, axis=1)
|
198 |
+
return scale_cam_image(result)
|
199 |
+
|
200 |
+
def __call__(
|
201 |
+
self,
|
202 |
+
input_tensor,
|
203 |
+
targets=None,
|
204 |
+
target_size=None,
|
205 |
+
):
|
206 |
+
return self.forward(input_tensor, targets, target_size)
|
207 |
+
|
208 |
+
def __del__(self):
|
209 |
+
self.activations_and_grads.release()
|
210 |
+
|
211 |
+
def __enter__(self):
|
212 |
+
return self
|
213 |
+
|
214 |
+
def __exit__(self, exc_type, exc_value, exc_tb):
|
215 |
+
self.activations_and_grads.release()
|
216 |
+
if isinstance(exc_value, IndexError):
|
217 |
+
# Handle IndexError here...
|
218 |
+
print(
|
219 |
+
f"An exception occurred in CAM with block: {exc_type}. "
|
220 |
+
f"Message: {exc_value}"
|
221 |
+
)
|
222 |
+
return True
|
modeling/model/car.py
ADDED
@@ -0,0 +1,318 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2024 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""Implementation of CaR."""
|
17 |
+
|
18 |
+
import os
|
19 |
+
|
20 |
+
import clip
|
21 |
+
import numpy as np
|
22 |
+
import torch
|
23 |
+
from torch import nn
|
24 |
+
import torch.nn.functional as F
|
25 |
+
|
26 |
+
# pylint: disable=g-importing-member
|
27 |
+
# pylint: disable=g-bad-import-order
|
28 |
+
from modeling.model.clip_wrapper import CLIPWrapper
|
29 |
+
from modeling.model.clip_wrapper import forward_clip
|
30 |
+
from modeling.model.clipcam import CLIPCAM
|
31 |
+
from modeling.model.crf import PostProcess
|
32 |
+
from modeling.model.utils import apply_visual_prompts
|
33 |
+
from utils.visualize import viz_attn
|
34 |
+
|
35 |
+
|
36 |
+
class CaR(nn.Module):
|
37 |
+
"""CaR module."""
|
38 |
+
|
39 |
+
def __init__(
|
40 |
+
self,
|
41 |
+
cfg,
|
42 |
+
device="cpu",
|
43 |
+
visualize=False,
|
44 |
+
confidence_threshold=0.45,
|
45 |
+
save_path="save_path",
|
46 |
+
seg_mode="refer",
|
47 |
+
semantic_clip_model_name=None,
|
48 |
+
semantic_pretrained_data=None,
|
49 |
+
semantic_templates=None,
|
50 |
+
text_template=None,
|
51 |
+
visual_prompt_type="circle",
|
52 |
+
clipes_threshold=0.4,
|
53 |
+
cam_text_template="a clean origami {}.",
|
54 |
+
bg_cls=None,
|
55 |
+
iom_thres=0.6,
|
56 |
+
min_pred_threshold=0.01,
|
57 |
+
bg_factor=1.0,
|
58 |
+
mask_threshold=0.5,
|
59 |
+
):
|
60 |
+
"""CaR model for image segmentation.
|
61 |
+
|
62 |
+
Args:
|
63 |
+
cfg: the config file.
|
64 |
+
device: the device to run the model.
|
65 |
+
visualize: whether to visualize the intermediate results
|
66 |
+
confidence_threshold: the confidence threshold for semantic
|
67 |
+
segmentation. If the confidence score is lower than the threshold, the
|
68 |
+
mask will be discarded.
|
69 |
+
save_path: the path to save the intermediate results
|
70 |
+
seg_mode: the segmentation mode, can be 'refer' or 'semantic'
|
71 |
+
semantic_clip_model_name: the name of the semantic segmentation model.
|
72 |
+
semantic_pretrained_data: the path to the pretrained semantic
|
73 |
+
segmentation model.
|
74 |
+
semantic_templates: the templates for semantic segmentation.
|
75 |
+
text_template: the template for visual prompting.
|
76 |
+
visual_prompt_type: the type of visual prompting.
|
77 |
+
clipes_threshold: the threshold for CLIPES.
|
78 |
+
cam_text_template: the template for CAM.
|
79 |
+
bg_cls: background classes.
|
80 |
+
iom_thres: IoM threshold.
|
81 |
+
min_pred_threshold: Prediction threshold.
|
82 |
+
bg_factor: Background factor.
|
83 |
+
mask_threshold: Mask threshold.
|
84 |
+
"""
|
85 |
+
super(CaR, self).__init__()
|
86 |
+
# CLIP parameters
|
87 |
+
self.confidence_threshold = confidence_threshold
|
88 |
+
self.device = device
|
89 |
+
self.visualize = visualize
|
90 |
+
self.save_path = save_path
|
91 |
+
self.seg_mode = seg_mode
|
92 |
+
self.semantic_clip_model_name = semantic_clip_model_name
|
93 |
+
self.semantic_pretrained_data = semantic_pretrained_data
|
94 |
+
self.semantic_templates = semantic_templates
|
95 |
+
self.text_template = text_template
|
96 |
+
self.visual_prompt_type = visual_prompt_type
|
97 |
+
self.clipes_threshold = clipes_threshold
|
98 |
+
self.cam_text_template = cam_text_template
|
99 |
+
self.iom_thres = iom_thres
|
100 |
+
self.min_pred_threshold = min_pred_threshold
|
101 |
+
self.bg_cls = bg_cls
|
102 |
+
self.bg_factor = bg_factor
|
103 |
+
self.mask_threshold = mask_threshold
|
104 |
+
|
105 |
+
if not hasattr(cfg, "clip"):
|
106 |
+
raise ValueError("The config file should contain the CLIP parameters.")
|
107 |
+
|
108 |
+
if not hasattr(cfg, "car"):
|
109 |
+
raise ValueError("The config file should contain the car parameters.")
|
110 |
+
|
111 |
+
if hasattr(cfg, "cam"):
|
112 |
+
raise ValueError("cfg.cam is deprecated, please use cfg.car ")
|
113 |
+
|
114 |
+
for k, v in vars(cfg.clip).items():
|
115 |
+
setattr(self, k, v)
|
116 |
+
|
117 |
+
for k, v in vars(cfg.car).items():
|
118 |
+
setattr(self, k, v)
|
119 |
+
|
120 |
+
if hasattr(cfg, "sam"):
|
121 |
+
for k, v in vars(cfg.sam).items():
|
122 |
+
setattr(self, k, v)
|
123 |
+
if not self.bg_cls:
|
124 |
+
self.bg_cls = None
|
125 |
+
print(f"The model is running on {self.device}")
|
126 |
+
self.clip_model, self.preprocess = clip.load(
|
127 |
+
self.clip_model_name, device=self.device
|
128 |
+
)
|
129 |
+
self.clip_model = CLIPWrapper(self.clip_model)
|
130 |
+
self.post_process = PostProcess(device=self.device)
|
131 |
+
self.mask_generator = CLIPCAM(
|
132 |
+
self.clip_model,
|
133 |
+
device=self.device,
|
134 |
+
text_template=self.text_template,
|
135 |
+
threshold=self.clipes_threshold,
|
136 |
+
bg_cls=self.bg_cls,
|
137 |
+
)
|
138 |
+
self.semantic_clip_model, self.semantic_preprocess = clip.load(
|
139 |
+
self.semantic_clip_model_name, device=self.device
|
140 |
+
)
|
141 |
+
self.semantic_clip_model = CLIPWrapper(self.semantic_clip_model)
|
142 |
+
|
143 |
+
def get_confidence(self, cam_map, binary_cam_map):
|
144 |
+
confidence_map = torch.sum(cam_map * binary_cam_map[None], dim=[2, 3])
|
145 |
+
confidence_map = confidence_map / torch.sum(binary_cam_map, dim=[1, 2])
|
146 |
+
confidence_score = confidence_map.squeeze()
|
147 |
+
return confidence_score
|
148 |
+
|
149 |
+
def set_visual_prompt_type(self, visual_prompt_type):
|
150 |
+
self.visual_prompt_type = visual_prompt_type
|
151 |
+
|
152 |
+
def set_bg_factor(self, bg_factor):
|
153 |
+
self.bg_factor = bg_factor
|
154 |
+
|
155 |
+
def set_confidence_threshold(self, confidence_threshold):
|
156 |
+
self.confidence_threshold = confidence_threshold
|
157 |
+
|
158 |
+
def set_mask_threshold(self, mask_threshold):
|
159 |
+
self.mask_threshold = mask_threshold
|
160 |
+
|
161 |
+
def apply_visual_prompts(self, image, mask):
|
162 |
+
if torch.sum(mask).item() <= 1:
|
163 |
+
return image
|
164 |
+
image_array = np.array(image)
|
165 |
+
img_h = image_array.shape[0]
|
166 |
+
img_w = image_array.shape[1]
|
167 |
+
mask = (
|
168 |
+
F.interpolate(mask[None][None], size=(img_h, img_w), mode="nearest")
|
169 |
+
.squeeze()
|
170 |
+
.detach()
|
171 |
+
.cpu()
|
172 |
+
.numpy()
|
173 |
+
)
|
174 |
+
mask = (mask > self.mask_threshold).astype(np.uint8)
|
175 |
+
prompted_image = apply_visual_prompts(
|
176 |
+
image_array, mask, self.visual_prompt_type, self.visualize
|
177 |
+
)
|
178 |
+
return prompted_image
|
179 |
+
|
180 |
+
def get_mask_confidence(self, prompted_images, prompt_text):
|
181 |
+
"""Get the confidene for each mask with visual prompting."""
|
182 |
+
# get the center, width and height of the mask
|
183 |
+
prompted_tensor = torch.stack(
|
184 |
+
[self.semantic_preprocess(img) for img in prompted_images], dim=0
|
185 |
+
)
|
186 |
+
prompted_tensor = prompted_tensor.to(self.device)
|
187 |
+
h, w = prompted_tensor.shape[-2:]
|
188 |
+
text_prediction = forward_clip(
|
189 |
+
self.semantic_clip_model, prompted_tensor, prompt_text, h, w
|
190 |
+
)
|
191 |
+
return text_prediction
|
192 |
+
|
193 |
+
def _filter_texts(self, ori_mask_id, sem_scores, prompt_text):
|
194 |
+
"""Remove false positive masks by score filtering and recall the backbone to get the CAM maps for the filtered texts."""
|
195 |
+
if not ori_mask_id:
|
196 |
+
max_id = np.argmax(sem_scores)
|
197 |
+
ori_mask_id.append(max_id)
|
198 |
+
filtered_text = [prompt_text[i] for i in ori_mask_id]
|
199 |
+
return filtered_text
|
200 |
+
|
201 |
+
def _forward_stage(self, ori_img, cam_text, clip_text, semantic_prompt_text):
|
202 |
+
mask_proposals = self.get_mask_proposals(ori_img, cam_text)
|
203 |
+
num_texts = len(cam_text)
|
204 |
+
ori_mask_id = []
|
205 |
+
sem_scores = torch.zeros((num_texts,), device=self.device).float()
|
206 |
+
prompted_imgs = [
|
207 |
+
self.apply_visual_prompts(ori_img, cam_map)
|
208 |
+
for cam_map in mask_proposals
|
209 |
+
]
|
210 |
+
text_scores = self.get_mask_confidence(prompted_imgs, semantic_prompt_text)
|
211 |
+
mask_scores = torch.diagonal(text_scores)
|
212 |
+
for mask_idx, mask_score in enumerate(mask_scores):
|
213 |
+
# record mask idx
|
214 |
+
if mask_score > self.confidence_threshold:
|
215 |
+
ori_mask_id.append(mask_idx)
|
216 |
+
sem_scores[mask_idx] = mask_score
|
217 |
+
sem_scores = sem_scores.cpu().detach().numpy()
|
218 |
+
filtered_texts = self._filter_texts(ori_mask_id, sem_scores, clip_text)
|
219 |
+
# if isinstance(ori_img, list):
|
220 |
+
# ori_img = [ori_img[i] for i in ori_mask_id]
|
221 |
+
|
222 |
+
all_scores = torch.zeros((num_texts,), device=self.device).float()
|
223 |
+
sem_scores = torch.from_numpy(sem_scores).to(self.device)
|
224 |
+
for new_id, ori_id in enumerate(ori_mask_id):
|
225 |
+
if new_id >= len(mask_proposals):
|
226 |
+
# the mask is filtered out.
|
227 |
+
continue
|
228 |
+
all_scores[ori_id] = sem_scores[ori_id]
|
229 |
+
return filtered_texts, all_scores, mask_proposals
|
230 |
+
|
231 |
+
def _get_save_path(self, text):
|
232 |
+
folder_name = "_".join([t.replace(" ", "_") for t in text])
|
233 |
+
if len(folder_name) > 20:
|
234 |
+
folder_name = folder_name[:20]
|
235 |
+
output_path = os.path.join(self.save_path, folder_name)
|
236 |
+
sub_output_path = [
|
237 |
+
os.path.join(output_path, t.replace(" ", "_")) for t in text
|
238 |
+
]
|
239 |
+
return output_path, sub_output_path
|
240 |
+
|
241 |
+
def get_mask_proposals(self, img, text):
|
242 |
+
if self.seg_mode == "refer":
|
243 |
+
if isinstance(img, list):
|
244 |
+
cam_map_list = [self.mask_generator(i, t)[0] for i, t in zip(img, text)]
|
245 |
+
else:
|
246 |
+
cam_map_list = [self.mask_generator(img, t)[0] for t in text]
|
247 |
+
return torch.cat(cam_map_list, dim=0)
|
248 |
+
elif self.seg_mode == "semantic":
|
249 |
+
return self.mask_generator(img, text)[0]
|
250 |
+
else:
|
251 |
+
raise ValueError(
|
252 |
+
"Unknown segmentation mode. Only refer and semantic segmentation are"
|
253 |
+
" supported."
|
254 |
+
)
|
255 |
+
|
256 |
+
def _forward_car(self, ori_img, text):
|
257 |
+
if isinstance(text, str):
|
258 |
+
text = [text]
|
259 |
+
_, sub_output_path = self._get_save_path(text)
|
260 |
+
image_array = np.array(ori_img)
|
261 |
+
clip_text = [self.cam_text_template.format(t) for t in text]
|
262 |
+
cam_text = text
|
263 |
+
init_clip_text = clip_text # the text prompts of CLIP is different.
|
264 |
+
semantic_prompt_text = clip_text
|
265 |
+
# Apply semantic prompting augmentation.
|
266 |
+
if self.semantic_templates is not None:
|
267 |
+
semantic_prompt_text = []
|
268 |
+
for template in self.semantic_templates:
|
269 |
+
templated_text = [template.format(t) for t in text]
|
270 |
+
semantic_prompt_text.append(templated_text)
|
271 |
+
|
272 |
+
num_positive_last = 0
|
273 |
+
run = 0
|
274 |
+
while True:
|
275 |
+
run += 1
|
276 |
+
cur_texts, all_scores, mask_proposals = self._forward_stage(
|
277 |
+
ori_img, cam_text, clip_text, semantic_prompt_text
|
278 |
+
)
|
279 |
+
if cur_texts: # if there is no text, skip the refinement
|
280 |
+
cam_text = cur_texts
|
281 |
+
clip_text = cur_texts
|
282 |
+
|
283 |
+
num_positive = (all_scores > 0).sum().item()
|
284 |
+
if num_positive == num_positive_last:
|
285 |
+
# stop the refinement if the number of positive masks
|
286 |
+
# does not change.
|
287 |
+
break
|
288 |
+
num_positive_last = num_positive
|
289 |
+
# Apply densecrf for refinement.
|
290 |
+
# SAM is optional and is applied outside the model.
|
291 |
+
refined_masks = self.post_process(
|
292 |
+
ori_img,
|
293 |
+
mask_proposals,
|
294 |
+
separate=self.seg_mode == "refer",
|
295 |
+
bg_factor=self.bg_factor,
|
296 |
+
)
|
297 |
+
predicted_class_idx = [init_clip_text.index(t) for t in cur_texts]
|
298 |
+
if self.visualize:
|
299 |
+
_ = [
|
300 |
+
viz_attn(
|
301 |
+
image_array,
|
302 |
+
attn,
|
303 |
+
prefix=sub_output_path[aid],
|
304 |
+
img_name="semantic_mask",
|
305 |
+
)
|
306 |
+
for aid, attn in enumerate(refined_masks)
|
307 |
+
]
|
308 |
+
final_predicted_masks = torch.zeros(len(text), *refined_masks[0].shape)
|
309 |
+
final_all_scores = torch.zeros(len(text))
|
310 |
+
for idx, mask, score in zip(predicted_class_idx, refined_masks, all_scores):
|
311 |
+
final_predicted_masks[idx] = mask
|
312 |
+
final_all_scores[idx] = score
|
313 |
+
return final_predicted_masks, final_all_scores
|
314 |
+
|
315 |
+
def forward(self, im_ori, text):
|
316 |
+
# raw_image_np is the padded image input with shape (512, 512, 3)
|
317 |
+
pseudo_masks, conf_scores = self._forward_car(im_ori, text)
|
318 |
+
return pseudo_masks, conf_scores
|
modeling/model/clip_wrapper.py
ADDED
@@ -0,0 +1,297 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2024 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""A wrapper for CLIP model to support forward with a list of text inputs."""
|
17 |
+
|
18 |
+
# pylint: disable=g-importing-member
|
19 |
+
import clip
|
20 |
+
import numpy as np
|
21 |
+
import torch
|
22 |
+
from torch import nn
|
23 |
+
import torch.nn.functional as F
|
24 |
+
|
25 |
+
_CONTEXT_LENGTH = 77
|
26 |
+
|
27 |
+
|
28 |
+
def forward_clip_single(model, image, text, h, w):
|
29 |
+
"""Forward a single text input.
|
30 |
+
|
31 |
+
Args:
|
32 |
+
model (CLIPWrapper or CLIP): the CLIP model.
|
33 |
+
image (torch.Tensor): the image tensor.
|
34 |
+
text (List[str]): the text input.
|
35 |
+
h (int): the height of the image.
|
36 |
+
w (int): the width of the image.
|
37 |
+
|
38 |
+
Returns:
|
39 |
+
torch.Tensor: the logits.
|
40 |
+
"""
|
41 |
+
if isinstance(text, str):
|
42 |
+
text = [text]
|
43 |
+
text_tokens = clip.tokenize(text).to(image.device)
|
44 |
+
text_prediction = model(image, text_tokens, h, w)
|
45 |
+
return text_prediction.detach().cpu()
|
46 |
+
|
47 |
+
|
48 |
+
def forward_clip(model, image, text, h, w):
|
49 |
+
"""Forward a list of text inputs.
|
50 |
+
|
51 |
+
Args:
|
52 |
+
model (CLIPWrapper or CLIP): the CLIP model.
|
53 |
+
image (torch.Tensor): the image tensor.
|
54 |
+
text (List[str] or List[List[str]]): the text input.
|
55 |
+
h (int): the height of the image.
|
56 |
+
w (int): the width of the image.
|
57 |
+
|
58 |
+
Returns:
|
59 |
+
torch.Tensor: the logits.
|
60 |
+
"""
|
61 |
+
if isinstance(text[0], list):
|
62 |
+
text_prediction = torch.stack(
|
63 |
+
[forward_clip_single(model, image, t, h, w) for t in text], dim=0
|
64 |
+
)
|
65 |
+
text_prediction = torch.sum(text_prediction, dim=0)
|
66 |
+
text_prediction = F.softmax(text_prediction.float(), dim=-1)
|
67 |
+
else:
|
68 |
+
text_prediction = forward_clip_single(model, image, text, h, w)
|
69 |
+
return text_prediction.float()
|
70 |
+
|
71 |
+
|
72 |
+
def upsample_position_embedding(embed, new_size):
|
73 |
+
"""Upsample the pretrained embedding to a higher resolution.
|
74 |
+
|
75 |
+
Args:
|
76 |
+
embed (torch.Tensor): the pretrained embedding.
|
77 |
+
new_size (Tuple[int, int]): the new size of the embedding.
|
78 |
+
|
79 |
+
Returns:
|
80 |
+
torch.Tensor: the upsampled embedding.
|
81 |
+
"""
|
82 |
+
# emb size NxD
|
83 |
+
first = embed[:1, :]
|
84 |
+
embed = embed[1:, :]
|
85 |
+
n = embed.size(0)
|
86 |
+
d = embed.size(1)
|
87 |
+
size = int(np.sqrt(n))
|
88 |
+
if size * size != n:
|
89 |
+
raise ValueError(f'The size of embed {n} is not a perfect square number.')
|
90 |
+
# new_size = size * self.upsample
|
91 |
+
embed = embed.permute(1, 0)
|
92 |
+
embed = embed.view(1, d, size, size).contiguous()
|
93 |
+
embed = F.upsample(
|
94 |
+
embed,
|
95 |
+
size=new_size,
|
96 |
+
mode='bilinear',
|
97 |
+
)
|
98 |
+
embed = embed.view(d, -1).contiguous()
|
99 |
+
embed = embed.permute(1, 0)
|
100 |
+
embed = torch.cat([first, embed], 0)
|
101 |
+
embed = nn.parameter.Parameter(embed.half())
|
102 |
+
return embed
|
103 |
+
|
104 |
+
|
105 |
+
class CustomBlock(nn.Module):
|
106 |
+
"""A customized attention block."""
|
107 |
+
|
108 |
+
def __init__(self, block):
|
109 |
+
super().__init__()
|
110 |
+
for k, v in vars(block).items():
|
111 |
+
setattr(self, k, v)
|
112 |
+
|
113 |
+
def attention(self, x):
|
114 |
+
self.attn_mask = (
|
115 |
+
self.attn_mask.to(dtype=x.dtype, device=x.device)
|
116 |
+
if self.attn_mask is not None
|
117 |
+
else None
|
118 |
+
)
|
119 |
+
self.attn = self.attn.to(dtype=x.dtype, device=x.device)
|
120 |
+
# Setting need_weights to True also returns the attention weights
|
121 |
+
return self.attn(x, x, x, need_weights=True, attn_mask=self.attn_mask)
|
122 |
+
|
123 |
+
def forward(self, x):
|
124 |
+
# attn_output: (L,N,E), attn_weight: (N,L,L)
|
125 |
+
attn_output, attn_weight = self.attention(self.ln_1(x))
|
126 |
+
x = x + attn_output
|
127 |
+
x = x + self.mlp(self.ln_2(x))
|
128 |
+
return x, attn_weight
|
129 |
+
|
130 |
+
|
131 |
+
class CustomTransformer(nn.Module):
|
132 |
+
"""A customized Transformer to support CAM calculation."""
|
133 |
+
|
134 |
+
def __init__(self, transformer):
|
135 |
+
"""Initialize the wrapper.
|
136 |
+
|
137 |
+
Args:
|
138 |
+
transformer (nn.Module): the Transformer to be wrapped.
|
139 |
+
"""
|
140 |
+
super().__init__()
|
141 |
+
for k, v in vars(transformer).items():
|
142 |
+
setattr(self, k, v)
|
143 |
+
|
144 |
+
self.resblocks = nn.Sequential(
|
145 |
+
*[CustomBlock(block) for block in self.resblocks]
|
146 |
+
)
|
147 |
+
|
148 |
+
def forward(self, x):
|
149 |
+
attn_weights = []
|
150 |
+
with torch.no_grad():
|
151 |
+
layers = self.layers if x.shape[0] == _CONTEXT_LENGTH else self.layers - 1
|
152 |
+
for i in range(layers):
|
153 |
+
x, attn_weight = self.resblocks[i](x)
|
154 |
+
attn_weights.append(attn_weight)
|
155 |
+
return x, attn_weights
|
156 |
+
|
157 |
+
|
158 |
+
class CustomVisionTransformer(nn.Module):
|
159 |
+
"""A customized VisionTransformer to support CAM calculation."""
|
160 |
+
|
161 |
+
def __init__(self, model):
|
162 |
+
"""Initialize the wrapper.
|
163 |
+
|
164 |
+
Args:
|
165 |
+
model (VisionTransformer): the VisionTransformer to be wrapped.
|
166 |
+
"""
|
167 |
+
super().__init__()
|
168 |
+
for k, v in vars(model).items():
|
169 |
+
setattr(self, k, v)
|
170 |
+
self.patch_size = self.conv1.kernel_size[0]
|
171 |
+
self.transformer = CustomTransformer(self.transformer)
|
172 |
+
|
173 |
+
def forward(self, x, h, w):
|
174 |
+
self.positional_embedding_new = upsample_position_embedding(
|
175 |
+
self.positional_embedding, (h // self.patch_size, w // self.patch_size)
|
176 |
+
)
|
177 |
+
# shape = [*, width, grid, grid]
|
178 |
+
x = self.conv1(x)
|
179 |
+
# shape = [*, width, grid ** 2]
|
180 |
+
x = x.reshape(x.shape[0], x.shape[1], -1)
|
181 |
+
# shape = [*, grid ** 2, width]
|
182 |
+
x = x.permute(0, 2, 1)
|
183 |
+
zeros = torch.zeros(
|
184 |
+
x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device
|
185 |
+
)
|
186 |
+
# shape = [*, grid ** 2 + 1, width]
|
187 |
+
x = torch.cat([self.class_embedding.to(x.dtype) + zeros, x], dim=1)
|
188 |
+
x = x + self.positional_embedding_new.to(x.dtype)
|
189 |
+
x = self.ln_pre(x)
|
190 |
+
# NLD -> LND
|
191 |
+
x = x.permute(1, 0, 2)
|
192 |
+
x, attn_weight = self.transformer(x)
|
193 |
+
return x, attn_weight
|
194 |
+
|
195 |
+
|
196 |
+
class CLIPWrapper(nn.Module):
|
197 |
+
"""A wrapper for CLIP to support forward with a list of text inputs."""
|
198 |
+
|
199 |
+
def __init__(self, clip_model):
|
200 |
+
"""Initialize the wrapper.
|
201 |
+
|
202 |
+
Args:
|
203 |
+
clip_model (CLIP): the CLIP model to be wrapped.
|
204 |
+
"""
|
205 |
+
super().__init__()
|
206 |
+
# copy all attributes from clip_model to self
|
207 |
+
for k, v in vars(clip_model).items():
|
208 |
+
setattr(self, k, v)
|
209 |
+
self.visual = CustomVisionTransformer(self.visual)
|
210 |
+
self.transformer = CustomTransformer(self.transformer)
|
211 |
+
|
212 |
+
@property
|
213 |
+
def dtype(self):
|
214 |
+
return self.visual.conv1.weight.dtype
|
215 |
+
|
216 |
+
def encode_image(self, image, h, w):
|
217 |
+
return self.visual(image.type(self.dtype), h, w)
|
218 |
+
|
219 |
+
def encode_text(self, text):
|
220 |
+
x = self.token_embedding(text).type(
|
221 |
+
self.dtype
|
222 |
+
) # [batch_size, n_ctx, d_model]
|
223 |
+
|
224 |
+
x = x + self.positional_embedding.type(self.dtype)
|
225 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
226 |
+
x, _ = self.transformer(x)
|
227 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
228 |
+
x = self.ln_final(x).type(self.dtype)
|
229 |
+
|
230 |
+
# x.shape = [batch_size, n_ctx, transformer.width]
|
231 |
+
# take features from the eot embedding
|
232 |
+
# (eot_token is the highest number in each sequence)
|
233 |
+
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
|
234 |
+
|
235 |
+
return x
|
236 |
+
|
237 |
+
def pool_visual(self, x, use_cls_token=False):
|
238 |
+
if use_cls_token:
|
239 |
+
return x[:, 0]
|
240 |
+
else:
|
241 |
+
return torch.mean(x[:, 1:, :], dim=1)
|
242 |
+
|
243 |
+
def forward_last_layer(
|
244 |
+
self, image_features, text_features, use_cls_token=False, repeat_last=True
|
245 |
+
):
|
246 |
+
"""Forward the last layer of CLIP.
|
247 |
+
|
248 |
+
Args:
|
249 |
+
image_features (torch.Tensor): the image features.
|
250 |
+
text_features (torch.Tensor): the text features.
|
251 |
+
use_cls_token (bool, optional): whether to use the CLS token. Defaults
|
252 |
+
to False.
|
253 |
+
repeat_last (bool, optional): whether to repeat the last layer. Defaults
|
254 |
+
to True.
|
255 |
+
|
256 |
+
Returns:
|
257 |
+
torch.Tensor: the logits.
|
258 |
+
torch.Tensor: the attention weights.
|
259 |
+
"""
|
260 |
+
if repeat_last:
|
261 |
+
x, attention_weight = self.visual.transformer.resblocks[
|
262 |
+
self.visual.transformer.layers - 1
|
263 |
+
](image_features)
|
264 |
+
else:
|
265 |
+
x = image_features
|
266 |
+
attention_weight = None
|
267 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
268 |
+
|
269 |
+
x = self.visual.ln_post(x)
|
270 |
+
x = self.pool_visual(x, use_cls_token=use_cls_token)
|
271 |
+
|
272 |
+
if self.visual.proj is not None:
|
273 |
+
x = x @ self.visual.proj
|
274 |
+
|
275 |
+
image_features = x
|
276 |
+
|
277 |
+
# normalized features
|
278 |
+
image_features = image_features / image_features.norm(dim=1, keepdim=True)
|
279 |
+
text_features = text_features / text_features.norm(dim=1, keepdim=True)
|
280 |
+
# cosine similarity as logits
|
281 |
+
logit_scale = self.logit_scale.exp()
|
282 |
+
logits_per_image = logit_scale * image_features @ text_features.t()
|
283 |
+
|
284 |
+
# shape = [global_batch_size, global_batch_size]
|
285 |
+
logits_per_image = F.softmax(logits_per_image.float(), dim=-1)
|
286 |
+
|
287 |
+
return logits_per_image, attention_weight
|
288 |
+
|
289 |
+
def forward(self, image, text, h=224, w=224):
|
290 |
+
with torch.no_grad():
|
291 |
+
text_features = self.encode_text(text)
|
292 |
+
feature_map, _ = self.visual(image.type(self.dtype), h, w)
|
293 |
+
|
294 |
+
logits_per_image, _ = self.forward_last_layer(
|
295 |
+
feature_map, text_features, use_cls_token=True, repeat_last=False
|
296 |
+
)
|
297 |
+
return logits_per_image
|
modeling/model/clipcam.py
ADDED
@@ -0,0 +1,255 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2024 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""Calculate CAM with CLIP model."""
|
17 |
+
|
18 |
+
import warnings
|
19 |
+
|
20 |
+
import clip
|
21 |
+
import cv2
|
22 |
+
import numpy as np
|
23 |
+
import torch
|
24 |
+
|
25 |
+
# pylint: disable=g-importing-member
|
26 |
+
# pylint: disable=g-bad-import-order
|
27 |
+
from modeling.model.cam import CAM
|
28 |
+
from modeling.model.cam import scale_cam_image
|
29 |
+
from modeling.model.utils import img_ms_and_flip
|
30 |
+
from modeling.model.utils import reshape_transform
|
31 |
+
from modeling.model.utils import scoremap2bbox
|
32 |
+
|
33 |
+
warnings.filterwarnings("ignore")
|
34 |
+
|
35 |
+
|
36 |
+
class ClipOutputTarget:
|
37 |
+
|
38 |
+
def __init__(self, category):
|
39 |
+
self.category = category
|
40 |
+
|
41 |
+
def __call__(self, model_output):
|
42 |
+
if len(model_output.shape) == 1:
|
43 |
+
return model_output[self.category]
|
44 |
+
return model_output[:, self.category]
|
45 |
+
|
46 |
+
|
47 |
+
def zeroshot_classifier(classnames, templates, model, device):
|
48 |
+
"""Zeroshot classifier."""
|
49 |
+
with torch.no_grad():
|
50 |
+
zeroshot_weights = []
|
51 |
+
for classname in classnames:
|
52 |
+
if templates is None:
|
53 |
+
texts = [classname]
|
54 |
+
else:
|
55 |
+
# format with class
|
56 |
+
texts = [template.format(classname) for template in templates]
|
57 |
+
texts = clip.tokenize(texts).to(device) # tokenize
|
58 |
+
class_embeddings = model.encode_text(texts) # embed with text encoder
|
59 |
+
class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
|
60 |
+
class_embedding = class_embeddings.mean(dim=0)
|
61 |
+
class_embedding /= class_embedding.norm()
|
62 |
+
zeroshot_weights.append(class_embedding)
|
63 |
+
zeroshot_weights = torch.stack(zeroshot_weights, dim=1).to(device)
|
64 |
+
return zeroshot_weights.t()
|
65 |
+
|
66 |
+
|
67 |
+
class CLIPCAM:
|
68 |
+
"""Generate CAM with CLIP model."""
|
69 |
+
|
70 |
+
def __init__(
|
71 |
+
self,
|
72 |
+
clip_model,
|
73 |
+
device,
|
74 |
+
text_template=None,
|
75 |
+
threshold=0.4,
|
76 |
+
bg_cls=None,
|
77 |
+
):
|
78 |
+
self.device = device
|
79 |
+
self.clip_model = clip_model.to(device)
|
80 |
+
self.text_template = text_template
|
81 |
+
self.threshold = threshold
|
82 |
+
self.stride = self.clip_model.visual.patch_size
|
83 |
+
|
84 |
+
# if self.dataset_name == 'voc' else BACKGROUND_CATEGORY_COCO
|
85 |
+
self.bg_cls = bg_cls
|
86 |
+
self.bg_text_features = None
|
87 |
+
if self.bg_cls is not None:
|
88 |
+
self.bg_text_features = zeroshot_classifier(
|
89 |
+
self.bg_cls,
|
90 |
+
("a clean origami {}.",),
|
91 |
+
self.clip_model,
|
92 |
+
self.device,
|
93 |
+
).to(self.device)
|
94 |
+
self.target_layers = [self.clip_model.visual.transformer.resblocks[-1].ln_1]
|
95 |
+
self.cam = CAM(
|
96 |
+
model=self.clip_model,
|
97 |
+
target_layers=self.target_layers,
|
98 |
+
reshape_transform=reshape_transform,
|
99 |
+
use_cuda="cuda" in device,
|
100 |
+
stride=self.stride,
|
101 |
+
)
|
102 |
+
|
103 |
+
def set_bg_cls(self, bg_cls):
|
104 |
+
# if len(bg_cls) == 0:
|
105 |
+
if not bg_cls:
|
106 |
+
self.bg_cls = None
|
107 |
+
self.bg_text_features = None
|
108 |
+
else:
|
109 |
+
self.bg_cls = bg_cls
|
110 |
+
self.bg_text_features = zeroshot_classifier(
|
111 |
+
self.bg_cls,
|
112 |
+
("a clean origami {}.",),
|
113 |
+
self.clip_model,
|
114 |
+
self.device,
|
115 |
+
).to(self.device)
|
116 |
+
|
117 |
+
def __call__(self, ori_img, text, scale=1.0):
|
118 |
+
"""Get CAM masks and features.
|
119 |
+
|
120 |
+
Args:
|
121 |
+
ori_img(Image): image to be searched.
|
122 |
+
text (str): text to be searched.
|
123 |
+
scale (float): image scale.
|
124 |
+
Returns:
|
125 |
+
CAM masks and features.
|
126 |
+
"""
|
127 |
+
ori_width = ori_img.size[0]
|
128 |
+
ori_height = ori_img.size[1]
|
129 |
+
if isinstance(text, str):
|
130 |
+
text = [text]
|
131 |
+
|
132 |
+
# convert image to bgr channel
|
133 |
+
ms_imgs = img_ms_and_flip(ori_img, ori_height, ori_width, scales=[scale])
|
134 |
+
image = ms_imgs[0]
|
135 |
+
|
136 |
+
image = image.unsqueeze(0)
|
137 |
+
h, w = image.shape[-2], image.shape[-1]
|
138 |
+
image = image.to(self.device)
|
139 |
+
image_features, attn_weight_list = self.clip_model.encode_image(image, h, w)
|
140 |
+
|
141 |
+
highres_cam_to_save = []
|
142 |
+
refined_cam_to_save = []
|
143 |
+
# keys = []
|
144 |
+
|
145 |
+
# [bg_id_for_each_image[im_idx]].to(device_id)
|
146 |
+
bg_features_temp = None
|
147 |
+
if self.bg_text_features is not None:
|
148 |
+
bg_features_temp = self.bg_text_features.to(self.device)
|
149 |
+
fg_features_temp = zeroshot_classifier(
|
150 |
+
text, self.text_template, self.clip_model, self.device
|
151 |
+
).to(self.device)
|
152 |
+
if bg_features_temp is None:
|
153 |
+
text_features_temp = fg_features_temp
|
154 |
+
else:
|
155 |
+
text_features_temp = torch.cat(
|
156 |
+
[fg_features_temp, bg_features_temp], dim=0
|
157 |
+
)
|
158 |
+
input_tensor = [
|
159 |
+
image_features,
|
160 |
+
text_features_temp.to(self.device),
|
161 |
+
h,
|
162 |
+
w,
|
163 |
+
]
|
164 |
+
|
165 |
+
# for idx, label in enumerate(label_list):
|
166 |
+
# keys.append(new_class_names.index(label))
|
167 |
+
for idx, _ in enumerate(text):
|
168 |
+
targets = [ClipOutputTarget(idx)]
|
169 |
+
|
170 |
+
# torch.cuda.empty_cache()
|
171 |
+
grayscale_cam, _, attn_weight_last = self.cam(
|
172 |
+
input_tensor=input_tensor, targets=targets, target_size=None
|
173 |
+
) # (ori_width, ori_height))
|
174 |
+
|
175 |
+
grayscale_cam = grayscale_cam[0, :]
|
176 |
+
if grayscale_cam.max() == 0:
|
177 |
+
input_tensor_fg = (
|
178 |
+
image_features,
|
179 |
+
fg_features_temp.to(self.device),
|
180 |
+
h,
|
181 |
+
w,
|
182 |
+
)
|
183 |
+
grayscale_cam, _, attn_weight_last = self.cam(
|
184 |
+
input_tensor=input_tensor_fg,
|
185 |
+
targets=targets,
|
186 |
+
target_size=None,
|
187 |
+
)
|
188 |
+
grayscale_cam = grayscale_cam[0, :]
|
189 |
+
|
190 |
+
grayscale_cam_highres = cv2.resize(grayscale_cam, (ori_width, ori_height))
|
191 |
+
highres_cam_to_save.append(torch.tensor(grayscale_cam_highres))
|
192 |
+
|
193 |
+
if idx == 0:
|
194 |
+
attn_weight_list.append(attn_weight_last)
|
195 |
+
attn_weight = [
|
196 |
+
aw[:, 1:, 1:] for aw in attn_weight_list
|
197 |
+
] # (b, hxw, hxw)
|
198 |
+
attn_weight = torch.stack(attn_weight, dim=0)[-8:]
|
199 |
+
attn_weight = torch.mean(attn_weight, dim=0)
|
200 |
+
attn_weight = attn_weight[0].cpu().detach()
|
201 |
+
attn_weight = attn_weight.float()
|
202 |
+
|
203 |
+
box, cnt = scoremap2bbox(
|
204 |
+
scoremap=grayscale_cam,
|
205 |
+
threshold=self.threshold,
|
206 |
+
multi_contour_eval=True,
|
207 |
+
)
|
208 |
+
aff_mask = torch.zeros((grayscale_cam.shape[0], grayscale_cam.shape[1]))
|
209 |
+
for i_ in range(cnt):
|
210 |
+
x0_, y0_, x1_, y1_ = box[i_]
|
211 |
+
aff_mask[y0_:y1_, x0_:x1_] = 1
|
212 |
+
|
213 |
+
aff_mask = aff_mask.view(
|
214 |
+
1, grayscale_cam.shape[0] * grayscale_cam.shape[1]
|
215 |
+
)
|
216 |
+
aff_mat = attn_weight
|
217 |
+
|
218 |
+
trans_mat = aff_mat / torch.sum(aff_mat, dim=0, keepdim=True)
|
219 |
+
trans_mat = trans_mat / torch.sum(trans_mat, dim=1, keepdim=True)
|
220 |
+
|
221 |
+
for _ in range(2):
|
222 |
+
trans_mat = trans_mat / torch.sum(trans_mat, dim=0, keepdim=True)
|
223 |
+
trans_mat = trans_mat / torch.sum(trans_mat, dim=1, keepdim=True)
|
224 |
+
trans_mat = (trans_mat + trans_mat.transpose(1, 0)) / 2
|
225 |
+
|
226 |
+
# This is copied from CLIP-ES
|
227 |
+
for _ in range(1):
|
228 |
+
trans_mat = torch.matmul(trans_mat, trans_mat)
|
229 |
+
|
230 |
+
trans_mat = trans_mat * aff_mask
|
231 |
+
|
232 |
+
cam_to_refine = torch.FloatTensor(grayscale_cam)
|
233 |
+
cam_to_refine = cam_to_refine.view(-1, 1)
|
234 |
+
|
235 |
+
# (n,n) * (n,1)->(n,1)
|
236 |
+
cam_refined = torch.matmul(trans_mat, cam_to_refine).reshape(
|
237 |
+
h // self.stride, w // self.stride
|
238 |
+
)
|
239 |
+
cam_refined = cam_refined.cpu().numpy().astype(np.float32)
|
240 |
+
cam_refined_highres = scale_cam_image(
|
241 |
+
[cam_refined], (ori_width, ori_height)
|
242 |
+
)[0]
|
243 |
+
refined_cam_to_save.append(torch.tensor(cam_refined_highres))
|
244 |
+
|
245 |
+
# post process the cam map
|
246 |
+
# label = process(raw_image, refined_cam, postprocessor)
|
247 |
+
# vis_img = vis_mask(np.asarray(raw_image), label, [0, 255, 0])
|
248 |
+
# vis_img.save(f'clip_es_crf_{idx}.jpg')
|
249 |
+
|
250 |
+
# keys = torch.tensor(keys)
|
251 |
+
# cam_all_scales.append(torch.stack(cam_to_save,dim=0))
|
252 |
+
|
253 |
+
cam_masks = torch.stack(refined_cam_to_save, dim=0)
|
254 |
+
|
255 |
+
return cam_masks.to(self.device), fg_features_temp.to(self.device)
|
modeling/model/crf.py
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2024 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""DenseCRF."""
|
17 |
+
|
18 |
+
import numpy as np
|
19 |
+
from pydensecrf import densecrf as dcrf
|
20 |
+
from pydensecrf import utils
|
21 |
+
import torch
|
22 |
+
import torch.nn.functional as F
|
23 |
+
|
24 |
+
|
25 |
+
class DenseCRF(object):
|
26 |
+
"""DenseCRF class."""
|
27 |
+
|
28 |
+
def __init__(self, iter_max, pos_w, pos_xy_std, bi_w, bi_xy_std, bi_rgb_std):
|
29 |
+
self.iter_max = iter_max
|
30 |
+
self.pos_w = pos_w
|
31 |
+
self.pos_xy_std = pos_xy_std
|
32 |
+
self.bi_w = bi_w
|
33 |
+
self.bi_xy_std = bi_xy_std
|
34 |
+
self.bi_rgb_std = bi_rgb_std
|
35 |
+
|
36 |
+
def __call__(self, image, probmap):
|
37 |
+
c, h, w = probmap.shape
|
38 |
+
|
39 |
+
u = utils.unary_from_softmax(probmap)
|
40 |
+
u = np.ascontiguousarray(u)
|
41 |
+
|
42 |
+
image = np.ascontiguousarray(image)
|
43 |
+
|
44 |
+
d = dcrf.DenseCRF2D(w, h, c)
|
45 |
+
d.setUnaryEnergy(u)
|
46 |
+
d.addPairwiseGaussian(sxy=self.pos_xy_std, compat=self.pos_w)
|
47 |
+
d.addPairwiseBilateral(
|
48 |
+
sxy=self.bi_xy_std,
|
49 |
+
srgb=self.bi_rgb_std,
|
50 |
+
rgbim=image,
|
51 |
+
compat=self.bi_w,
|
52 |
+
)
|
53 |
+
|
54 |
+
q = d.inference(self.iter_max)
|
55 |
+
q = np.array(q).reshape((c, h, w))
|
56 |
+
|
57 |
+
return q
|
58 |
+
|
59 |
+
|
60 |
+
class PostProcess:
|
61 |
+
"""Post processing with dense CRF."""
|
62 |
+
|
63 |
+
def __init__(self, device):
|
64 |
+
self.device = device
|
65 |
+
self.postprocessor = DenseCRF(
|
66 |
+
iter_max=10,
|
67 |
+
pos_xy_std=1,
|
68 |
+
pos_w=3,
|
69 |
+
bi_xy_std=67,
|
70 |
+
bi_rgb_std=3,
|
71 |
+
bi_w=4,
|
72 |
+
)
|
73 |
+
|
74 |
+
def apply_crf(self, image, cams, bg_factor=1.0):
|
75 |
+
"""Apply dense CRF."""
|
76 |
+
bg_score = np.power(1 - np.max(cams, axis=0, keepdims=True), bg_factor)
|
77 |
+
cams = np.concatenate((bg_score, cams), axis=0)
|
78 |
+
prob = cams
|
79 |
+
|
80 |
+
image = image.astype(np.uint8).transpose(1, 2, 0)
|
81 |
+
prob = self.postprocessor(image, prob)
|
82 |
+
|
83 |
+
label = np.argmax(prob, axis=0)
|
84 |
+
|
85 |
+
label_tensor = torch.from_numpy(label).long()
|
86 |
+
refined_mask = F.one_hot(label_tensor).to(device=self.device)
|
87 |
+
refined_mask = refined_mask.permute(2, 0, 1)
|
88 |
+
refined_mask = refined_mask[1:].float()
|
89 |
+
return refined_mask
|
90 |
+
|
91 |
+
def __call__(self, image, cams, separate=False, bg_factor=1.0):
|
92 |
+
mean_bgr = (104.008, 116.669, 122.675)
|
93 |
+
# covert Image to numpy array
|
94 |
+
image = np.array(image).astype(np.float32)
|
95 |
+
|
96 |
+
# RGB -> BGR
|
97 |
+
image = image[:, :, ::-1]
|
98 |
+
# Mean subtraction
|
99 |
+
image -= mean_bgr
|
100 |
+
# HWC -> CHW
|
101 |
+
image = image.transpose(2, 0, 1)
|
102 |
+
|
103 |
+
if isinstance(cams, torch.Tensor):
|
104 |
+
cams = cams.cpu().detach().numpy()
|
105 |
+
if separate:
|
106 |
+
refined_mask = [
|
107 |
+
self.apply_crf(image, cam[None], bg_factor) for cam in cams
|
108 |
+
]
|
109 |
+
refined_mask = torch.cat(refined_mask, dim=0)
|
110 |
+
else:
|
111 |
+
refined_mask = self.apply_crf(image, cams, bg_factor)
|
112 |
+
|
113 |
+
return refined_mask
|
modeling/model/utils.py
ADDED
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2024 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""CAM utils."""
|
17 |
+
|
18 |
+
# pylint: disable=g-importing-member
|
19 |
+
import os
|
20 |
+
|
21 |
+
import cv2
|
22 |
+
import numpy as np
|
23 |
+
from PIL import Image
|
24 |
+
from scipy.ndimage import binary_fill_holes
|
25 |
+
import torch
|
26 |
+
from torchvision.transforms import Compose
|
27 |
+
from torchvision.transforms import Normalize
|
28 |
+
from torchvision.transforms import Resize
|
29 |
+
from torchvision.transforms import ToTensor
|
30 |
+
|
31 |
+
# pylint: disable=g-import-not-at-top
|
32 |
+
try:
|
33 |
+
from torchvision.transforms import InterpolationMode
|
34 |
+
|
35 |
+
BICUBIC = InterpolationMode.BICUBIC
|
36 |
+
except ImportError:
|
37 |
+
BICUBIC = Image.BICUBIC
|
38 |
+
|
39 |
+
_CONTOUR_INDEX = 1 if cv2.__version__.split('.')[0] == '3' else 0
|
40 |
+
|
41 |
+
|
42 |
+
def _convert_image_to_rgb(image):
|
43 |
+
return image.convert('RGB')
|
44 |
+
|
45 |
+
|
46 |
+
def _transform_resize(h, w):
|
47 |
+
return Compose([
|
48 |
+
Resize((h, w), interpolation=BICUBIC),
|
49 |
+
_convert_image_to_rgb,
|
50 |
+
ToTensor(),
|
51 |
+
Normalize(
|
52 |
+
(0.48145466, 0.4578275, 0.40821073),
|
53 |
+
(0.26862954, 0.26130258, 0.27577711),
|
54 |
+
),
|
55 |
+
])
|
56 |
+
|
57 |
+
|
58 |
+
def img_ms_and_flip(image, ori_height, ori_width, scales=1.0, patch_size=16):
|
59 |
+
"""Resizes and flips the image."""
|
60 |
+
if isinstance(scales, float):
|
61 |
+
scales = [scales]
|
62 |
+
|
63 |
+
all_imgs = []
|
64 |
+
for scale in scales:
|
65 |
+
preprocess = _transform_resize(
|
66 |
+
int(np.ceil(scale * int(ori_height) / patch_size) * patch_size),
|
67 |
+
int(np.ceil(scale * int(ori_width) / patch_size) * patch_size),
|
68 |
+
)
|
69 |
+
image = preprocess(image)
|
70 |
+
image_ori = image
|
71 |
+
image_flip = torch.flip(image, [-1])
|
72 |
+
all_imgs.append(image_ori)
|
73 |
+
all_imgs.append(image_flip)
|
74 |
+
return all_imgs
|
75 |
+
|
76 |
+
|
77 |
+
def reshape_transform(tensor, height=28, width=28):
|
78 |
+
tensor = tensor.permute(1, 0, 2)
|
79 |
+
result = tensor[:, 1:, :].reshape(
|
80 |
+
tensor.size(0), height, width, tensor.size(2)
|
81 |
+
)
|
82 |
+
|
83 |
+
# Bring the channels to the first dimension, like in CNNs.
|
84 |
+
result = result.transpose(2, 3).transpose(1, 2)
|
85 |
+
return result
|
86 |
+
|
87 |
+
|
88 |
+
def vis_mask(image, mask, mask_color):
|
89 |
+
# switch the height and width of image
|
90 |
+
# image = image.transpose(1, 0, 2)
|
91 |
+
if mask.shape[0] != image.shape[0] or mask.shape[1] != image.shape[1]:
|
92 |
+
mask = cv2.resize(mask, (image.shape[1], image.shape[0]))
|
93 |
+
fg = mask > 0.5
|
94 |
+
rgb = np.copy(image)
|
95 |
+
rgb[fg] = (rgb[fg] * 0.3 + np.array(mask_color) * 0.7).astype(np.uint8)
|
96 |
+
return Image.fromarray(rgb)
|
97 |
+
|
98 |
+
|
99 |
+
def scoremap2bbox(scoremap, threshold, multi_contour_eval=False):
|
100 |
+
"""Get bounding boxes from scoremap."""
|
101 |
+
height, width = scoremap.shape
|
102 |
+
scoremap_image = np.expand_dims((scoremap * 255).astype(np.uint8), 2)
|
103 |
+
while True:
|
104 |
+
_, thr_gray_heatmap = cv2.threshold(
|
105 |
+
src=scoremap_image,
|
106 |
+
thresh=int(threshold * np.max(scoremap_image)),
|
107 |
+
maxval=255,
|
108 |
+
type=cv2.THRESH_BINARY,
|
109 |
+
)
|
110 |
+
if thr_gray_heatmap.max() > 0 or threshold <= 0:
|
111 |
+
break
|
112 |
+
threshold -= 0.1
|
113 |
+
contours = cv2.findContours(
|
114 |
+
image=thr_gray_heatmap, mode=cv2.RETR_TREE, method=cv2.CHAIN_APPROX_SIMPLE
|
115 |
+
)[_CONTOUR_INDEX]
|
116 |
+
|
117 |
+
# if len(contours) == 0:
|
118 |
+
if not contours:
|
119 |
+
return np.asarray([[0, 0, 0, 0]]), 1
|
120 |
+
|
121 |
+
if not multi_contour_eval:
|
122 |
+
contours = [max(contours, key=cv2.contourArea)]
|
123 |
+
|
124 |
+
estimated_boxes = []
|
125 |
+
for contour in contours:
|
126 |
+
x, y, w, h = cv2.boundingRect(contour)
|
127 |
+
x0, y0, x1, y1 = x, y, x + w, y + h
|
128 |
+
x1 = min(x1, width - 1)
|
129 |
+
y1 = min(y1, height - 1)
|
130 |
+
estimated_boxes.append([x0, y0, x1, y1])
|
131 |
+
|
132 |
+
return np.asarray(estimated_boxes), len(contours)
|
133 |
+
|
134 |
+
|
135 |
+
def mask2chw(arr):
|
136 |
+
# Find the row and column indices where the array is 1
|
137 |
+
rows, cols = np.where(arr == 1)
|
138 |
+
# Calculate center of the mask
|
139 |
+
center_y = int(np.mean(rows))
|
140 |
+
center_x = int(np.mean(cols))
|
141 |
+
# Calculate height and width of the mask
|
142 |
+
height = rows.max() - rows.min() + 1
|
143 |
+
width = cols.max() - cols.min() + 1
|
144 |
+
return (center_y, center_x), height, width
|
145 |
+
|
146 |
+
|
147 |
+
def unpad(image_array, pad=None):
|
148 |
+
if pad is not None:
|
149 |
+
left, top, width, height = pad
|
150 |
+
image_array = image_array[top : top + height, left : left + width, :]
|
151 |
+
return image_array
|
152 |
+
|
153 |
+
|
154 |
+
def apply_visual_prompts(
|
155 |
+
image_array,
|
156 |
+
mask,
|
157 |
+
visual_prompt_type=('circle',),
|
158 |
+
visualize=False,
|
159 |
+
color=(255, 0, 0),
|
160 |
+
thickness=1,
|
161 |
+
blur_strength=(15, 15),
|
162 |
+
):
|
163 |
+
"""Applies visual prompts to the image."""
|
164 |
+
prompted_image = image_array.copy()
|
165 |
+
if 'blur' in visual_prompt_type:
|
166 |
+
# blur the part out side the mask
|
167 |
+
# Blur the entire image
|
168 |
+
blurred = cv2.GaussianBlur(prompted_image.copy(), blur_strength, 0)
|
169 |
+
# Get the sharp region using the mask
|
170 |
+
sharp_region = cv2.bitwise_and(
|
171 |
+
prompted_image.copy(),
|
172 |
+
prompted_image.copy(),
|
173 |
+
mask=np.clip(mask, 0, 255).astype(np.uint8),
|
174 |
+
)
|
175 |
+
# Get the blurred region using the inverted mask
|
176 |
+
inv_mask = 1 - mask
|
177 |
+
blurred_region = (blurred * inv_mask[:, :, None]).astype(np.uint8)
|
178 |
+
# Combine the sharp and blurred regions
|
179 |
+
prompted_image = cv2.add(sharp_region, blurred_region)
|
180 |
+
if 'gray' in visual_prompt_type:
|
181 |
+
gray = cv2.cvtColor(prompted_image.copy(), cv2.COLOR_BGR2GRAY)
|
182 |
+
# make gray part 3 channel
|
183 |
+
gray = np.stack([gray, gray, gray], axis=-1)
|
184 |
+
# Get the sharp region using the mask
|
185 |
+
color_region = cv2.bitwise_and(
|
186 |
+
prompted_image.copy(),
|
187 |
+
prompted_image.copy(),
|
188 |
+
mask=np.clip(mask, 0, 255).astype(np.uint8),
|
189 |
+
)
|
190 |
+
# Get the blurred region using the inverted mask
|
191 |
+
inv_mask = 1 - mask
|
192 |
+
gray_region = (gray * inv_mask[:, :, None]).astype(np.uint8)
|
193 |
+
# Combine the sharp and blurred regions
|
194 |
+
prompted_image = cv2.add(color_region, gray_region)
|
195 |
+
if 'black' in visual_prompt_type:
|
196 |
+
prompted_image = cv2.bitwise_and(
|
197 |
+
prompted_image.copy(),
|
198 |
+
prompted_image.copy(),
|
199 |
+
mask=np.clip(mask, 0, 255).astype(np.uint8),
|
200 |
+
)
|
201 |
+
if 'circle' in visual_prompt_type:
|
202 |
+
mask_center, mask_height, mask_width = mask2chw(mask)
|
203 |
+
center_coordinates = (mask_center[1], mask_center[0])
|
204 |
+
axes_length = (mask_width // 2, mask_height // 2)
|
205 |
+
prompted_image = cv2.ellipse(
|
206 |
+
prompted_image,
|
207 |
+
center_coordinates,
|
208 |
+
axes_length,
|
209 |
+
0,
|
210 |
+
0,
|
211 |
+
360,
|
212 |
+
color,
|
213 |
+
thickness,
|
214 |
+
)
|
215 |
+
if 'rectangle' in visual_prompt_type:
|
216 |
+
mask_center, mask_height, mask_width = mask2chw(mask)
|
217 |
+
# center_coordinates = (mask_center[1], mask_center[0])
|
218 |
+
# axes_length = (mask_width // 2, mask_height // 2)
|
219 |
+
start_point = (
|
220 |
+
mask_center[1] - mask_width // 2,
|
221 |
+
mask_center[0] - mask_height // 2,
|
222 |
+
)
|
223 |
+
end_point = (
|
224 |
+
mask_center[1] + mask_width // 2,
|
225 |
+
mask_center[0] + mask_height // 2,
|
226 |
+
)
|
227 |
+
prompted_image = cv2.rectangle(
|
228 |
+
prompted_image, start_point, end_point, color, thickness
|
229 |
+
)
|
230 |
+
if 'contour' in visual_prompt_type:
|
231 |
+
# Find the contours of the mask
|
232 |
+
# fill holes for the mask
|
233 |
+
mask = binary_fill_holes(mask)
|
234 |
+
contours, _ = cv2.findContours(
|
235 |
+
mask.astype(np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE
|
236 |
+
)
|
237 |
+
# Draw the contours on the image
|
238 |
+
prompted_image = cv2.drawContours(
|
239 |
+
prompted_image.copy(), contours, -1, color, thickness
|
240 |
+
)
|
241 |
+
|
242 |
+
if visualize:
|
243 |
+
cv2.imwrite(os.path.join('masked_img.png'), prompted_image)
|
244 |
+
prompted_image = Image.fromarray(prompted_image.astype(np.uint8))
|
245 |
+
return prompted_image
|
modeling/model/utils_test.py
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2024 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""This file contains the unit tests for the utils.py file."""
|
17 |
+
|
18 |
+
import numpy as np
|
19 |
+
from PIL import Image
|
20 |
+
import torch
|
21 |
+
|
22 |
+
# pylint: disable=g-bad-import-order
|
23 |
+
from modeling.model import utils
|
24 |
+
|
25 |
+
|
26 |
+
def test_scoremap2bbox():
|
27 |
+
"""Test the scoremap2bbox function."""
|
28 |
+
scoremap = np.zeros((10, 10))
|
29 |
+
scoremap[1:5, 1:5] = 1
|
30 |
+
scoremap[5:9, 5:9] = 2
|
31 |
+
scoremap[5:9, 1:5] = 3
|
32 |
+
scoremap[1:5, 5:9] = 4
|
33 |
+
bbox, len_bboxes = utils.scoremap2bbox(scoremap, 0.5)
|
34 |
+
assert len_bboxes == 1
|
35 |
+
assert bbox[0, 0] == 1
|
36 |
+
assert bbox[0, 1] == 1
|
37 |
+
assert bbox[0, 2] == 9
|
38 |
+
assert bbox[0, 3] == 9
|
39 |
+
|
40 |
+
|
41 |
+
def test_mask2chw():
|
42 |
+
"""Test the mask2chw function."""
|
43 |
+
mask = np.zeros((10, 10))
|
44 |
+
mask[1:5, 1:5] = 1
|
45 |
+
mask[5:9, 5:9] = 2
|
46 |
+
mask[5:9, 1:5] = 3
|
47 |
+
mask[1:5, 5:9] = 4
|
48 |
+
mask = torch.tensor(mask)
|
49 |
+
mask_center, mask_height, mask_width = utils.mask2chw(mask)
|
50 |
+
assert len(mask_center) == 2
|
51 |
+
assert mask_center[0] == 2
|
52 |
+
assert mask_center[1] == 2
|
53 |
+
assert mask_height == 4
|
54 |
+
assert mask_width == 4
|
55 |
+
|
56 |
+
|
57 |
+
def test_unpad():
|
58 |
+
"""Test the unpad function."""
|
59 |
+
image = np.zeros((10, 10, 1))
|
60 |
+
image[1:5, 1:5] = 1
|
61 |
+
image[5:9, 5:9] = 2
|
62 |
+
image[5:9, 1:5] = 3
|
63 |
+
image[1:5, 5:9] = 4
|
64 |
+
unpad_image = utils.unpad(image, pad=(1, 1, 8, 8))
|
65 |
+
assert len(unpad_image[0]) == 8, 'The width of the image is not 8.'
|
66 |
+
assert len(unpad_image[1]) == 8, 'The height of the image is not 8.'
|
67 |
+
unpad_image = utils.unpad(image, None)
|
68 |
+
assert (unpad_image == image).sum() == 100
|
69 |
+
|
70 |
+
|
71 |
+
def test_apply_visual_prompts():
|
72 |
+
"""Test the apply_visual_prompts function."""
|
73 |
+
image = np.ones((5, 5))
|
74 |
+
mask = np.array([
|
75 |
+
[0, 0, 0, 0, 0],
|
76 |
+
[0, 0, 0, 0, 0],
|
77 |
+
[0, 0, 1.0, 0, 0],
|
78 |
+
[0, 0, 0, 0, 0],
|
79 |
+
[0, 0, 0, 0, 0],
|
80 |
+
])
|
81 |
+
|
82 |
+
target = np.array([
|
83 |
+
[1, 1, 255, 1, 1],
|
84 |
+
[1, 255, 1, 255, 1],
|
85 |
+
[255, 1, 1, 1, 255],
|
86 |
+
[1, 255, 1, 255, 1],
|
87 |
+
[1, 1, 255, 1, 1],
|
88 |
+
])
|
89 |
+
mask[1:5, 1:5] = 1
|
90 |
+
prompted_image = utils.apply_visual_prompts(
|
91 |
+
image, mask, visual_prompt_type='circle', thickness=1
|
92 |
+
)
|
93 |
+
prompted_array = np.array(prompted_image)
|
94 |
+
assert (prompted_array == target).sum() == 25
|
95 |
+
|
96 |
+
|
97 |
+
def test_reshape_transform():
|
98 |
+
"""Test the reshape_transform function."""
|
99 |
+
image = torch.zeros((101, 10, 32))
|
100 |
+
image = utils.reshape_transform(image, height=10, width=10)
|
101 |
+
b, c, h, w = image.shape
|
102 |
+
assert b == 10
|
103 |
+
assert c == 32
|
104 |
+
assert h == 10
|
105 |
+
assert w == 10
|
106 |
+
|
107 |
+
|
108 |
+
def test_img_ms_and_flip():
|
109 |
+
"""Test the img_ms_and_flip function."""
|
110 |
+
image = np.zeros((120, 150))
|
111 |
+
image[1:5, 1:5] = 1
|
112 |
+
image[5:9, 5:9] = 2
|
113 |
+
image[5:9, 1:5] = 3
|
114 |
+
image[1:5, 5:9] = 4
|
115 |
+
image = Image.fromarray(image)
|
116 |
+
image = utils.img_ms_and_flip(image, 120, 150, scales=[1.2], patch_size=16)
|
117 |
+
image = image[0]
|
118 |
+
h, w = image.shape[-2:]
|
119 |
+
assert h == int(np.ceil(1.2 * 120 / 16) * 16)
|
120 |
+
assert w == int(np.ceil(1.2 * 150 / 16) * 16)
|
121 |
+
|
122 |
+
|
123 |
+
if __name__ == '__main__':
|
124 |
+
test_scoremap2bbox()
|
125 |
+
test_mask2chw()
|
126 |
+
test_unpad()
|
127 |
+
test_apply_visual_prompts()
|
128 |
+
test_reshape_transform()
|
129 |
+
test_img_ms_and_flip()
|
modeling/post_process/object_discovery.py
ADDED
@@ -0,0 +1,355 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2024 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""Find objects."""
|
17 |
+
|
18 |
+
# pylint: disable=g-importing-member
|
19 |
+
import numpy as np
|
20 |
+
import scipy
|
21 |
+
from scipy import ndimage
|
22 |
+
from scipy.linalg import eigh
|
23 |
+
from scipy.ndimage import label
|
24 |
+
import torch
|
25 |
+
import torch.nn.functional as F
|
26 |
+
|
27 |
+
|
28 |
+
def ncut(
|
29 |
+
feats,
|
30 |
+
dims,
|
31 |
+
scales,
|
32 |
+
init_image_size,
|
33 |
+
tau=0,
|
34 |
+
eps=1e-5,
|
35 |
+
no_binary_graph=False,
|
36 |
+
):
|
37 |
+
"""Implementation of NCut Method.
|
38 |
+
|
39 |
+
Args:
|
40 |
+
feats: the pixel/patche features of an image
|
41 |
+
dims: dimension of the map from which the features are used
|
42 |
+
scales: from image to map scale
|
43 |
+
init_image_size: size of the image
|
44 |
+
tau: thresold for graph construction
|
45 |
+
eps: graph edge weight
|
46 |
+
no_binary_graph: ablation study for using similarity score as graph
|
47 |
+
edge weight
|
48 |
+
Returns:
|
49 |
+
TODO
|
50 |
+
"""
|
51 |
+
feats = feats[0, 1:, :]
|
52 |
+
feats = F.normalize(feats, p=2)
|
53 |
+
a = feats @ feats.transpose(1, 0)
|
54 |
+
a = a.cpu().numpy()
|
55 |
+
if no_binary_graph:
|
56 |
+
a[a < tau] = eps
|
57 |
+
else:
|
58 |
+
a = a > tau
|
59 |
+
a = np.where(a.astype(float) == 0, eps, a)
|
60 |
+
d_i = np.sum(a, axis=1)
|
61 |
+
d = np.diag(d_i)
|
62 |
+
|
63 |
+
# Print second and third smallest eigenvector
|
64 |
+
_, eigenvectors = eigh(d - a, d, subset_by_index=[1, 2])
|
65 |
+
eigenvec = np.copy(eigenvectors[:, 0])
|
66 |
+
|
67 |
+
# Using average point to compute bipartition
|
68 |
+
second_smallest_vec = eigenvectors[:, 0]
|
69 |
+
avg = np.sum(second_smallest_vec) / len(second_smallest_vec)
|
70 |
+
bipartition = second_smallest_vec > avg
|
71 |
+
|
72 |
+
seed = np.argmax(np.abs(second_smallest_vec))
|
73 |
+
|
74 |
+
if bipartition[seed] != 1:
|
75 |
+
eigenvec = eigenvec * -1
|
76 |
+
bipartition = np.logical_not(bipartition)
|
77 |
+
bipartition = bipartition.reshape(dims).astype(float)
|
78 |
+
|
79 |
+
# predict BBox
|
80 |
+
# We only extract the principal object BBox
|
81 |
+
pred, _, objects, cc = detect_box(
|
82 |
+
bipartition,
|
83 |
+
seed,
|
84 |
+
dims,
|
85 |
+
scales=scales,
|
86 |
+
initial_im_size=init_image_size[1:],
|
87 |
+
)
|
88 |
+
mask = np.zeros(dims)
|
89 |
+
mask[cc[0], cc[1]] = 1
|
90 |
+
|
91 |
+
return np.asarray(pred), objects, mask, seed, None, eigenvec.reshape(dims)
|
92 |
+
|
93 |
+
|
94 |
+
def grad_obj_discover_on_attn(attn, gradcam, dims, topk=1, threshold=0.6):
|
95 |
+
"""Get the gradcam and attn map, then find the seed, then use LOST algorithm to find the potential points.
|
96 |
+
|
97 |
+
Args:
|
98 |
+
attn: attention map from ViT averaged across all heads, shape: [1,
|
99 |
+
(1+num_patches), (1+num_patches)].
|
100 |
+
gradcam: gradcam map from ViT, shape: [1, 1, H, W].
|
101 |
+
dims:
|
102 |
+
topk:
|
103 |
+
threshold:
|
104 |
+
Returns:
|
105 |
+
th_attn:
|
106 |
+
"""
|
107 |
+
|
108 |
+
w_featmap, h_featmap = dims
|
109 |
+
# nh = attn.shape[1]
|
110 |
+
attn = attn.squeeze()
|
111 |
+
|
112 |
+
seeds = torch.argsort(gradcam.flatten(), descending=True)[:topk]
|
113 |
+
|
114 |
+
# We keep only the output patch attention
|
115 |
+
# Get the attentions corresponding to [CLS] token
|
116 |
+
patch_attn = attn[1:, 1:]
|
117 |
+
topk_attn = patch_attn[seeds]
|
118 |
+
nh = topk_attn.shape[0]
|
119 |
+
# attentions = attn[0, :, 0, 1:].reshape(nh, -1)
|
120 |
+
|
121 |
+
# we keep only a certain percentage of the mass
|
122 |
+
val, idx = torch.sort(topk_attn)
|
123 |
+
val /= torch.sum(val, dim=1, keepdim=True)
|
124 |
+
cumval = torch.cumsum(val, dim=1)
|
125 |
+
th_attn = cumval > (1 - threshold)
|
126 |
+
idx2 = torch.argsort(idx)
|
127 |
+
for h in range(nh):
|
128 |
+
th_attn[h] = th_attn[h][idx2[h]]
|
129 |
+
th_attn = th_attn.reshape(nh, w_featmap, h_featmap).float()
|
130 |
+
th_attn = th_attn.sum(0)
|
131 |
+
th_attn[th_attn > 1] = 1
|
132 |
+
return th_attn[None, None]
|
133 |
+
|
134 |
+
|
135 |
+
def grad_obj_discover(feats, gradcam, dims):
|
136 |
+
"""Using gradient heatmap to find the seed, then use LOST algorithm to find the potential points.
|
137 |
+
|
138 |
+
Args:
|
139 |
+
feats: the pixel/patche features of an image. Shape: [1, HW, C]
|
140 |
+
gradcam: the grad cam map
|
141 |
+
dims: dimension of the map from which the features are used
|
142 |
+
|
143 |
+
Returns:
|
144 |
+
pred: box predictions
|
145 |
+
A: binary affinity matrix
|
146 |
+
scores: lowest degree scores for all patches
|
147 |
+
seed: selected patch corresponding to an object
|
148 |
+
"""
|
149 |
+
# Compute the similarity
|
150 |
+
a = (feats @ feats.transpose(1, 2)).squeeze()
|
151 |
+
|
152 |
+
# Compute the inverse degree centrality measure per patch
|
153 |
+
# sorted_patches, scores = patch_scoring(a)
|
154 |
+
|
155 |
+
# Select the initial seed
|
156 |
+
# seed = sorted_patches[0]
|
157 |
+
seed = gradcam.argmax()
|
158 |
+
mask = a[seed]
|
159 |
+
mask = mask.view(1, 1, *dims)
|
160 |
+
|
161 |
+
return mask
|
162 |
+
|
163 |
+
|
164 |
+
def lost(feats, dims, scales, init_image_size, k_patches=100):
|
165 |
+
"""Implementation of LOST method.
|
166 |
+
|
167 |
+
Args:
|
168 |
+
feats: the pixel/patche features of an image. Shape: [1, C, H, W]
|
169 |
+
dims: dimension of the map from which the features are used
|
170 |
+
scales: from image to map scale
|
171 |
+
init_image_size: size of the image
|
172 |
+
k_patches: number of k patches retrieved that are compared to the seed
|
173 |
+
at seed expansion.
|
174 |
+
Returns:
|
175 |
+
pred: box predictions
|
176 |
+
A: binary affinity matrix
|
177 |
+
scores: lowest degree scores for all patches
|
178 |
+
seed: selected patch corresponding to an object
|
179 |
+
"""
|
180 |
+
# Compute the similarity
|
181 |
+
feats = feats.flatten(2).transpose(1, 2)
|
182 |
+
a = (feats @ feats.transpose(1, 2)).squeeze()
|
183 |
+
|
184 |
+
# Compute the inverse degree centrality measure per patch
|
185 |
+
sorted_patches, _ = patch_scoring(a)
|
186 |
+
|
187 |
+
# Select the initial seed
|
188 |
+
seed = sorted_patches[0]
|
189 |
+
|
190 |
+
# Seed expansion
|
191 |
+
potentials = sorted_patches[:k_patches]
|
192 |
+
similars = potentials[a[seed, potentials] > 0.0]
|
193 |
+
m = torch.sum(a[similars, :], dim=0)
|
194 |
+
|
195 |
+
# Box extraction
|
196 |
+
_, _, _, mask = detect_box(
|
197 |
+
m, seed, dims, scales=scales, initial_im_size=init_image_size[1:]
|
198 |
+
)
|
199 |
+
|
200 |
+
return mask
|
201 |
+
# return np.asarray(bbox), A, scores, seed
|
202 |
+
|
203 |
+
|
204 |
+
def patch_scoring(m, threshold=0.0):
|
205 |
+
"""Patch scoring based on the inverse degree."""
|
206 |
+
# Cloning important
|
207 |
+
a = m.clone()
|
208 |
+
|
209 |
+
# Zero diagonal
|
210 |
+
a.fill_diagonal_(0)
|
211 |
+
|
212 |
+
# Make sure symmetric and non nul
|
213 |
+
a[a < 0] = 0
|
214 |
+
# C = A + A.t()
|
215 |
+
|
216 |
+
# Sort pixels by inverse degree
|
217 |
+
cent = -torch.sum(a > threshold, dim=1).type(torch.float32)
|
218 |
+
sel = torch.argsort(cent, descending=True)
|
219 |
+
|
220 |
+
return sel, cent
|
221 |
+
|
222 |
+
|
223 |
+
def detect_box(
|
224 |
+
bipartition,
|
225 |
+
seed,
|
226 |
+
dims,
|
227 |
+
initial_im_size=None,
|
228 |
+
scales=None,
|
229 |
+
principle_object=True,
|
230 |
+
):
|
231 |
+
"""Extract a box corresponding to the seed patch."""
|
232 |
+
|
233 |
+
# Among connected components extract from the affinity matrix, select the one
|
234 |
+
# corresponding to the seed patch.
|
235 |
+
|
236 |
+
# w_featmap, h_featmap = dims
|
237 |
+
objects, _ = ndimage.label(bipartition)
|
238 |
+
cc = objects[np.unravel_index(seed, dims)]
|
239 |
+
|
240 |
+
if principle_object:
|
241 |
+
mask = np.where(objects == cc)
|
242 |
+
# Add +1 because excluded max
|
243 |
+
ymin, ymax = min(mask[0]), max(mask[0]) + 1
|
244 |
+
xmin, xmax = min(mask[1]), max(mask[1]) + 1
|
245 |
+
# Rescale to image size
|
246 |
+
r_xmin, r_xmax = scales[1] * xmin, scales[1] * xmax
|
247 |
+
r_ymin, r_ymax = scales[0] * ymin, scales[0] * ymax
|
248 |
+
pred = [r_xmin, r_ymin, r_xmax, r_ymax]
|
249 |
+
|
250 |
+
# Check not out of image size (used when padding)
|
251 |
+
if initial_im_size:
|
252 |
+
pred[2] = min(pred[2], initial_im_size[1])
|
253 |
+
pred[3] = min(pred[3], initial_im_size[0])
|
254 |
+
|
255 |
+
# Coordinate predictions for the feature space
|
256 |
+
# Axis different then in image space
|
257 |
+
pred_feats = [ymin, xmin, ymax, xmax]
|
258 |
+
|
259 |
+
return pred, pred_feats, objects, mask
|
260 |
+
else:
|
261 |
+
raise NotImplementedError
|
262 |
+
|
263 |
+
|
264 |
+
# This function is modified from
|
265 |
+
# https://github.com/facebookresearch/dino/blob/main/visualize_attention.py
|
266 |
+
# Ref: https://github.com/facebookresearch/dino.
|
267 |
+
def dino_seg(attn, dims, patch_size, head=0):
|
268 |
+
"""Extraction of boxes based on the DINO segmentation method proposed in DINO."""
|
269 |
+
w_featmap, h_featmap = dims
|
270 |
+
nh = attn.shape[1]
|
271 |
+
official_th = 0.6
|
272 |
+
|
273 |
+
# We keep only the output patch attention
|
274 |
+
# Get the attentions corresponding to [CLS] token
|
275 |
+
attentions = attn[0, :, 0, 1:].reshape(nh, -1)
|
276 |
+
|
277 |
+
# we keep only a certain percentage of the mass
|
278 |
+
val, idx = torch.sort(attentions)
|
279 |
+
val /= torch.sum(val, dim=1, keepdim=True)
|
280 |
+
cumval = torch.cumsum(val, dim=1)
|
281 |
+
th_attn = cumval > (1 - official_th)
|
282 |
+
idx2 = torch.argsort(idx)
|
283 |
+
for h in range(nh):
|
284 |
+
th_attn[h] = th_attn[h][idx2[h]]
|
285 |
+
th_attn = th_attn.reshape(nh, w_featmap, h_featmap).float()
|
286 |
+
|
287 |
+
# Connected components
|
288 |
+
labeled_array, _ = scipy.ndimage.label(th_attn[head].cpu().numpy())
|
289 |
+
|
290 |
+
# Find the biggest component
|
291 |
+
size_components = [
|
292 |
+
np.sum(labeled_array == c) for c in range(np.max(labeled_array))
|
293 |
+
]
|
294 |
+
|
295 |
+
if len(size_components) > 1:
|
296 |
+
# Select the biggest component avoiding component 0 corresponding
|
297 |
+
# to background
|
298 |
+
biggest_component = np.argmax(size_components[1:]) + 1
|
299 |
+
else:
|
300 |
+
# Cases of a single component
|
301 |
+
biggest_component = 0
|
302 |
+
|
303 |
+
# Mask corresponding to connected component
|
304 |
+
mask = np.where(labeled_array == biggest_component)
|
305 |
+
|
306 |
+
# Add +1 because excluded max
|
307 |
+
ymin, ymax = min(mask[0]), max(mask[0]) + 1
|
308 |
+
xmin, xmax = min(mask[1]), max(mask[1]) + 1
|
309 |
+
|
310 |
+
# Rescale to image
|
311 |
+
r_xmin, r_xmax = xmin * patch_size, xmax * patch_size
|
312 |
+
r_ymin, r_ymax = ymin * patch_size, ymax * patch_size
|
313 |
+
pred = [r_xmin, r_ymin, r_xmax, r_ymax]
|
314 |
+
|
315 |
+
return pred
|
316 |
+
|
317 |
+
|
318 |
+
def get_feats(feat_out, shape):
|
319 |
+
# Batch size, Number of heads, Number of tokens
|
320 |
+
nb_im, nh, nb_tokens = shape[0:3]
|
321 |
+
qkv = (
|
322 |
+
feat_out["qkv"]
|
323 |
+
.reshape(nb_im, nb_tokens, 3, nh, -1 // nh)
|
324 |
+
.permute(2, 0, 3, 1, 4)
|
325 |
+
)
|
326 |
+
k = qkv[1]
|
327 |
+
k = k.transpose(1, 2).reshape(nb_im, nb_tokens, -1)
|
328 |
+
return k
|
329 |
+
|
330 |
+
|
331 |
+
def get_instances(masks, return_largest=False):
|
332 |
+
return [
|
333 |
+
get_instances_single(m[None], return_largest=return_largest)
|
334 |
+
for m in masks
|
335 |
+
]
|
336 |
+
|
337 |
+
|
338 |
+
def get_instances_single(mask, return_largest=False):
|
339 |
+
"""Get the mask of a single instance."""
|
340 |
+
labeled_array, _ = label(mask.cpu().numpy())
|
341 |
+
instances = np.concatenate(
|
342 |
+
[labeled_array == c for c in range(np.max(labeled_array) + 1)], axis=0
|
343 |
+
)
|
344 |
+
if return_largest:
|
345 |
+
size_components = np.sum(instances, axis=(1, 2))
|
346 |
+
if len(size_components) > 1:
|
347 |
+
# Select the biggest component avoiding component 0 corresponding
|
348 |
+
# to background
|
349 |
+
biggest_component = np.argmax(size_components[1:]) + 1
|
350 |
+
else:
|
351 |
+
# Cases of a single component
|
352 |
+
biggest_component = 0
|
353 |
+
# Mask corresponding to connected component
|
354 |
+
return torch.from_numpy(labeled_array == biggest_component).float()
|
355 |
+
return torch.from_numpy(instances[1:]).float()
|
modeling/post_process/post_process.py
ADDED
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2024 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""Post processing."""
|
17 |
+
|
18 |
+
import torch
|
19 |
+
import torch.nn.functional as F
|
20 |
+
|
21 |
+
# pylint: disable=g-bad-import-order
|
22 |
+
# pylint: disable=g-importing-member
|
23 |
+
from modeling.post_process.object_discovery import get_instances
|
24 |
+
from utils.metrics import IoM
|
25 |
+
|
26 |
+
|
27 |
+
# This should be a abstract function to generate masks for the input image.
|
28 |
+
# However, we first hack it due to the time limit.
|
29 |
+
def generate_masks_from_sam(
|
30 |
+
image_path, save_path, pipeline, img_sam=None, visualize=True
|
31 |
+
):
|
32 |
+
"""Generate masks from SAM."""
|
33 |
+
masks, _, mask_list = pipeline.segment_automask(
|
34 |
+
image_path=image_path,
|
35 |
+
visualize=visualize,
|
36 |
+
save_path=save_path,
|
37 |
+
image=img_sam,
|
38 |
+
)
|
39 |
+
mask_tensor = torch.from_numpy(masks)
|
40 |
+
mask_tensor = mask_tensor.float()
|
41 |
+
return mask_tensor, mask_list
|
42 |
+
|
43 |
+
|
44 |
+
def match_masks(
|
45 |
+
mask_tensor, attn_map, mask_list, iom_thres=0.0, min_pred_threshold=0.2
|
46 |
+
):
|
47 |
+
"""Match masks with the attention map according to the IoU.
|
48 |
+
|
49 |
+
Args:
|
50 |
+
mask_tensor: A torch.Tensor for the masks with shape [num_masks, height,
|
51 |
+
width].
|
52 |
+
attn_map: A torch.Tensor for the attention map with shape [1, 1, height,
|
53 |
+
width].
|
54 |
+
mask_list: A list of masks with shape [num_masks, height, width]
|
55 |
+
iom_thres: A float for the threshold to apply to the attention map.
|
56 |
+
min_pred_threshold: The prediction score threshold.
|
57 |
+
|
58 |
+
Returns:
|
59 |
+
A list of matched_masks with shape [num_masks, height, width],
|
60 |
+
len(matched_masks) = number of captions
|
61 |
+
"""
|
62 |
+
predictions = attn_map.squeeze(1).detach()
|
63 |
+
iom = IoM(predictions, mask_tensor, min_pred_threshold=min_pred_threshold)
|
64 |
+
keep_mask = iom > iom_thres
|
65 |
+
# mask_tensor = mask_tensor[keep_mask]
|
66 |
+
new_list = []
|
67 |
+
for mid, m_dict in enumerate(mask_list):
|
68 |
+
if keep_mask[mid]:
|
69 |
+
new_list.append(m_dict)
|
70 |
+
# if not len(new_list):
|
71 |
+
if not new_list:
|
72 |
+
max_id = torch.argmax(iom)
|
73 |
+
new_list.append(mask_list[max_id])
|
74 |
+
return new_list
|
75 |
+
|
76 |
+
|
77 |
+
def post_process_mask(attn_masks, pad=None, min_area_ratio=0.15):
|
78 |
+
"""Post process attention masks."""
|
79 |
+
if pad is not None:
|
80 |
+
left, top, width, height = pad
|
81 |
+
attn_masks = attn_masks[Ellipsis, top : top + height, left : left + width]
|
82 |
+
else:
|
83 |
+
height = None
|
84 |
+
width = None
|
85 |
+
mask_area = attn_masks.sum(dim=(1, 2))
|
86 |
+
total_area = mask_area.sum()
|
87 |
+
keep_mask = mask_area / total_area > min_area_ratio
|
88 |
+
if torch.sum(keep_mask) == 0:
|
89 |
+
if keep_mask.shape[0] == 0:
|
90 |
+
return torch.zeros(
|
91 |
+
(1, height, width), device=attn_masks.device, dtype=attn_masks.dtype
|
92 |
+
)
|
93 |
+
keep_mask[torch.argmax(mask_area)] = True
|
94 |
+
attn_masks = attn_masks[keep_mask]
|
95 |
+
return attn_masks
|
96 |
+
|
97 |
+
|
98 |
+
def filter_masks(
|
99 |
+
attn_masks,
|
100 |
+
pad=None,
|
101 |
+
mask_threshold=0.3,
|
102 |
+
min_area_ratio=0.15,
|
103 |
+
return_largest=False,
|
104 |
+
device=None,
|
105 |
+
return_instances=False,
|
106 |
+
):
|
107 |
+
"""Filter attention mask below the threshold."""
|
108 |
+
attn_masks[attn_masks < mask_threshold] = 0
|
109 |
+
# get_instances will be operated on cpu
|
110 |
+
ins_masks = get_instances(attn_masks, return_largest=return_largest)
|
111 |
+
ins_masks = [post_process_mask(m, pad, min_area_ratio) for m in ins_masks]
|
112 |
+
ins_masks = list(filter(lambda x: x is not None, ins_masks))
|
113 |
+
ins_masks = [m.to(device) for m in ins_masks]
|
114 |
+
if not return_instances:
|
115 |
+
return [torch.any(m, dim=0, keepdim=True).to(m.dtype) for m in ins_masks]
|
116 |
+
return ins_masks
|
117 |
+
|
118 |
+
|
119 |
+
def post_process(
|
120 |
+
input_array,
|
121 |
+
attn_masks,
|
122 |
+
pad=None,
|
123 |
+
mask_threshold=0.3,
|
124 |
+
return_largest=False,
|
125 |
+
min_area_ratio=0.15,
|
126 |
+
return_instances=False,
|
127 |
+
):
|
128 |
+
"""post process the input tensor with the attention masks.
|
129 |
+
|
130 |
+
Args:
|
131 |
+
input_array: A np.ndarray input array to be post processed with shape
|
132 |
+
[width, height, 3, batch_size]
|
133 |
+
attn_masks: A torch.Tensor for the attention masks with shape [1,
|
134 |
+
num_texts, width, height]
|
135 |
+
pad: A list of padding: [pad_left, pad_top, width, height], where
|
136 |
+
pad_left, pad_top and width, height are int values.
|
137 |
+
mask_threshold: The threshold to binarize the mask.
|
138 |
+
return_largest: If true, return the largest connected component.
|
139 |
+
min_area_ratio: Keep the mask if its area is larger than this threshold.
|
140 |
+
return_instances: Whether to return instances or not.
|
141 |
+
|
142 |
+
Returns:
|
143 |
+
attn_masks: A list of tensors with shape [num_instances, height, width]
|
144 |
+
x num_texts, where len(attn_masks) = num_texts.
|
145 |
+
NOTE: the number_instances for each text (class) may vary.
|
146 |
+
The output is a binary tensor.
|
147 |
+
"""
|
148 |
+
if len(attn_masks.shape) == 3:
|
149 |
+
attn_masks = attn_masks[None]
|
150 |
+
img_width, img_height = input_array.shape[:2]
|
151 |
+
attn_masks = F.interpolate(
|
152 |
+
attn_masks, size=(img_height, img_width), mode='bicubic'
|
153 |
+
).squeeze(0)
|
154 |
+
device = attn_masks.device
|
155 |
+
output_masks = filter_masks(
|
156 |
+
attn_masks,
|
157 |
+
pad=pad,
|
158 |
+
mask_threshold=mask_threshold,
|
159 |
+
min_area_ratio=min_area_ratio,
|
160 |
+
return_largest=return_largest,
|
161 |
+
device=device,
|
162 |
+
return_instances=return_instances,
|
163 |
+
)
|
164 |
+
if pad is not None:
|
165 |
+
left, top, width, height = pad
|
166 |
+
input_array = input_array[top : top + height, left : left + width]
|
167 |
+
return input_array, output_masks
|
requirements.txt
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
tensorflow>=2.14.0
|
2 |
+
numpy>=1.16.4
|
3 |
+
torch>=2.0.0
|
4 |
+
torchvision>=0.15.1
|
sam/__init__.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2024 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""SAM(Segment Anything Model)."""
|
17 |
+
|
18 |
+
from .sam import *
|
19 |
+
from .utils import *
|
sam/sam.py
ADDED
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2024 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""A pipeline for segmenting objects using the SAM model."""
|
17 |
+
|
18 |
+
# Copyright 2024 The Google Research Authors.
|
19 |
+
# This file is based on the SAM (Segment Anything) and HQ-SAM.
|
20 |
+
#
|
21 |
+
# https://github.com/facebookresearch/segment-anything
|
22 |
+
# https://github.com/SysCV/sam-hq/tree/main
|
23 |
+
#
|
24 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
25 |
+
# you may not use this file except in compliance with the License.
|
26 |
+
# You may obtain a copy of the License at
|
27 |
+
#
|
28 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
29 |
+
#
|
30 |
+
# Unless required by applicable law or agreed to in writing, software
|
31 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
32 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
33 |
+
# See the License for the specific language governing permissions and
|
34 |
+
# limitations under the License.
|
35 |
+
|
36 |
+
|
37 |
+
# pylint: disable=all
|
38 |
+
# pylint: disable=g-importing-member
|
39 |
+
import os
|
40 |
+
import cv2
|
41 |
+
import matplotlib.pyplot as plt
|
42 |
+
import numpy as np
|
43 |
+
from sam.utils import show_anns
|
44 |
+
from sam.utils import show_box
|
45 |
+
from sam.utils import show_mask
|
46 |
+
from sam.utils import show_points
|
47 |
+
from segment_anything import sam_model_registry
|
48 |
+
from segment_anything import SamAutomaticMaskGenerator
|
49 |
+
from segment_anything import SamPredictor
|
50 |
+
|
51 |
+
|
52 |
+
class SAMPipeline:
|
53 |
+
|
54 |
+
def __init__(
|
55 |
+
self,
|
56 |
+
checkpoint,
|
57 |
+
model_type,
|
58 |
+
device="cuda:0",
|
59 |
+
points_per_side=32,
|
60 |
+
pred_iou_thresh=0.88,
|
61 |
+
stability_score_thresh=0.95,
|
62 |
+
box_nms_thresh=0.7,
|
63 |
+
):
|
64 |
+
self.checkpoint = checkpoint
|
65 |
+
self.model_type = model_type
|
66 |
+
self.device = device
|
67 |
+
self.sam = sam_model_registry[self.model_type](checkpoint=self.checkpoint)
|
68 |
+
self.sam.to(device=self.device)
|
69 |
+
self.load_mask_generator(
|
70 |
+
points_per_side=points_per_side,
|
71 |
+
pred_iou_thresh=pred_iou_thresh,
|
72 |
+
stability_score_thresh=stability_score_thresh,
|
73 |
+
box_nms_thresh=box_nms_thresh,
|
74 |
+
)
|
75 |
+
|
76 |
+
# Default Prompt Args
|
77 |
+
self.click_args = {"k": 5, "order": "max", "how_filter": "median"}
|
78 |
+
self.box_args = None
|
79 |
+
|
80 |
+
def load_sam(self):
|
81 |
+
print("Loading SAM")
|
82 |
+
sam = sam_model_registry[self.model_type](checkpoint=self.checkpoint)
|
83 |
+
sam.to(device=self.device)
|
84 |
+
self.predictor = SamPredictor(sam)
|
85 |
+
print("Loading Done")
|
86 |
+
|
87 |
+
def load_mask_generator(
|
88 |
+
self,
|
89 |
+
points_per_side,
|
90 |
+
pred_iou_thresh,
|
91 |
+
stability_score_thresh,
|
92 |
+
box_nms_thresh,
|
93 |
+
):
|
94 |
+
print("Loading SAM")
|
95 |
+
self.mask_generator = SamAutomaticMaskGenerator(
|
96 |
+
model=self.sam,
|
97 |
+
points_per_side=points_per_side,
|
98 |
+
pred_iou_thresh=pred_iou_thresh,
|
99 |
+
stability_score_thresh=stability_score_thresh,
|
100 |
+
box_nms_thresh=box_nms_thresh,
|
101 |
+
crop_n_layers=0,
|
102 |
+
crop_n_points_downscale_factor=1,
|
103 |
+
)
|
104 |
+
print("Loading Done")
|
105 |
+
|
106 |
+
# segment single object
|
107 |
+
def segment_image_single(
|
108 |
+
self,
|
109 |
+
image_path,
|
110 |
+
input_point=None,
|
111 |
+
input_label=None,
|
112 |
+
input_box=None,
|
113 |
+
input_mask=None,
|
114 |
+
multimask_output=True,
|
115 |
+
visualize=False,
|
116 |
+
save_path=None,
|
117 |
+
fname="",
|
118 |
+
image=None,
|
119 |
+
):
|
120 |
+
if image is None:
|
121 |
+
image = cv2.imread(image_path)
|
122 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
123 |
+
self.predictor.set_image(image)
|
124 |
+
masks, scores, logits = self.predictor.predict(
|
125 |
+
point_coords=input_point,
|
126 |
+
point_labels=input_label,
|
127 |
+
box=input_box,
|
128 |
+
mask_input=None,
|
129 |
+
multimask_output=multimask_output,
|
130 |
+
)
|
131 |
+
|
132 |
+
if visualize:
|
133 |
+
self.visualize(
|
134 |
+
image,
|
135 |
+
masks,
|
136 |
+
scores,
|
137 |
+
save_path,
|
138 |
+
input_point=input_point,
|
139 |
+
input_label=input_label,
|
140 |
+
input_box=input_box,
|
141 |
+
input_mask=input_mask,
|
142 |
+
fname=fname,
|
143 |
+
)
|
144 |
+
|
145 |
+
return masks, scores, logits
|
146 |
+
|
147 |
+
def segment_automask(
|
148 |
+
self,
|
149 |
+
image_path,
|
150 |
+
visualize=False,
|
151 |
+
save_path=None,
|
152 |
+
image=None,
|
153 |
+
fname="automask.jpg",
|
154 |
+
):
|
155 |
+
if image is None:
|
156 |
+
image = cv2.imread(image_path)
|
157 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
158 |
+
|
159 |
+
mask_list, bbox_list = [], []
|
160 |
+
masks = self.mask_generator.generate(image)
|
161 |
+
mask_list.extend([mask["segmentation"] for mask in masks])
|
162 |
+
bbox_list.extend([mask["bbox"] for mask in masks])
|
163 |
+
|
164 |
+
if visualize:
|
165 |
+
self.visualize_automask(image, masks, save_path, fname=fname)
|
166 |
+
|
167 |
+
masks_arr, bbox_arr = np.array(mask_list), np.array(bbox_list)
|
168 |
+
return masks_arr, bbox_arr, masks
|
169 |
+
|
170 |
+
def visualize_automask(self, image, masks, save_path, fname="mask.jpg"):
|
171 |
+
if not os.path.exists(save_path):
|
172 |
+
os.makedirs(save_path)
|
173 |
+
plt.figure(figsize=(20, 20))
|
174 |
+
plt.imshow(image)
|
175 |
+
show_anns(masks)
|
176 |
+
plt.axis("off")
|
177 |
+
plt.savefig(os.path.join(save_path, fname))
|
178 |
+
|
179 |
+
def visualize(
|
180 |
+
self,
|
181 |
+
image,
|
182 |
+
masks,
|
183 |
+
scores,
|
184 |
+
save_path,
|
185 |
+
input_point=None,
|
186 |
+
input_label=None,
|
187 |
+
input_box=None,
|
188 |
+
input_mask=None,
|
189 |
+
fname="",
|
190 |
+
):
|
191 |
+
for i, (mask, score) in enumerate(zip(masks, scores)):
|
192 |
+
plt.figure(figsize=(10, 10))
|
193 |
+
plt.imshow(image)
|
194 |
+
show_mask(mask, plt.gca())
|
195 |
+
if input_point is not None:
|
196 |
+
show_points(input_point, input_label, plt.gca())
|
197 |
+
if input_box is not None:
|
198 |
+
show_box(input_box, plt.gca())
|
199 |
+
if input_mask is not None:
|
200 |
+
show_mask(input_mask[0], plt.gca(), True)
|
201 |
+
plt.title(f"Mask {i+1}, Score: {score:.3f}", fontsize=18)
|
202 |
+
plt.axis("off")
|
203 |
+
plt.savefig(os.path.join(save_path, f"{fname}{i}.jpg"))
|
204 |
+
|
205 |
+
return input_point, input_label, input_box, input_mask
|
sam/utils.py
ADDED
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2024 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
# Copyright 2024 The Google Research Authors.
|
17 |
+
# This file is based on the SAM (Segment Anything) and HQ-SAM.
|
18 |
+
#
|
19 |
+
# https://github.com/facebookresearch/segment-anything
|
20 |
+
# https://github.com/SysCV/sam-hq/tree/main
|
21 |
+
#
|
22 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
23 |
+
# you may not use this file except in compliance with the License.
|
24 |
+
# You may obtain a copy of the License at
|
25 |
+
#
|
26 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
27 |
+
#
|
28 |
+
# Unless required by applicable law or agreed to in writing, software
|
29 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
30 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
31 |
+
# See the License for the specific language governing permissions and
|
32 |
+
# limitations under the License.
|
33 |
+
|
34 |
+
"""SAM Utilities."""
|
35 |
+
# pylint: disable=all
|
36 |
+
# pylint: disable=g-importing-member
|
37 |
+
import json
|
38 |
+
import matplotlib.pyplot as plt
|
39 |
+
import numpy as np
|
40 |
+
from scipy.spatial.distance import cdist
|
41 |
+
|
42 |
+
|
43 |
+
def show_mask(mask, ax, random_color=False):
|
44 |
+
if random_color:
|
45 |
+
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
|
46 |
+
else:
|
47 |
+
color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
|
48 |
+
h, w = mask.shape[-2:]
|
49 |
+
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
|
50 |
+
ax.imshow(mask_image)
|
51 |
+
|
52 |
+
|
53 |
+
def show_points(coords, labels, ax, marker_size=375):
|
54 |
+
pos_points = coords[labels == 1]
|
55 |
+
neg_points = coords[labels == 0]
|
56 |
+
ax.scatter(
|
57 |
+
pos_points[:, 0],
|
58 |
+
pos_points[:, 1],
|
59 |
+
color='green',
|
60 |
+
marker='*',
|
61 |
+
s=marker_size,
|
62 |
+
edgecolor='white',
|
63 |
+
linewidth=1.25,
|
64 |
+
)
|
65 |
+
ax.scatter(
|
66 |
+
neg_points[:, 0],
|
67 |
+
neg_points[:, 1],
|
68 |
+
color='red',
|
69 |
+
marker='*',
|
70 |
+
s=marker_size,
|
71 |
+
edgecolor='white',
|
72 |
+
linewidth=1.25,
|
73 |
+
)
|
74 |
+
|
75 |
+
|
76 |
+
def show_box(box, ax):
|
77 |
+
x0, y0, x1, y1 = box
|
78 |
+
w, h = x1 - x0, y1 - y0
|
79 |
+
ax.add_patch(
|
80 |
+
plt.Rectangle(
|
81 |
+
(x0, y0), w, h, edgecolor='red', facecolor=(0, 0, 0, 0), lw=2
|
82 |
+
)
|
83 |
+
)
|
84 |
+
|
85 |
+
|
86 |
+
def show_anns(anns):
|
87 |
+
if len(anns) == 0:
|
88 |
+
return
|
89 |
+
for index, dictionary in enumerate(anns):
|
90 |
+
dictionary['id'] = index
|
91 |
+
|
92 |
+
sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
|
93 |
+
ax = plt.gca()
|
94 |
+
ax.set_autoscale_on(False)
|
95 |
+
# polygons = []
|
96 |
+
# color = []
|
97 |
+
for ann in sorted_anns:
|
98 |
+
m = ann['segmentation']
|
99 |
+
img = np.ones((m.shape[0], m.shape[1], 3))
|
100 |
+
color_mask = np.random.random((1, 3)).tolist()[0]
|
101 |
+
for i in range(3):
|
102 |
+
img[:, :, i] = color_mask[i]
|
103 |
+
ax.imshow(np.dstack((img, m * 0.35)))
|
104 |
+
|
105 |
+
# Get the centroid of the mask
|
106 |
+
mask_y, mask_x = np.nonzero(m)
|
107 |
+
centroid_x, centroid_y = np.mean(mask_x), np.mean(mask_y)
|
108 |
+
|
109 |
+
# Display the mask ID
|
110 |
+
mask_id = ann['id']
|
111 |
+
ax.text(
|
112 |
+
centroid_x,
|
113 |
+
centroid_y,
|
114 |
+
str(mask_id),
|
115 |
+
color='black',
|
116 |
+
fontsize=48,
|
117 |
+
weight='bold',
|
118 |
+
)
|
119 |
+
|
120 |
+
|
121 |
+
# Turn CAM result to SAM prompt
|
122 |
+
def aggregate_RGB_channel(activation_mask, how='max'):
|
123 |
+
B, C, H, W = activation_mask.shape
|
124 |
+
if how == 'max':
|
125 |
+
res_activation_mask = np.amax(activation_mask, axis=1, keepdims=True)
|
126 |
+
elif how == 'avr':
|
127 |
+
res_activation_mask = np.mean(activation_mask, axis=1, keepdims=True)
|
128 |
+
res_activation_mask = res_activation_mask.reshape(B, 1, H * W)
|
129 |
+
|
130 |
+
res_activation_mask = np.squeeze(res_activation_mask, axis=1)
|
131 |
+
return res_activation_mask
|
132 |
+
|
133 |
+
|
134 |
+
def find_k_points(arr, k, order='max', how_filter='median'):
|
135 |
+
arr = arr.squeeze(0)
|
136 |
+
flat_indices = np.argpartition(arr.flatten(), -k)[-k:]
|
137 |
+
unravel_topk_idx = np.unravel_index(flat_indices, arr.shape)
|
138 |
+
topk_indices = np.array(unravel_topk_idx).transpose()[:, ::-1]
|
139 |
+
# print(topk_indices.shape)
|
140 |
+
|
141 |
+
if how_filter == 'random':
|
142 |
+
random_rows = np.random.choice(
|
143 |
+
topk_indices.shape[0], size=int(round(k / 16)), replace=False
|
144 |
+
)
|
145 |
+
topk_indices = topk_indices[random_rows]
|
146 |
+
elif how_filter == 'median':
|
147 |
+
distances = cdist(topk_indices, topk_indices)
|
148 |
+
distances = np.sum(distances, axis=1)
|
149 |
+
median_distance = np.median(distances)
|
150 |
+
filtered_idx = [
|
151 |
+
i for i in range(len(distances)) if distances[i] < median_distance
|
152 |
+
]
|
153 |
+
topk_indices = topk_indices[filtered_idx]
|
154 |
+
return topk_indices
|
155 |
+
|
156 |
+
|
157 |
+
def max_sum_submatrix(matrix):
|
158 |
+
matrix = np.array(matrix)
|
159 |
+
H, W = matrix.shape
|
160 |
+
# Preprocess cumulative sums for rows
|
161 |
+
matrix[:, 1:] += matrix[:, :-1]
|
162 |
+
max_sum = float('-inf')
|
163 |
+
max_rect = (0, 0, 0, 0) # (top, left, bottom, right)
|
164 |
+
|
165 |
+
for left in range(W):
|
166 |
+
for right in range(left, W):
|
167 |
+
# Apply 1D Kadane's algorithm for the current pair of columns
|
168 |
+
column_sum = matrix[:, right] - (matrix[:, left - 1] if left > 0 else 0)
|
169 |
+
max_ending_here = max_so_far = column_sum[0]
|
170 |
+
start, end = 0, 0
|
171 |
+
|
172 |
+
for i in range(1, H):
|
173 |
+
val = column_sum[i]
|
174 |
+
if max_ending_here > 0:
|
175 |
+
max_ending_here += val
|
176 |
+
else:
|
177 |
+
max_ending_here = val
|
178 |
+
start = i
|
179 |
+
|
180 |
+
if max_ending_here > max_so_far:
|
181 |
+
max_so_far = max_ending_here
|
182 |
+
end = i
|
183 |
+
|
184 |
+
if max_so_far > max_sum:
|
185 |
+
max_sum = max_so_far
|
186 |
+
max_rect = (start, left, end, right)
|
187 |
+
|
188 |
+
return max_sum, max_rect
|
189 |
+
|
190 |
+
|
191 |
+
def CAM2SAMClick(activation_map, k=5, order='max', how_filter='median'):
|
192 |
+
# activation_map = aggregate_RGB_channel(activation_map)
|
193 |
+
H, W, C = activation_map.shape
|
194 |
+
activation_map = activation_map.reshape((1, 1, H, W))
|
195 |
+
coords = []
|
196 |
+
for nrow in range(activation_map.shape[0]):
|
197 |
+
coord = find_k_points(activation_map[nrow], k, order, how_filter)
|
198 |
+
coords.append(coord)
|
199 |
+
return coords
|
200 |
+
|
201 |
+
|
202 |
+
def CAM2SAMBox(activation_map):
|
203 |
+
# print(activation_map.shape)
|
204 |
+
# activation_map = aggregate_RGB_channel(activation_map)
|
205 |
+
H, W, C = activation_map.shape
|
206 |
+
activation_map = activation_map.reshape((1, H, W))
|
207 |
+
box_coordinates = []
|
208 |
+
for nrow in range(activation_map.shape[0]):
|
209 |
+
# print(activation_map[nrow].shape)
|
210 |
+
arr = activation_map[nrow]
|
211 |
+
|
212 |
+
norm_arr = 2 * ((arr - np.min(arr)) / (np.max(arr) - np.min(arr))) - 1
|
213 |
+
# print(norm_arr.shape)
|
214 |
+
_, box_coordinate = max_sum_submatrix(norm_arr)
|
215 |
+
box_coordinates.append(box_coordinate)
|
216 |
+
return box_coordinates
|
217 |
+
|
218 |
+
|
219 |
+
# Visualize
|
220 |
+
def visualize_attention(arr, filename):
|
221 |
+
# Create a figure and axes object
|
222 |
+
fig, ax = plt.subplots()
|
223 |
+
# Display the array as an image
|
224 |
+
im = ax.imshow(arr)
|
225 |
+
# Add a colorbar
|
226 |
+
ax.figure.colorbar(im, ax=ax)
|
227 |
+
# cbar = ax.figure.colorbar(im, ax=ax)
|
228 |
+
# Save the figure as a PNG file
|
229 |
+
fig.savefig(filename)
|
230 |
+
|
231 |
+
|
232 |
+
# Build config
|
233 |
+
def build_sam_config(config_path):
|
234 |
+
with open(config_path, 'r') as infile:
|
235 |
+
config = json.load(infile)
|
236 |
+
|
237 |
+
sam_checkpoint = config['model']['sam_checkpoint']
|
238 |
+
model_type = config['model']['model_type']
|
239 |
+
return sam_checkpoint, model_type
|
utils/__init__.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2024 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
utils/inference_pipeline.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2024 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""The inference pipeline for the CaR model."""
|
17 |
+
|
18 |
+
import numpy as np
|
19 |
+
from PIL import Image
|
20 |
+
import torch
|
21 |
+
|
22 |
+
# pylint: disable=g-importing-member
|
23 |
+
# pylint: disable=g-bad-import-order
|
24 |
+
from modeling.post_process.post_process import generate_masks_from_sam
|
25 |
+
from modeling.post_process.post_process import match_masks
|
26 |
+
from utils.utils import process_sentence
|
27 |
+
from utils.metrics import IoU
|
28 |
+
|
29 |
+
IMAGE_WIDTH = 512
|
30 |
+
IMAGE_HEIGHT = 512
|
31 |
+
|
32 |
+
|
33 |
+
def get_sam_masks(
|
34 |
+
config, image_path, masks, matching_thresh=0.9, img_sam=None, pipeline=None
|
35 |
+
):
|
36 |
+
"""Generate SAM masks."""
|
37 |
+
print("generating sam masks online")
|
38 |
+
mask_tensor, mask_list = generate_masks_from_sam(
|
39 |
+
image_path,
|
40 |
+
save_path="./",
|
41 |
+
pipeline=pipeline,
|
42 |
+
img_sam=img_sam,
|
43 |
+
visualize=False,
|
44 |
+
)
|
45 |
+
mask_tensor = mask_tensor.to(masks.device)
|
46 |
+
# only conduct sam on masks that is not all zero
|
47 |
+
attn_map, mask_ids = [], []
|
48 |
+
for mask_id, mask in enumerate(masks):
|
49 |
+
if torch.sum(mask) > 0:
|
50 |
+
attn_map.append(mask.unsqueeze(0))
|
51 |
+
mask_ids.append(mask_id)
|
52 |
+
matched_masks = [
|
53 |
+
match_masks(
|
54 |
+
mask_tensor,
|
55 |
+
attn,
|
56 |
+
mask_list,
|
57 |
+
iom_thres=config.car.iom_thres,
|
58 |
+
min_pred_threshold=config.sam.min_pred_threshold,
|
59 |
+
)
|
60 |
+
for attn in attn_map
|
61 |
+
]
|
62 |
+
for matched_mask, mask_id in zip(matched_masks, mask_ids):
|
63 |
+
sam_masks = np.array([item["segmentation"] for item in matched_mask])
|
64 |
+
sam_mask = np.any(sam_masks, axis=0)
|
65 |
+
cur_mask = masks[mask_id]
|
66 |
+
iou = IoU(torch.from_numpy(sam_mask).to(cur_mask.device), cur_mask)
|
67 |
+
if iou > matching_thresh:
|
68 |
+
masks[mask_id] = torch.from_numpy(sam_mask).to(masks.device)
|
69 |
+
return masks
|
70 |
+
|
71 |
+
|
72 |
+
def inference_car(cfg, car_model, image_path, sentences, sam_pipeline=None):
|
73 |
+
sentences = [process_sentence(sen, cfg.test.ds_name) for sen in sentences]
|
74 |
+
img = Image.open(image_path).convert("RGB")
|
75 |
+
if cfg.test.use_pseudo:
|
76 |
+
masks, scores = car_model(img, sentences)
|
77 |
+
return masks, scores
|
78 |
+
|
79 |
+
masks, scores = car_model(img, sentences, cfg.car.num_iteration)
|
80 |
+
sam_masks = get_sam_masks(
|
81 |
+
cfg, image_path, masks, cfg.sam.matching_thresh, pipeline=sam_pipeline
|
82 |
+
)
|
83 |
+
return sam_masks, scores
|
utils/merge_mask.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2024 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""Mask merging functions for post-processing."""
|
17 |
+
|
18 |
+
import cv2
|
19 |
+
import numpy as np
|
20 |
+
import torch
|
21 |
+
import torch.nn.functional as F
|
22 |
+
|
23 |
+
|
24 |
+
def merge_masks_simple(
|
25 |
+
all_masks, target_h, target_w, threshold=0.5, scores=None
|
26 |
+
):
|
27 |
+
"""Merge masks."""
|
28 |
+
merged_mask = None
|
29 |
+
if scores is not None:
|
30 |
+
merged_mask = torch.sum(all_masks * scores[:, None, None], dim=0)
|
31 |
+
merged_mask /= torch.sum(scores)
|
32 |
+
merged_mask = merged_mask.detach().cpu().numpy()
|
33 |
+
# resize the mask to the target size
|
34 |
+
merged_mask = cv2.resize(merged_mask, (target_w, target_h))
|
35 |
+
merged_mask = np.where(merged_mask >= threshold, 1, 0).astype(np.uint8)
|
36 |
+
if np.sum(merged_mask) <= 0.05 * (target_h * target_w):
|
37 |
+
merged_mask = torch.any(all_masks > 0, dim=0)
|
38 |
+
merged_mask = merged_mask.detach().cpu().numpy().astype(np.uint8)
|
39 |
+
# resize the mask to the target size
|
40 |
+
merged_mask = cv2.resize(merged_mask, (target_w, target_h))
|
41 |
+
merged_mask = merged_mask > threshold
|
42 |
+
merged_mask = torch.from_numpy(merged_mask).float()
|
43 |
+
return merged_mask[None]
|
44 |
+
|
45 |
+
|
46 |
+
def merge_masks(all_masks, target_h, target_w, threshold=0.5):
|
47 |
+
all_masks = torch.from_numpy(np.stack(all_masks)).float()
|
48 |
+
mask_tensor = F.interpolate(
|
49 |
+
all_masks[None], size=(target_h, target_w), mode='bilinear'
|
50 |
+
).squeeze(0)
|
51 |
+
bg_mask = threshold * torch.ones((1, target_h, target_w))
|
52 |
+
merged_mask = torch.cat([bg_mask, mask_tensor], dim=0)
|
53 |
+
mask_idx = torch.argmax(merged_mask, dim=0)
|
54 |
+
merged_mask = mask_idx > 0
|
55 |
+
if merged_mask.sum() <= 0.05 * (target_h * target_w):
|
56 |
+
merged_mask = torch.any(mask_tensor, dim=0)
|
57 |
+
return merged_mask.float()[None]
|
utils/metrics.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2024 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""Metrics for evaluating the performance of the model."""
|
17 |
+
|
18 |
+
import torch
|
19 |
+
|
20 |
+
|
21 |
+
def IoU(mask1, mask2, threshold=0.5):
|
22 |
+
"""Calculate Intersection over Union (IoU) between prediction and GT masks.
|
23 |
+
|
24 |
+
Args:
|
25 |
+
mask1: A torch.Tensor denoting the prediction, shape (N, H, W), where N is
|
26 |
+
the number of masks.
|
27 |
+
mask2: A torch.Tensor denoting the ground truth, shape (N, H, W), where N
|
28 |
+
is the number of masks.
|
29 |
+
threshold: The threshold to binarize masks.
|
30 |
+
Returns:
|
31 |
+
IoU of `mask1` and `mask2`.
|
32 |
+
"""
|
33 |
+
if threshold > 0:
|
34 |
+
mask1, mask2 = (mask1 > threshold).to(torch.bool), (mask2 > threshold).to(
|
35 |
+
torch.bool
|
36 |
+
)
|
37 |
+
intersection = torch.sum(mask1 * (mask1 == mask2), dim=[-1, -2]).squeeze()
|
38 |
+
union = torch.sum(mask1 + mask2, dim=[-1, -2]).squeeze()
|
39 |
+
if union.sum() == 0:
|
40 |
+
return 0
|
41 |
+
return (intersection.to(torch.float) / union).mean().item()
|
42 |
+
|
43 |
+
|
44 |
+
def IoM(pred, target, min_pred_threshold=0.2):
|
45 |
+
"""Calculate Intersection over the area of gt Mask and pred Mask (IoM).
|
46 |
+
|
47 |
+
between prediction and each ground truth masks.
|
48 |
+
Precaution:
|
49 |
+
this function works for prediction and target that are binary masks,
|
50 |
+
where 1 represents the mask and 0 represents the background.
|
51 |
+
Args:
|
52 |
+
pred: A torch.Tensor denoting the prediction, shape (N, H, W), where N is
|
53 |
+
the number of masks.
|
54 |
+
target: A torch.Tensor denoting the ground truth, shape (N, H, W), where N
|
55 |
+
is the number of masks.
|
56 |
+
min_pred_threshold: prediction threshold.
|
57 |
+
|
58 |
+
Returns:
|
59 |
+
ious: A torch.Tensor denoting the IoU, shape (N,).
|
60 |
+
"""
|
61 |
+
# calculate the intersection over all masks
|
62 |
+
intersection = torch.einsum("mij,nij->mn", pred.to(target.device), target)
|
63 |
+
area_pred = torch.einsum("mij->m", pred)
|
64 |
+
area_target = torch.einsum("nij->n", target)
|
65 |
+
# we calculate the IoM by dividing the intersection over the minimum area.
|
66 |
+
iom_target = torch.einsum("mn,n->mn", intersection, 1 / area_target)
|
67 |
+
iom_pred = torch.einsum("mn,m->mn", intersection, 1 / area_pred)
|
68 |
+
# if the intersection is smaller than a certain percentage of the area of
|
69 |
+
# the pred mask, we consider it as background.
|
70 |
+
iom_target[iom_pred < min_pred_threshold] = 0
|
71 |
+
# we consider the IoM as the maximum IoM between the pred mask and
|
72 |
+
# the target mask.
|
73 |
+
iom = torch.max(iom_target, iom_pred)
|
74 |
+
iom = iom.max(dim=0)[0]
|
75 |
+
return iom
|
utils/nlp.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2024 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""Language processing utilities."""
|
17 |
+
|
18 |
+
import spacy
|
19 |
+
|
20 |
+
|
21 |
+
def load_spacy_model(model='en_core_web_trf'):
|
22 |
+
nlp = spacy.load(model)
|
23 |
+
return nlp
|
24 |
+
|
25 |
+
|
26 |
+
def process_sentence(sentence, nlp):
|
27 |
+
"""Process a sentence."""
|
28 |
+
doc = nlp(sentence)
|
29 |
+
sentence_for_spacy = []
|
30 |
+
|
31 |
+
for _, token in enumerate(doc):
|
32 |
+
if token.text == ' ':
|
33 |
+
continue
|
34 |
+
sentence_for_spacy.append(token.text)
|
35 |
+
|
36 |
+
sentence_for_spacy = ' '.join(sentence_for_spacy)
|
37 |
+
noun_phrase, _, _ = extract_noun_phrase(
|
38 |
+
sentence_for_spacy, nlp, need_index=True
|
39 |
+
)
|
40 |
+
return noun_phrase
|
41 |
+
|
42 |
+
|
43 |
+
def extract_noun_phrase(text, nlp, need_index=False):
|
44 |
+
"""Extract noun phrase from text. nlp is a spacy model.
|
45 |
+
|
46 |
+
Args:
|
47 |
+
text: str, text to be processed.
|
48 |
+
nlp: spacy model.
|
49 |
+
need_index: bool, whether to return the index of the noun phrase.
|
50 |
+
|
51 |
+
Returns:
|
52 |
+
noun_phrase: str, noun phrase of the text.
|
53 |
+
"""
|
54 |
+
# text = text.lower()
|
55 |
+
|
56 |
+
doc = nlp(text)
|
57 |
+
|
58 |
+
chunks = {}
|
59 |
+
chunks_index = {}
|
60 |
+
for chunk in doc.noun_chunks:
|
61 |
+
for i in range(chunk.start, chunk.end):
|
62 |
+
chunks[i] = chunk
|
63 |
+
chunks_index[i] = (chunk.start, chunk.end)
|
64 |
+
|
65 |
+
for token in doc:
|
66 |
+
if token.head.i == token.i:
|
67 |
+
head = token.head
|
68 |
+
|
69 |
+
if head.i not in chunks:
|
70 |
+
children = list(head.children)
|
71 |
+
if children and children[0].i in chunks:
|
72 |
+
head = children[0]
|
73 |
+
else:
|
74 |
+
if need_index:
|
75 |
+
return text, [], text
|
76 |
+
else:
|
77 |
+
return text
|
78 |
+
|
79 |
+
head_noun = head.text
|
80 |
+
head_index = chunks_index[head.i]
|
81 |
+
head_index = [i for i in range(head_index[0], head_index[1])]
|
82 |
+
|
83 |
+
sentence_index = [i for i in range(len(doc))]
|
84 |
+
not_phrase_index = []
|
85 |
+
for i in sentence_index:
|
86 |
+
# not_phrase_index.append(i) if i not in head_index else None
|
87 |
+
if i not in head_index:
|
88 |
+
not_phrase_index.append(i)
|
89 |
+
|
90 |
+
head = chunks[head.i]
|
91 |
+
if need_index:
|
92 |
+
return head.text, not_phrase_index, head_noun
|
93 |
+
else:
|
94 |
+
return head.text
|
utils/utils.py
ADDED
@@ -0,0 +1,277 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2024 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""Utility functions for the project."""
|
17 |
+
|
18 |
+
from __future__ import print_function
|
19 |
+
# pylint: disable=g-importing-member
|
20 |
+
from collections import defaultdict
|
21 |
+
from collections import deque
|
22 |
+
from copy import deepcopy
|
23 |
+
import datetime
|
24 |
+
import errno
|
25 |
+
import os
|
26 |
+
import sys
|
27 |
+
import time
|
28 |
+
import numpy as np
|
29 |
+
from PIL import Image
|
30 |
+
import torch
|
31 |
+
from torchvision import transforms
|
32 |
+
import yaml
|
33 |
+
|
34 |
+
# pylint: disable=g-bad-import-order
|
35 |
+
from data.voc import CLASS2ID
|
36 |
+
from data.voc import VOC_CLASSES
|
37 |
+
|
38 |
+
|
39 |
+
_MB = 1024.0 * 1024.0
|
40 |
+
|
41 |
+
DINO_transform = transforms.Compose([
|
42 |
+
transforms.ToTensor(),
|
43 |
+
transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
|
44 |
+
])
|
45 |
+
|
46 |
+
|
47 |
+
class Config:
|
48 |
+
|
49 |
+
def __init__(self, **kwargs):
|
50 |
+
for key, value in kwargs.items():
|
51 |
+
if isinstance(value, dict):
|
52 |
+
setattr(self, key, Config(**value))
|
53 |
+
else:
|
54 |
+
setattr(self, key, value)
|
55 |
+
|
56 |
+
|
57 |
+
def load_yaml(filename):
|
58 |
+
with open(filename) as file:
|
59 |
+
try:
|
60 |
+
data = yaml.safe_load(file)
|
61 |
+
return data
|
62 |
+
except yaml.YAMLError as e:
|
63 |
+
print(f"Error while loading YAML file: {e}")
|
64 |
+
|
65 |
+
|
66 |
+
def normalize(x, dim=None, eps=1e-15):
|
67 |
+
if dim is None:
|
68 |
+
return (x - x.min()) / (x.max() - x.min())
|
69 |
+
# Normalize to [0, 1].
|
70 |
+
numerator = x - x.min(axis=dim, keepdims=True)[0]
|
71 |
+
denominator = (
|
72 |
+
x.max(axis=dim, keepdims=True)[0]
|
73 |
+
- x.min(axis=dim, keepdims=True)[0]
|
74 |
+
+ eps
|
75 |
+
)
|
76 |
+
return numerator / denominator
|
77 |
+
|
78 |
+
|
79 |
+
class SmoothedValue(object):
|
80 |
+
"""Track a series of values and provide access to smoothed values over a window or the global series average."""
|
81 |
+
|
82 |
+
def __init__(self, window_size=20, fmt=None):
|
83 |
+
if fmt is None:
|
84 |
+
fmt = "{median:.4f} ({global_avg:.4f})"
|
85 |
+
self.deque = deque(maxlen=window_size)
|
86 |
+
self.total = 0.0
|
87 |
+
self.count = 0
|
88 |
+
self.fmt = fmt
|
89 |
+
|
90 |
+
def update(self, value, n=1):
|
91 |
+
self.deque.append(value)
|
92 |
+
self.count += n
|
93 |
+
self.total += value * n
|
94 |
+
|
95 |
+
# def synchronize_between_processes(self):
|
96 |
+
# """
|
97 |
+
# Warning: does not synchronize the deque!
|
98 |
+
# """
|
99 |
+
# if not is_dist_avail_and_initialized():
|
100 |
+
# return
|
101 |
+
# t = torch.tensor([self.count, self.total],
|
102 |
+
# dtype=torch.float64, device='cuda')
|
103 |
+
# dist.barrier()
|
104 |
+
# dist.all_reduce(t)
|
105 |
+
# t = t.tolist()
|
106 |
+
# self.count = int(t[0])
|
107 |
+
# self.total = t[1]
|
108 |
+
|
109 |
+
@property
|
110 |
+
def median(self):
|
111 |
+
d = torch.tensor(list(self.deque))
|
112 |
+
return d.median().item()
|
113 |
+
|
114 |
+
@property
|
115 |
+
def avg(self):
|
116 |
+
d = torch.tensor(list(self.deque), dtype=torch.float32)
|
117 |
+
return d.mean().item()
|
118 |
+
|
119 |
+
@property
|
120 |
+
def global_avg(self):
|
121 |
+
return self.total / self.count
|
122 |
+
|
123 |
+
@property
|
124 |
+
def max(self):
|
125 |
+
return max(self.deque)
|
126 |
+
|
127 |
+
@property
|
128 |
+
def value(self):
|
129 |
+
return self.deque[-1]
|
130 |
+
|
131 |
+
def __str__(self):
|
132 |
+
return self.fmt.format(
|
133 |
+
median=self.median,
|
134 |
+
avg=self.avg,
|
135 |
+
global_avg=self.global_avg,
|
136 |
+
max=self.max,
|
137 |
+
value=self.value,
|
138 |
+
)
|
139 |
+
|
140 |
+
|
141 |
+
class MetricLogger(object):
|
142 |
+
"""Log the metrics."""
|
143 |
+
|
144 |
+
def __init__(self, delimiter="\t"):
|
145 |
+
self.meters = defaultdict(SmoothedValue)
|
146 |
+
self.delimiter = delimiter
|
147 |
+
|
148 |
+
def update(self, **kwargs):
|
149 |
+
for k, v in kwargs.items():
|
150 |
+
if isinstance(v, torch.Tensor):
|
151 |
+
v = v.item()
|
152 |
+
assert isinstance(v, (float, int))
|
153 |
+
self.meters[k].update(v)
|
154 |
+
|
155 |
+
def __getattr__(self, attr):
|
156 |
+
if attr in self.meters:
|
157 |
+
return self.meters[attr]
|
158 |
+
if attr in self.__dict__:
|
159 |
+
return self.__dict__[attr]
|
160 |
+
raise AttributeError(
|
161 |
+
"'{}' object has no attribute '{}'".format(type(self).__name__, attr)
|
162 |
+
)
|
163 |
+
|
164 |
+
def __str__(self):
|
165 |
+
loss_str = []
|
166 |
+
for name, meter in self.meters.items():
|
167 |
+
loss_str.append("{}: {}".format(name, str(meter)))
|
168 |
+
return self.delimiter.join(loss_str)
|
169 |
+
|
170 |
+
def synchronize_between_processes(self):
|
171 |
+
for meter in self.meters.values():
|
172 |
+
meter.synchronize_between_processes()
|
173 |
+
|
174 |
+
def add_meter(self, name, meter):
|
175 |
+
self.meters[name] = meter
|
176 |
+
|
177 |
+
def log_every(self, iterable, print_freq, header=None):
|
178 |
+
"""Log every `print_freq` times."""
|
179 |
+
i = 0
|
180 |
+
if not header:
|
181 |
+
header = ""
|
182 |
+
start_time = time.time()
|
183 |
+
end = time.time()
|
184 |
+
iter_time = SmoothedValue(fmt="{avg:.4f}")
|
185 |
+
data_time = SmoothedValue(fmt="{avg:.4f}")
|
186 |
+
space_fmt = ":" + str(len(str(len(iterable)))) + "d"
|
187 |
+
log_msg = self.delimiter.join([
|
188 |
+
header,
|
189 |
+
"[{0" + space_fmt + "}/{1}]",
|
190 |
+
"eta: {eta}",
|
191 |
+
"{meters}",
|
192 |
+
"time: {time}",
|
193 |
+
"data: {data}",
|
194 |
+
"max mem: {memory:.0f}",
|
195 |
+
])
|
196 |
+
for obj in iterable:
|
197 |
+
data_time.update(time.time() - end)
|
198 |
+
yield obj
|
199 |
+
iter_time.update(time.time() - end)
|
200 |
+
if i % print_freq == 0:
|
201 |
+
eta_seconds = iter_time.global_avg * (len(iterable) - i)
|
202 |
+
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
|
203 |
+
print(
|
204 |
+
log_msg.format(
|
205 |
+
i,
|
206 |
+
len(iterable),
|
207 |
+
eta=eta_string,
|
208 |
+
meters=str(self),
|
209 |
+
time=str(iter_time),
|
210 |
+
data=str(data_time),
|
211 |
+
memory=torch.cuda.max_memory_allocated() / _MB,
|
212 |
+
)
|
213 |
+
)
|
214 |
+
sys.stdout.flush()
|
215 |
+
|
216 |
+
i += 1
|
217 |
+
end = time.time()
|
218 |
+
total_time = time.time() - start_time
|
219 |
+
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
220 |
+
print("{} Total time: {}".format(header, total_time_str))
|
221 |
+
|
222 |
+
|
223 |
+
def mkdir(path):
|
224 |
+
try:
|
225 |
+
os.makedirs(path)
|
226 |
+
except OSError as e:
|
227 |
+
if e.errno != errno.EEXIST:
|
228 |
+
raise
|
229 |
+
|
230 |
+
|
231 |
+
def pad_to_square(im):
|
232 |
+
"""Pad the images to square shape."""
|
233 |
+
im = deepcopy(im)
|
234 |
+
width, height = im.size
|
235 |
+
top_pad = (max(width, height) - height) // 2
|
236 |
+
bot_pad = max(width, height) - height - top_pad
|
237 |
+
left_pad = (max(width, height) - width) // 2
|
238 |
+
right_pad = max(width, height) - width - left_pad
|
239 |
+
|
240 |
+
if len(im.mode) == 3:
|
241 |
+
color = (0, 0, 0)
|
242 |
+
elif len(im.mode) == 1:
|
243 |
+
color = 0
|
244 |
+
else:
|
245 |
+
raise ValueError(f"Image mode not supported. Image has {im.mode} channels.")
|
246 |
+
|
247 |
+
return add_margin(im, top_pad, right_pad, bot_pad, left_pad, color=color)
|
248 |
+
|
249 |
+
|
250 |
+
def add_margin(pil_img, top, right, bottom, left, color=(0, 0, 0)):
|
251 |
+
"""Ref: https://note.nkmk.me/en/python-pillow-add-margin-expand-canvas/."""
|
252 |
+
width, height = pil_img.size
|
253 |
+
new_width = width + right + left
|
254 |
+
new_height = height + top + bottom
|
255 |
+
result = Image.new(pil_img.mode, (new_width, new_height), color)
|
256 |
+
result.paste(pil_img, (left, top))
|
257 |
+
|
258 |
+
# 1 represents the image, 0 represents the padding
|
259 |
+
pad = [left, top, width, height]
|
260 |
+
return result, pad
|
261 |
+
|
262 |
+
|
263 |
+
def process_sentence(sentence, ds_name):
|
264 |
+
"""Dataset specific sentence processing."""
|
265 |
+
if "refcoco" in ds_name:
|
266 |
+
sentence = sentence[0].lower()
|
267 |
+
# get rid of special characters
|
268 |
+
sentence = sentence.replace('"', "")
|
269 |
+
sentence = sentence.replace("/", "")
|
270 |
+
if ds_name == "voc":
|
271 |
+
if sentence in list(CLASS2ID.keys()):
|
272 |
+
label_id = CLASS2ID[sentence] - 1
|
273 |
+
sentence = VOC_CLASSES[label_id]
|
274 |
+
|
275 |
+
if not isinstance(sentence, str):
|
276 |
+
sentence = sentence[0]
|
277 |
+
return sentence
|
utils/visualize.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2024 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""Visualization functions."""
|
17 |
+
|
18 |
+
import os
|
19 |
+
|
20 |
+
import cv2
|
21 |
+
import matplotlib.pyplot as plt
|
22 |
+
import numpy as np
|
23 |
+
from PIL import Image
|
24 |
+
import torch
|
25 |
+
# pylint: disable=g-importing-member
|
26 |
+
from utils.utils import normalize
|
27 |
+
|
28 |
+
_VIS_HEIGHT = 512
|
29 |
+
_VIS_WIDTH = 512
|
30 |
+
|
31 |
+
|
32 |
+
def show_cam_on_image(img, mask):
|
33 |
+
if img.shape[1] != mask.shape[1]:
|
34 |
+
mask = cv2.resize(mask, (img.shape[1], img.shape[0]))
|
35 |
+
heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
|
36 |
+
heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
|
37 |
+
heatmap = np.float32(heatmap) / 255
|
38 |
+
cam = heatmap + np.float32(img)
|
39 |
+
cam = cam / np.max(cam)
|
40 |
+
cam = np.uint8(255 * cam)
|
41 |
+
return cam
|
42 |
+
|
43 |
+
|
44 |
+
def save_img(array, img_name):
|
45 |
+
numpy_array = array.astype(np.uint8)
|
46 |
+
image = Image.fromarray(numpy_array, mode="RGB")
|
47 |
+
image.save(f"{img_name}.png")
|
48 |
+
|
49 |
+
|
50 |
+
def viz_attn(img, attn_map, prefix="vis_results/clipcam_img", img_name="cam"):
|
51 |
+
"""Visualize attention map."""
|
52 |
+
num_masks = 1
|
53 |
+
if len(attn_map.shape) == 3:
|
54 |
+
num_masks = attn_map.shape[0]
|
55 |
+
attn_map = attn_map.float().squeeze(1).detach().cpu().numpy()
|
56 |
+
attn_map = normalize(attn_map)
|
57 |
+
img = normalize(img)
|
58 |
+
if num_masks == 1:
|
59 |
+
vis = show_cam_on_image(img, attn_map)
|
60 |
+
if not os.path.exists(prefix):
|
61 |
+
os.makedirs(prefix)
|
62 |
+
save_img(vis, os.path.join(prefix, f"{img_name}"))
|
63 |
+
return vis
|
64 |
+
for i in range(num_masks):
|
65 |
+
vis = show_cam_on_image(img, attn_map[i])
|
66 |
+
if not os.path.exists(prefix):
|
67 |
+
os.makedirs(prefix)
|
68 |
+
save_img(vis, os.path.join(prefix, f"{img_name}_{i}"))
|
69 |
+
|
70 |
+
|
71 |
+
def vis_mask(mask, gt_mask, img, output_dir, fname):
|
72 |
+
"""Visualize mask."""
|
73 |
+
mask_img = torch.zeros((_VIS_WIDTH, _VIS_HEIGHT))
|
74 |
+
mask_img[mask[0]] = 1
|
75 |
+
|
76 |
+
# print(gt_mask.shape, img.size())
|
77 |
+
# Assume img and gt_mask are also torch.Tensor with size (512, 512)
|
78 |
+
img = img[0].permute(1, 2, 0).numpy()
|
79 |
+
gt_mask_img = torch.zeros((_VIS_WIDTH, _VIS_HEIGHT))
|
80 |
+
gt_mask_img[gt_mask[0]] = 1
|
81 |
+
|
82 |
+
_, axs = plt.subplots(
|
83 |
+
1, 3, figsize=(15, 5)
|
84 |
+
) # change the figsize if necessary
|
85 |
+
|
86 |
+
axs[0].imshow(img) # if image is grayscale, otherwise remove cmap argument
|
87 |
+
axs[0].axis("off")
|
88 |
+
axs[0].set_title("Original Image")
|
89 |
+
|
90 |
+
axs[1].imshow(
|
91 |
+
mask_img.numpy(), cmap="jet", alpha=0.5
|
92 |
+
) # using alpha for transparency
|
93 |
+
axs[1].axis("off")
|
94 |
+
axs[1].set_title("Mask")
|
95 |
+
|
96 |
+
axs[2].imshow(
|
97 |
+
gt_mask_img.numpy(), cmap="jet", alpha=0.5
|
98 |
+
) # using alpha for transparency
|
99 |
+
axs[2].axis("off")
|
100 |
+
axs[2].set_title("Ground Truth Mask")
|
101 |
+
|
102 |
+
plt.savefig(
|
103 |
+
os.path.join(output_dir, f"{fname}.jpg"),
|
104 |
+
bbox_inches="tight",
|
105 |
+
dpi=300,
|
106 |
+
pad_inches=0.0,
|
107 |
+
)
|