Spaces:
Sleeping
Sleeping
Upload 114 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +1 -0
- app.py +419 -0
- examples/000000027620.jpg +0 -0
- examples/000000112634.jpg +0 -0
- examples/000000165713.jpg +0 -0
- examples/000000190756.jpg +0 -0
- examples/000000226058.jpg +0 -0
- examples/000000234779.jpg +0 -0
- examples/000000243075.jpg +0 -0
- examples/000000263860.jpg +0 -0
- examples/000000284762.jpg +0 -0
- examples/000000298738.jpg +0 -0
- examples/000000372819.jpg +0 -0
- examples/000000377814.jpg +0 -0
- examples/000000516143.jpg +0 -0
- examples/000000555050.jpg +0 -0
- examples/dogs.jpg +0 -0
- examples/flowers.jpg +0 -0
- examples/fruits.jpg +0 -0
- examples/image.jpg +0 -0
- examples/truck.jpg +0 -0
- regionspot/__init__.py +10 -0
- regionspot/__pycache__/__init__.cpython-38.pyc +0 -0
- regionspot/__pycache__/automatic_mask_generator.cpython-38.pyc +0 -0
- regionspot/__pycache__/build.cpython-38.pyc +0 -0
- regionspot/__pycache__/config.cpython-38.pyc +0 -0
- regionspot/__pycache__/detector.cpython-38.pyc +0 -0
- regionspot/__pycache__/predictor.cpython-38.pyc +0 -0
- regionspot/__pycache__/test_time_augmentation.cpython-38.pyc +0 -0
- regionspot/automatic_mask_generator.py +632 -0
- regionspot/build.py +307 -0
- regionspot/config.py +39 -0
- regionspot/data/__pycache__/custom_dataset_dataloader.cpython-38.pyc +0 -0
- regionspot/data/__pycache__/dataset_mapper.cpython-38.pyc +0 -0
- regionspot/data/__pycache__/v3det_categories.cpython-38.pyc +3 -0
- regionspot/data/custom_dataset_dataloader.py +331 -0
- regionspot/data/dataset_mapper.py +140 -0
- regionspot/data/objects365.py +391 -0
- regionspot/data/openimages.py +34 -0
- regionspot/data/openimages_categories.py +1 -0
- regionspot/data/v3det.py +34 -0
- regionspot/data/v3det_categories.py +0 -0
- regionspot/detector.py +174 -0
- regionspot/modeling/__pycache__/constants.cpython-38.pyc +0 -0
- regionspot/modeling/__pycache__/decoder.cpython-38.pyc +0 -0
- regionspot/modeling/__pycache__/regionspot.cpython-38.pyc +0 -0
- regionspot/modeling/clip/__init__.py +1 -0
- regionspot/modeling/clip/__pycache__/__init__.cpython-38.pyc +0 -0
- regionspot/modeling/clip/__pycache__/clip.cpython-38.pyc +0 -0
- regionspot/modeling/clip/__pycache__/model.cpython-38.pyc +0 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
regionspot/data/__pycache__/v3det_categories.cpython-38.pyc filter=lfs diff=lfs merge=lfs -text
|
app.py
ADDED
@@ -0,0 +1,419 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torch
|
3 |
+
from PIL import ImageDraw
|
4 |
+
from PIL import Image
|
5 |
+
import numpy as np
|
6 |
+
from torchvision.transforms import ToPILImage
|
7 |
+
|
8 |
+
import matplotlib.pyplot as plt
|
9 |
+
import cv2
|
10 |
+
from regionspot.modeling.regionspot import build_regionspot_model
|
11 |
+
from regionspot import RegionSpot_Predictor
|
12 |
+
from regionspot import SamAutomaticMaskGenerator
|
13 |
+
import ast
|
14 |
+
|
15 |
+
fdic = {
|
16 |
+
# "family": "Impact",
|
17 |
+
# "style": "italic",
|
18 |
+
"size": 15,
|
19 |
+
# "color": "yellow",
|
20 |
+
# "weight": "bold",
|
21 |
+
}
|
22 |
+
|
23 |
+
def show_mask(mask, ax, random_color=False):
|
24 |
+
if random_color:
|
25 |
+
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
|
26 |
+
else:
|
27 |
+
color = np.array([30/255, 144/255, 255/255, 0.6])
|
28 |
+
h, w = mask.shape[-2:]
|
29 |
+
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
|
30 |
+
ax.imshow(mask_image)
|
31 |
+
|
32 |
+
# Function to show points on an image
|
33 |
+
def show_points(coords, labels, ax, marker_size=375):
|
34 |
+
pos_points = coords[labels == 1]
|
35 |
+
neg_points = coords[labels == 0]
|
36 |
+
ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
|
37 |
+
ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
|
38 |
+
|
39 |
+
# Function to show bounding boxes on an image
|
40 |
+
def show_box(box, ax):
|
41 |
+
x0, y0 = box[0], box[1]
|
42 |
+
w, h = box[2] - x0, box[3] - y0
|
43 |
+
ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor='none', lw=2))
|
44 |
+
|
45 |
+
def auto_show_box(box, label, ax):
|
46 |
+
x0, y0 = box[0], box[1]
|
47 |
+
w, h =box[2], box[3]
|
48 |
+
ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))
|
49 |
+
ax.text(x0,y0,f"{label}", fontdict=fdic)
|
50 |
+
|
51 |
+
def show_anns(image, anns, custom_vocabulary):
|
52 |
+
if anns == False:
|
53 |
+
plt.imshow(image)
|
54 |
+
ax = plt.gca()
|
55 |
+
ax.set_autoscale_on(False)
|
56 |
+
ax.imshow(image)
|
57 |
+
pic = plt.gcf()
|
58 |
+
pic.canvas.draw()
|
59 |
+
w,h = pic.canvas.get_width_height()
|
60 |
+
image = Image.frombytes('RGB', (w,h), pic.canvas.tostring_rgb())
|
61 |
+
return image
|
62 |
+
|
63 |
+
plt.imshow(image)
|
64 |
+
if len(anns) == 0:
|
65 |
+
return
|
66 |
+
sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
|
67 |
+
ax = plt.gca()
|
68 |
+
ax.set_autoscale_on(False)
|
69 |
+
|
70 |
+
img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 4))
|
71 |
+
img[:,:,3] = 0
|
72 |
+
for ann in sorted_anns:
|
73 |
+
l = custom_vocabulary[int(ann['pred_class'])]
|
74 |
+
if l != 'background':
|
75 |
+
m = ann['segmentation']
|
76 |
+
color_mask = np.concatenate([np.random.random(3), [0.35]])
|
77 |
+
img[m] = color_mask
|
78 |
+
b = ann['bbox']
|
79 |
+
auto_show_box(b,l, ax)
|
80 |
+
ax.imshow(img)
|
81 |
+
ax.axis('off')
|
82 |
+
pic = plt.gcf()
|
83 |
+
pic.canvas.draw()
|
84 |
+
w,h = pic.canvas.get_width_height()
|
85 |
+
image = Image.frombytes('RGB', (w,h), pic.canvas.tostring_rgb())
|
86 |
+
return image
|
87 |
+
|
88 |
+
def process_box(image, input_box, masks, mask_iou_score, class_score, class_index, custom_vocabulary):
|
89 |
+
# Extract class name and display image with masks and box
|
90 |
+
fig, ax = plt.subplots(figsize=(10, 10))
|
91 |
+
ax.imshow(image)
|
92 |
+
for idx in range(len(input_box)):
|
93 |
+
show_mask(masks[idx], ax)
|
94 |
+
show_box(input_box[idx], ax) # Assuming box_prompt contains all your boxes
|
95 |
+
# You might want to modify the text display for multiple classes as well
|
96 |
+
class_name = custom_vocabulary[int(class_index[idx])]
|
97 |
+
ax.text(input_box[idx][0], input_box[idx][1] - 10, class_name, color='white', fontsize=14, bbox=dict(facecolor='green', edgecolor='green', alpha=0.6))
|
98 |
+
|
99 |
+
ax.axis('off')
|
100 |
+
pic = plt.gcf()
|
101 |
+
pic.canvas.draw()
|
102 |
+
w,h = pic.canvas.get_width_height()
|
103 |
+
image = Image.frombytes('RGB', (w,h), pic.canvas.tostring_rgb())
|
104 |
+
return image
|
105 |
+
|
106 |
+
device = torch.device(
|
107 |
+
"cuda"
|
108 |
+
if torch.cuda.is_available()
|
109 |
+
else "mps"
|
110 |
+
if torch.backends.mps.is_available()
|
111 |
+
else "cpu"
|
112 |
+
)
|
113 |
+
|
114 |
+
# Description
|
115 |
+
title = "<center><strong><font size='8'> RegionSpot: Recognize Any Regions </font></strong></center>"
|
116 |
+
|
117 |
+
description_e = """ This is a demo on Github project [Recognize Any Regions](https://github.com/Surrey-UPLab/Recognize-Any-Regions). Welcome to give a star to it.
|
118 |
+
|
119 |
+
"""
|
120 |
+
|
121 |
+
description_p = """ This is a demo on Github project [Recognize Any Regions](https://github.com/Surrey-UPLab/Recognize-Any-Regions). Welcome to give a star to it.
|
122 |
+
|
123 |
+
"""
|
124 |
+
description_b = """ This is a demo on Github project [Recognize Any Regions](https://github.com/Surrey-UPLab/Recognize-Any-Regions). Welcome to give a star to it.
|
125 |
+
|
126 |
+
"""
|
127 |
+
|
128 |
+
examples = [["examples/dogs.jpg"], ["examples/fruits.jpg"], ["examples/flowers.jpg"],
|
129 |
+
["examples/000000190756.jpg"], ["examples/image.jpg"], ["examples/000000263860.jpg"],
|
130 |
+
["examples/000000298738.jpg"], ["examples/000000027620.jpg"], ["examples/000000112634.jpg"],
|
131 |
+
["examples/000000377814.jpg"], ["examples/000000516143.jpg"]]
|
132 |
+
|
133 |
+
default_example = examples[0]
|
134 |
+
|
135 |
+
css = "h1 { text-align: center } .about { text-align: justify; padding-left: 10%; padding-right: 10%; }"
|
136 |
+
|
137 |
+
def segment_sementic(image, text):
|
138 |
+
mask_threshold = 0.0
|
139 |
+
img = image
|
140 |
+
coor = np.nonzero(img["mask"])
|
141 |
+
coor[0].sort()
|
142 |
+
xmin = coor[0][0]
|
143 |
+
xmax = coor[0][-1]
|
144 |
+
coor[1].sort()
|
145 |
+
ymin = coor[1][0]
|
146 |
+
ymax = coor[1][-1]
|
147 |
+
input_box = np.array([[ymin, xmin, ymax, xmax]])
|
148 |
+
|
149 |
+
image = img["image"]
|
150 |
+
input_image = np.asarray(image)
|
151 |
+
|
152 |
+
ckpt_path = 'regionspot_bl_336.pth'
|
153 |
+
clip_type = 'CLIP_400M_Large_336'
|
154 |
+
# clip_input_size = 336
|
155 |
+
clip_input_size = 224
|
156 |
+
text = text.split(',')
|
157 |
+
custom_vocabulary = text
|
158 |
+
# Build and initialize the model
|
159 |
+
model, msg = build_regionspot_model(is_training=False, image_size=clip_input_size, clip_type=clip_type, pretrain_ckpt=ckpt_path,
|
160 |
+
custom_vocabulary=custom_vocabulary)
|
161 |
+
# Create predictor and set image
|
162 |
+
predictor = RegionSpot_Predictor(model.cuda())
|
163 |
+
predictor.set_image(input_image, clip_input_size=clip_input_size)
|
164 |
+
|
165 |
+
masks, mask_iou_score, class_score, class_index = predictor.predict(
|
166 |
+
point_coords=None,
|
167 |
+
point_labels=None,
|
168 |
+
box=input_box,
|
169 |
+
multimask_output=False,
|
170 |
+
mask_threshold = mask_threshold,
|
171 |
+
)
|
172 |
+
fig = process_box(input_image, input_box,masks, mask_iou_score, class_score, class_index, custom_vocabulary)
|
173 |
+
|
174 |
+
torch.cuda.empty_cache()
|
175 |
+
torch.cuda.empty_cache()
|
176 |
+
torch.cuda.empty_cache()
|
177 |
+
torch.cuda.empty_cache()
|
178 |
+
|
179 |
+
return fig
|
180 |
+
|
181 |
+
def text_segment_sementic(image, text, conf_threshold, box_threshold, crop_n_layers, crop_nms_threshold):
|
182 |
+
mask_threshold = 0.0
|
183 |
+
image = image
|
184 |
+
input_image = np.asarray(image)
|
185 |
+
text = text.split(',')
|
186 |
+
|
187 |
+
textP = ['background']
|
188 |
+
text = textP + text
|
189 |
+
|
190 |
+
custom_vocabulary = text
|
191 |
+
ckpt_path = 'regionspot_bl_336.pth'
|
192 |
+
clip_type = 'CLIP_400M_Large_336'
|
193 |
+
clip_input_size = 336
|
194 |
+
# clip_input_size = 224
|
195 |
+
model, msg = build_regionspot_model(is_training=False, image_size=clip_input_size, clip_type=clip_type, pretrain_ckpt=ckpt_path,
|
196 |
+
custom_vocabulary=custom_vocabulary)
|
197 |
+
mask_generator = SamAutomaticMaskGenerator(model.cuda(),
|
198 |
+
# crop_thresh=iou_threshold,
|
199 |
+
box_thresh=conf_threshold,mask_threshold=mask_threshold,
|
200 |
+
box_nms_thresh=box_threshold, crop_n_layers=crop_n_layers, crop_nms_thresh= crop_nms_threshold)
|
201 |
+
masks = mask_generator.generate(input_image)
|
202 |
+
|
203 |
+
fig = show_anns(input_image, masks, custom_vocabulary)
|
204 |
+
|
205 |
+
torch.cuda.empty_cache()
|
206 |
+
torch.cuda.empty_cache()
|
207 |
+
torch.cuda.empty_cache()
|
208 |
+
torch.cuda.empty_cache()
|
209 |
+
|
210 |
+
return fig
|
211 |
+
|
212 |
+
def point_segment_sementic(image, text, box_threshold, crop_nms_threshold):
|
213 |
+
global global_points
|
214 |
+
global global_point_label
|
215 |
+
global image_temp
|
216 |
+
|
217 |
+
mask_threshold = 0.0
|
218 |
+
input_image = image_temp
|
219 |
+
output_image = np.asarray(image)
|
220 |
+
ckpt_path = 'regionspot_bl_336.pth'
|
221 |
+
clip_type = 'CLIP_400M_Large_336'
|
222 |
+
clip_input_size = 336
|
223 |
+
# clip_input_size = 224
|
224 |
+
text = text.split(',')
|
225 |
+
|
226 |
+
textP = ['background']
|
227 |
+
text = textP + text
|
228 |
+
|
229 |
+
custom_vocabulary = text
|
230 |
+
model, msg = build_regionspot_model(is_training=False, image_size=clip_input_size, clip_type=clip_type, pretrain_ckpt=ckpt_path,
|
231 |
+
custom_vocabulary=custom_vocabulary)
|
232 |
+
mask_generator = SamAutomaticMaskGenerator(model.cuda(),
|
233 |
+
crop_thresh=0.0,
|
234 |
+
box_thresh=0.0,
|
235 |
+
mask_threshold=mask_threshold,
|
236 |
+
box_nms_thresh=box_threshold, crop_nms_thresh= crop_nms_threshold)
|
237 |
+
masks = mask_generator.generate_point(input_image,point=np.asarray(global_points), label=np.asarray(global_point_label))
|
238 |
+
|
239 |
+
fig = show_anns(output_image, masks, custom_vocabulary)
|
240 |
+
|
241 |
+
torch.cuda.empty_cache()
|
242 |
+
torch.cuda.empty_cache()
|
243 |
+
torch.cuda.empty_cache()
|
244 |
+
torch.cuda.empty_cache()
|
245 |
+
|
246 |
+
return fig
|
247 |
+
|
248 |
+
def get_points_with_draw(image, label, evt: gr.SelectData):
|
249 |
+
global global_points
|
250 |
+
global global_point_label
|
251 |
+
global image_temp
|
252 |
+
|
253 |
+
if global_point_label == []:
|
254 |
+
image_temp = np.asarray(image)
|
255 |
+
|
256 |
+
x, y = evt.index[0], evt.index[1]
|
257 |
+
point_radius, point_color = 15, (255, 255, 0) if label == 'Mask' else (255, 0, 255)
|
258 |
+
global_points.append([x, y])
|
259 |
+
global_point_label.append(1 if label == 'Mask' else 0)
|
260 |
+
|
261 |
+
draw = ImageDraw.Draw(image)
|
262 |
+
draw.ellipse([(x - point_radius, y - point_radius), (x + point_radius, y + point_radius)], fill=point_color)
|
263 |
+
return image
|
264 |
+
|
265 |
+
|
266 |
+
cond_img_p = gr.Image(label="Input with points", value="examples/dogs.jpg", type='pil')
|
267 |
+
cond_img_t = gr.Image(label="Input with text", value="examples/dogs.jpg", type='pil')
|
268 |
+
cond_img_b = gr.Image(label="Input with box", type="pil",tool='sketch')
|
269 |
+
img_p = gr.Image(label="Input with points P", type='pil')
|
270 |
+
|
271 |
+
segm_img_p = gr.Image(label="Segmented Image with points", interactive=False, type='pil')
|
272 |
+
segm_img_t = gr.Image(label="Segmented Image with text", interactive=False, type='pil')
|
273 |
+
segm_img_b = gr.Image(label="Segmented Image with box", interactive=False, type='pil')
|
274 |
+
|
275 |
+
global_points = []
|
276 |
+
global_point_label = []
|
277 |
+
image_temp = np.array([])
|
278 |
+
|
279 |
+
with gr.Blocks(css=css, title='Region Spot') as demo:
|
280 |
+
with gr.Row():
|
281 |
+
with gr.Column(scale=1):
|
282 |
+
# Title
|
283 |
+
gr.Markdown(title)
|
284 |
+
|
285 |
+
with gr.Tab("Points mode"):
|
286 |
+
# Images
|
287 |
+
with gr.Row(variant="panel"):
|
288 |
+
with gr.Column(scale=1):
|
289 |
+
cond_img_p.render()
|
290 |
+
|
291 |
+
with gr.Column(scale=1):
|
292 |
+
segm_img_p.render()
|
293 |
+
|
294 |
+
# Submit & Clear
|
295 |
+
with gr.Row():
|
296 |
+
with gr.Column():
|
297 |
+
with gr.Row():
|
298 |
+
with gr.Column():
|
299 |
+
add_or_remove = gr.Radio(["Mask", "Background"], value="Mask", label="Point_label (foreground/background)")
|
300 |
+
text_box_p = gr.Textbox(label="vocabulary", value="dog,cat")
|
301 |
+
with gr.Column():
|
302 |
+
segment_btn_p = gr.Button("Segment with points prompt", variant='primary')
|
303 |
+
clear_btn_p = gr.Button("Clear", variant='secondary')
|
304 |
+
|
305 |
+
gr.Markdown("Try some of the examples below")
|
306 |
+
gr.Examples(examples=examples,
|
307 |
+
inputs=[cond_img_t],
|
308 |
+
examples_per_page=18)
|
309 |
+
|
310 |
+
with gr.Column():
|
311 |
+
with gr.Accordion("Advanced options", open=True):
|
312 |
+
box_threshold_p = gr.Slider(0.0, 0.9, 0.7, step=0.05, label='box threshold', info='box nms threshold')
|
313 |
+
crop_threshold_p = gr.Slider(0.0, 0.9, 0.7, step=0.05, label='crop threshold', info='crop nms threshold')
|
314 |
+
# Description
|
315 |
+
gr.Markdown(description_p)
|
316 |
+
cond_img_p.select(get_points_with_draw, [cond_img_p, add_or_remove], cond_img_p)
|
317 |
+
segment_btn_p.click(point_segment_sementic,
|
318 |
+
inputs=[
|
319 |
+
cond_img_p,
|
320 |
+
text_box_p,
|
321 |
+
box_threshold_p,
|
322 |
+
crop_threshold_p,
|
323 |
+
],
|
324 |
+
outputs=[segm_img_p])
|
325 |
+
|
326 |
+
with gr.Tab("Text mode"):
|
327 |
+
# Images
|
328 |
+
with gr.Row(variant="panel"):
|
329 |
+
with gr.Column(scale=1):
|
330 |
+
cond_img_t.render()
|
331 |
+
|
332 |
+
with gr.Column(scale=1):
|
333 |
+
segm_img_t.render()
|
334 |
+
|
335 |
+
# Submit & Clear
|
336 |
+
with gr.Row():
|
337 |
+
with gr.Column():
|
338 |
+
with gr.Row():
|
339 |
+
with gr.Column():
|
340 |
+
contour_check = gr.Checkbox(value=True, label='withContours', info='draw the edges of the masks')
|
341 |
+
text_box_t = gr.Textbox(label="text prompt", value="dog,cat")
|
342 |
+
|
343 |
+
with gr.Column():
|
344 |
+
segment_btn_t = gr.Button("Segment with text", variant='primary')
|
345 |
+
clear_btn_t = gr.Button("Clear", variant="secondary")
|
346 |
+
|
347 |
+
gr.Markdown("Try some of the examples below")
|
348 |
+
gr.Examples(examples=examples,
|
349 |
+
inputs=[cond_img_t],
|
350 |
+
examples_per_page=18)
|
351 |
+
|
352 |
+
with gr.Column():
|
353 |
+
with gr.Accordion("Advanced options", open=True):
|
354 |
+
conf_threshold_t = gr.Slider(0.0, 0.9, 0.8, step=0.05, label='clip threshold', info='object confidence threshold of clip')
|
355 |
+
box_threshold_t = gr.Slider(0.0, 0.9, 0.5, step=0.05, label='box threshold', info='box nms threshold')
|
356 |
+
crop_n_layers_t = gr.Slider(0, 3, 0, step=1, label='crop_n_layers', info='crop_n_layers of auto generator')
|
357 |
+
crop_threshold_t = gr.Slider(0.0, 0.9, 0.5, step=0.05, label='crop threshold', info='crop nms threshold')
|
358 |
+
|
359 |
+
# Description
|
360 |
+
gr.Markdown(description_e)
|
361 |
+
segment_btn_t.click(text_segment_sementic,
|
362 |
+
inputs=[
|
363 |
+
cond_img_t,
|
364 |
+
text_box_t,
|
365 |
+
conf_threshold_t,
|
366 |
+
box_threshold_t,
|
367 |
+
crop_n_layers_t,
|
368 |
+
crop_threshold_t
|
369 |
+
],
|
370 |
+
outputs=[segm_img_t])
|
371 |
+
|
372 |
+
with gr.Tab("Box mode"):
|
373 |
+
# Images
|
374 |
+
with gr.Row(variant="panel"):
|
375 |
+
with gr.Column(scale=1):
|
376 |
+
cond_img_b.render()
|
377 |
+
|
378 |
+
with gr.Column(scale=1):
|
379 |
+
segm_img_b.render()
|
380 |
+
|
381 |
+
# Submit & Clear
|
382 |
+
with gr.Row():
|
383 |
+
with gr.Column():
|
384 |
+
with gr.Row():
|
385 |
+
with gr.Column():
|
386 |
+
contour_check = gr.Checkbox(value=True, label='withContours', info='draw the edges of the masks')
|
387 |
+
text_box_b = gr.Textbox(label="vocabulary", value="dog,cat")
|
388 |
+
with gr.Column():
|
389 |
+
segment_btn_b = gr.Button("Segment with box", variant='primary')
|
390 |
+
clear_btn_b = gr.Button("Clear", variant="secondary")
|
391 |
+
|
392 |
+
gr.Markdown("Try some of the examples below")
|
393 |
+
gr.Examples(examples=examples,
|
394 |
+
inputs=[cond_img_t],
|
395 |
+
examples_per_page=18)
|
396 |
+
|
397 |
+
with gr.Column():
|
398 |
+
# Description
|
399 |
+
gr.Markdown(description_b)
|
400 |
+
|
401 |
+
segment_btn_b.click(segment_sementic,
|
402 |
+
inputs=[
|
403 |
+
cond_img_b,
|
404 |
+
text_box_b,
|
405 |
+
],
|
406 |
+
outputs=[segm_img_b])
|
407 |
+
|
408 |
+
def clear():
|
409 |
+
return None, None, None
|
410 |
+
|
411 |
+
def clear_text():
|
412 |
+
return None, None, None
|
413 |
+
|
414 |
+
clear_btn_p.click(clear, outputs=[cond_img_p, segm_img_p, text_box_p])
|
415 |
+
clear_btn_t.click(clear_text, outputs=[cond_img_t, segm_img_t, text_box_t])
|
416 |
+
clear_btn_b.click(clear_text, outputs=[cond_img_b, segm_img_b, text_box_b])
|
417 |
+
|
418 |
+
demo.queue()
|
419 |
+
demo.launch()
|
examples/000000027620.jpg
ADDED
examples/000000112634.jpg
ADDED
examples/000000165713.jpg
ADDED
examples/000000190756.jpg
ADDED
examples/000000226058.jpg
ADDED
examples/000000234779.jpg
ADDED
examples/000000243075.jpg
ADDED
examples/000000263860.jpg
ADDED
examples/000000284762.jpg
ADDED
examples/000000298738.jpg
ADDED
examples/000000372819.jpg
ADDED
examples/000000377814.jpg
ADDED
examples/000000516143.jpg
ADDED
examples/000000555050.jpg
ADDED
examples/dogs.jpg
ADDED
examples/flowers.jpg
ADDED
examples/fruits.jpg
ADDED
examples/image.jpg
ADDED
examples/truck.jpg
ADDED
regionspot/__init__.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .config import add_regionspot_config
|
2 |
+
from .detector import RegionSpot
|
3 |
+
from .data.dataset_mapper import RegionSpotDatasetMapper
|
4 |
+
from .test_time_augmentation import RegionSpotWithTTA
|
5 |
+
from .build import *
|
6 |
+
from .data.custom_dataset_dataloader import *
|
7 |
+
from .predictor import RegionSpot_Predictor
|
8 |
+
from .automatic_mask_generator import SamAutomaticMaskGenerator
|
9 |
+
|
10 |
+
|
regionspot/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (610 Bytes). View file
|
|
regionspot/__pycache__/automatic_mask_generator.cpython-38.pyc
ADDED
Binary file (15.4 kB). View file
|
|
regionspot/__pycache__/build.cpython-38.pyc
ADDED
Binary file (10.8 kB). View file
|
|
regionspot/__pycache__/config.cpython-38.pyc
ADDED
Binary file (1.14 kB). View file
|
|
regionspot/__pycache__/detector.cpython-38.pyc
ADDED
Binary file (5.17 kB). View file
|
|
regionspot/__pycache__/predictor.cpython-38.pyc
ADDED
Binary file (11 kB). View file
|
|
regionspot/__pycache__/test_time_augmentation.cpython-38.pyc
ADDED
Binary file (7.79 kB). View file
|
|
regionspot/automatic_mask_generator.py
ADDED
@@ -0,0 +1,632 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
from torchvision.ops.boxes import batched_nms, box_area # type: ignore
|
10 |
+
import json
|
11 |
+
|
12 |
+
from typing import Any, Dict, List, Optional, Tuple
|
13 |
+
|
14 |
+
from .modeling.segment_anything.utils.transforms import ResizeLongestSide
|
15 |
+
# from .modeling import Sam
|
16 |
+
# from .predictor import SamPredictor
|
17 |
+
from .predictor import RegionSpot_Predictor
|
18 |
+
from .modeling.segment_anything.utils.amg import (
|
19 |
+
MaskData,
|
20 |
+
area_from_rle,
|
21 |
+
batch_iterator,
|
22 |
+
batched_mask_to_box,
|
23 |
+
box_xyxy_to_xywh,
|
24 |
+
build_all_layer_point_grids,
|
25 |
+
calculate_stability_score,
|
26 |
+
coco_encode_rle,
|
27 |
+
generate_crop_boxes,
|
28 |
+
is_box_near_crop_edge,
|
29 |
+
mask_to_rle_pytorch,
|
30 |
+
remove_small_regions,
|
31 |
+
rle_to_mask,
|
32 |
+
uncrop_boxes_xyxy,
|
33 |
+
uncrop_masks,
|
34 |
+
uncrop_points,
|
35 |
+
)
|
36 |
+
|
37 |
+
|
38 |
+
class SamAutomaticMaskGenerator:
|
39 |
+
def __init__(
|
40 |
+
self,
|
41 |
+
model,
|
42 |
+
points_per_side: Optional[int] = 32,
|
43 |
+
points_per_batch: int = 64,
|
44 |
+
pred_iou_thresh: float = 0.88,
|
45 |
+
stability_score_thresh: float = 0.95,
|
46 |
+
stability_score_offset: float = 1.0,
|
47 |
+
box_nms_thresh: float = 0.7,
|
48 |
+
crop_n_layers: int = 0,
|
49 |
+
crop_nms_thresh: float = 0.7,
|
50 |
+
crop_overlap_ratio: float = 512 / 1500,
|
51 |
+
crop_n_points_downscale_factor: int = 1,
|
52 |
+
point_grids: Optional[List[np.ndarray]] = None,
|
53 |
+
min_mask_region_area: int = 0,
|
54 |
+
output_mode: str = "binary_mask",
|
55 |
+
crop_thresh = 0.0,
|
56 |
+
box_thresh = 0.6,
|
57 |
+
mask_threshold = 0.0,
|
58 |
+
) -> None:
|
59 |
+
"""
|
60 |
+
Using a SAM model, generates masks for the entire image.
|
61 |
+
Generates a grid of point prompts over the image, then filters
|
62 |
+
low quality and duplicate masks. The default settings are chosen
|
63 |
+
for SAM with a ViT-H backbone.
|
64 |
+
|
65 |
+
Arguments:
|
66 |
+
model (Sam): The SAM model to use for mask prediction.
|
67 |
+
points_per_side (int or None): The number of points to be sampled
|
68 |
+
along one side of the image. The total number of points is
|
69 |
+
points_per_side**2. If None, 'point_grids' must provide explicit
|
70 |
+
point sampling.
|
71 |
+
points_per_batch (int): Sets the number of points run simultaneously
|
72 |
+
by the model. Higher numbers may be faster but use more GPU memory.
|
73 |
+
pred_iou_thresh (float): A filtering threshold in [0,1], using the
|
74 |
+
model's predicted mask quality.
|
75 |
+
stability_score_thresh (float): A filtering threshold in [0,1], using
|
76 |
+
the stability of the mask under changes to the cutoff used to binarize
|
77 |
+
the model's mask predictions.
|
78 |
+
stability_score_offset (float): The amount to shift the cutoff when
|
79 |
+
calculated the stability score.
|
80 |
+
box_nms_thresh (float): The box IoU cutoff used by non-maximal
|
81 |
+
suppression to filter duplicate masks.
|
82 |
+
crop_n_layers (int): If >0, mask prediction will be run again on
|
83 |
+
crops of the image. Sets the number of layers to run, where each
|
84 |
+
layer has 2**i_layer number of image crops.
|
85 |
+
crop_nms_thresh (float): The box IoU cutoff used by non-maximal
|
86 |
+
suppression to filter duplicate masks between different crops.
|
87 |
+
crop_overlap_ratio (float): Sets the degree to which crops overlap.
|
88 |
+
In the first crop layer, crops will overlap by this fraction of
|
89 |
+
the image length. Later layers with more crops scale down this overlap.
|
90 |
+
crop_n_points_downscale_factor (int): The number of points-per-side
|
91 |
+
sampled in layer n is scaled down by crop_n_points_downscale_factor**n.
|
92 |
+
point_grids (list(np.ndarray) or None): A list over explicit grids
|
93 |
+
of points used for sampling, normalized to [0,1]. The nth grid in the
|
94 |
+
list is used in the nth crop layer. Exclusive with points_per_side.
|
95 |
+
min_mask_region_area (int): If >0, postprocessing will be applied
|
96 |
+
to remove disconnected regions and holes in masks with area smaller
|
97 |
+
than min_mask_region_area. Requires opencv.
|
98 |
+
output_mode (str): The form masks are returned in. Can be 'binary_mask',
|
99 |
+
'uncompressed_rle', or 'coco_rle'. 'coco_rle' requires pycocotools.
|
100 |
+
For large resolutions, 'binary_mask' may consume large amounts of
|
101 |
+
memory.
|
102 |
+
"""
|
103 |
+
|
104 |
+
assert (points_per_side is None) != (
|
105 |
+
point_grids is None
|
106 |
+
), "Exactly one of points_per_side or point_grid must be provided."
|
107 |
+
if points_per_side is not None:
|
108 |
+
self.point_grids = build_all_layer_point_grids(
|
109 |
+
points_per_side,
|
110 |
+
crop_n_layers,
|
111 |
+
crop_n_points_downscale_factor,
|
112 |
+
)
|
113 |
+
elif point_grids is not None:
|
114 |
+
self.point_grids = point_grids
|
115 |
+
else:
|
116 |
+
raise ValueError("Can't have both points_per_side and point_grid be None.")
|
117 |
+
|
118 |
+
assert output_mode in [
|
119 |
+
"binary_mask",
|
120 |
+
"uncompressed_rle",
|
121 |
+
"coco_rle",
|
122 |
+
], f"Unknown output_mode {output_mode}."
|
123 |
+
if output_mode == "coco_rle":
|
124 |
+
from pycocotools import mask as mask_utils # type: ignore # noqa: F401
|
125 |
+
|
126 |
+
if min_mask_region_area > 0:
|
127 |
+
import cv2 # type: ignore # noqa: F401
|
128 |
+
|
129 |
+
# self.sam_clip = model
|
130 |
+
# self.model = self.sam_clip.sam
|
131 |
+
self.predictor = RegionSpot_Predictor(model)
|
132 |
+
self.points_per_batch = points_per_batch
|
133 |
+
self.pred_iou_thresh = pred_iou_thresh
|
134 |
+
self.stability_score_thresh = stability_score_thresh
|
135 |
+
self.stability_score_offset = stability_score_offset
|
136 |
+
self.box_nms_thresh = box_nms_thresh
|
137 |
+
self.crop_n_layers = crop_n_layers
|
138 |
+
self.crop_nms_thresh = crop_nms_thresh
|
139 |
+
self.crop_overlap_ratio = crop_overlap_ratio
|
140 |
+
self.crop_n_points_downscale_factor = crop_n_points_downscale_factor
|
141 |
+
self.min_mask_region_area = min_mask_region_area
|
142 |
+
self.output_mode = output_mode
|
143 |
+
self.crop_thresh = crop_thresh
|
144 |
+
self.box_thresh = box_thresh
|
145 |
+
self.mask_threshold = mask_threshold
|
146 |
+
|
147 |
+
@torch.no_grad()
|
148 |
+
def generate(self, image: np.ndarray) -> List[Dict[str, Any]]:
|
149 |
+
"""
|
150 |
+
Generates masks for the given image.
|
151 |
+
|
152 |
+
Arguments:
|
153 |
+
image (np.ndarray): The image to generate masks for, in HWC uint8 format.
|
154 |
+
|
155 |
+
Returns:
|
156 |
+
list(dict(str, any)): A list over records for masks. Each record is
|
157 |
+
a dict containing the following keys:
|
158 |
+
segmentation (dict(str, any) or np.ndarray): The mask. If
|
159 |
+
output_mode='binary_mask', is an array of shape HW. Otherwise,
|
160 |
+
is a dictionary containing the RLE.
|
161 |
+
bbox (list(float)): The box around the mask, in XYWH format.
|
162 |
+
area (int): The area in pixels of the mask.
|
163 |
+
predicted_iou (float): The model's own prediction of the mask's
|
164 |
+
quality. This is filtered by the pred_iou_thresh parameter.
|
165 |
+
point_coords (list(list(float))): The point coordinates input
|
166 |
+
to the model to generate this mask.
|
167 |
+
stability_score (float): A measure of the mask's quality. This
|
168 |
+
is filtered on using the stability_score_thresh parameter.
|
169 |
+
crop_box (list(float)): The crop of the image used to generate
|
170 |
+
the mask, given in XYWH format.
|
171 |
+
"""
|
172 |
+
|
173 |
+
# Generate masks
|
174 |
+
mask_data = self._generate_masks(image)
|
175 |
+
# Filter small disconnected regions and holes in masks
|
176 |
+
if self.min_mask_region_area > 0:
|
177 |
+
mask_data = self.postprocess_small_regions(
|
178 |
+
mask_data,
|
179 |
+
self.min_mask_region_area,
|
180 |
+
max(self.box_nms_thresh, self.crop_nms_thresh),
|
181 |
+
)
|
182 |
+
|
183 |
+
|
184 |
+
# transform = ResizeLongestSide(self.model.image_encoder.img_size)
|
185 |
+
self.predictor.set_image(image,clip_input_size=336)
|
186 |
+
total_data = MaskData()
|
187 |
+
total_data["pred_class"]=[]
|
188 |
+
|
189 |
+
maxvalue_box = 0
|
190 |
+
|
191 |
+
for box in mask_data["boxes"]:
|
192 |
+
box = self.predictor.transform.apply_boxes(box, self.predictor.original_size)
|
193 |
+
box_torch = torch.as_tensor(box, dtype=torch.float, device=self.predictor.device)
|
194 |
+
box_torch = box_torch[None, :]
|
195 |
+
masks, iou_preds, _, max_values, max_index = self.predictor.predict_torch(
|
196 |
+
point_coords=None,
|
197 |
+
point_labels=None,
|
198 |
+
boxes=box_torch,
|
199 |
+
mask_input=None,
|
200 |
+
multimask_output=False,
|
201 |
+
mask_threshold = self.mask_threshold,
|
202 |
+
)
|
203 |
+
bmax_values = max_values.detach().cpu().numpy()
|
204 |
+
bmax_index = max_index.detach().cpu().numpy()
|
205 |
+
|
206 |
+
pred_class = []
|
207 |
+
for i in range(bmax_index.shape[0]):
|
208 |
+
if bmax_values[i] > self.box_thresh:
|
209 |
+
pred_class.append(bmax_index[i])
|
210 |
+
else:
|
211 |
+
pred_class.append(-1)
|
212 |
+
# Serialize predictions and store in MaskData
|
213 |
+
data = MaskData(
|
214 |
+
masks=masks.flatten(0, 1),
|
215 |
+
iou_preds=iou_preds.flatten(0, 1),
|
216 |
+
pred_class=pred_class,
|
217 |
+
)
|
218 |
+
data["boxes"] = batched_mask_to_box(data["masks"])
|
219 |
+
data["rles"] = mask_to_rle_pytorch(data["masks"])
|
220 |
+
del data["masks"]
|
221 |
+
total_data.cat(data)
|
222 |
+
if total_data["pred_class"]==[]:
|
223 |
+
return False
|
224 |
+
if total_data["pred_class"]!=[]:
|
225 |
+
keep_mask= []
|
226 |
+
for i in total_data["pred_class"]:
|
227 |
+
if i != -1:
|
228 |
+
keep_mask.append(True)
|
229 |
+
else:
|
230 |
+
keep_mask.append(False)
|
231 |
+
|
232 |
+
keep_mask = torch.tensor(keep_mask)
|
233 |
+
total_data.filter(keep_mask)
|
234 |
+
mask_data = total_data
|
235 |
+
|
236 |
+
|
237 |
+
# Encode masks
|
238 |
+
if self.output_mode == "coco_rle":
|
239 |
+
mask_data["segmentations"] = [coco_encode_rle(rle) for rle in mask_data["rles"]]
|
240 |
+
elif self.output_mode == "binary_mask":
|
241 |
+
mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]]
|
242 |
+
else:
|
243 |
+
mask_data["segmentations"] = mask_data["rles"]
|
244 |
+
|
245 |
+
# Write mask records
|
246 |
+
curr_anns = []
|
247 |
+
|
248 |
+
for idx in range(len(mask_data["segmentations"])):
|
249 |
+
|
250 |
+
ann = {
|
251 |
+
"segmentation": mask_data["segmentations"][idx],
|
252 |
+
"area": area_from_rle(mask_data["rles"][idx]),
|
253 |
+
"bbox": box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(),
|
254 |
+
"predicted_iou": mask_data["iou_preds"][idx].item(),
|
255 |
+
"pred_class": mask_data["pred_class"][idx],
|
256 |
+
}
|
257 |
+
curr_anns.append(ann)
|
258 |
+
|
259 |
+
return curr_anns
|
260 |
+
|
261 |
+
|
262 |
+
|
263 |
+
###################
|
264 |
+
@torch.no_grad()
|
265 |
+
def generate_point(self, image: np.ndarray, point:np.ndarray, label:np.ndarray) -> List[Dict[str, Any]]:
|
266 |
+
"""
|
267 |
+
Generates masks for the given image.
|
268 |
+
|
269 |
+
Arguments:
|
270 |
+
image (np.ndarray): The image to generate masks for, in HWC uint8 format.
|
271 |
+
|
272 |
+
Returns:
|
273 |
+
list(dict(str, any)): A list over records for masks. Each record is
|
274 |
+
a dict containing the following keys:
|
275 |
+
segmentation (dict(str, any) or np.ndarray): The mask. If
|
276 |
+
output_mode='binary_mask', is an array of shape HW. Otherwise,
|
277 |
+
is a dictionary containing the RLE.
|
278 |
+
bbox (list(float)): The box around the mask, in XYWH format.
|
279 |
+
area (int): The area in pixels of the mask.
|
280 |
+
predicted_iou (float): The model's own prediction of the mask's
|
281 |
+
quality. This is filtered by the pred_iou_thresh parameter.
|
282 |
+
point_coords (list(list(float))): The point coordinates input
|
283 |
+
to the model to generate this mask.
|
284 |
+
stability_score (float): A measure of the mask's quality. This
|
285 |
+
is filtered on using the stability_score_thresh parameter.
|
286 |
+
crop_box (list(float)): The crop of the image used to generate
|
287 |
+
the mask, given in XYWH format.
|
288 |
+
"""
|
289 |
+
|
290 |
+
|
291 |
+
# Generate masks
|
292 |
+
mask_data = self._generate_masks_point(image, point, label)
|
293 |
+
# Filter small disconnected regions and holes in masks
|
294 |
+
if self.min_mask_region_area > 0:
|
295 |
+
mask_data = self.postprocess_small_regions(
|
296 |
+
mask_data,
|
297 |
+
self.min_mask_region_area,
|
298 |
+
max(self.box_nms_thresh, self.crop_nms_thresh),
|
299 |
+
)
|
300 |
+
# transform = ResizeLongestSide(self.model.image_encoder.img_size)
|
301 |
+
self.predictor.set_image(image,clip_input_size=336)
|
302 |
+
total_data = MaskData()
|
303 |
+
total_data["pred_class"]=[]
|
304 |
+
|
305 |
+
maxvalue_box = 0
|
306 |
+
|
307 |
+
for box in mask_data["boxes"]:
|
308 |
+
box = self.predictor.transform.apply_boxes(box, self.predictor.original_size)
|
309 |
+
box_torch = torch.as_tensor(box, dtype=torch.float, device=self.predictor.device)
|
310 |
+
box_torch = box_torch[None, :]
|
311 |
+
masks, iou_preds, _, max_values, max_index = self.predictor.predict_torch(
|
312 |
+
point_coords=None,
|
313 |
+
point_labels=None,
|
314 |
+
boxes=box_torch,
|
315 |
+
mask_input=None,
|
316 |
+
multimask_output=False,
|
317 |
+
mask_threshold = self.mask_threshold,
|
318 |
+
)
|
319 |
+
bmax_values = max_values.detach().cpu().numpy()
|
320 |
+
bmax_index = max_index.detach().cpu().numpy()
|
321 |
+
|
322 |
+
pred_class = []
|
323 |
+
maxV = []
|
324 |
+
for i in range(bmax_index.shape[0]):
|
325 |
+
if bmax_values[i] > self.box_thresh:
|
326 |
+
pred_class.append(bmax_index[i])
|
327 |
+
else:
|
328 |
+
pred_class.append(-1)
|
329 |
+
for i in range(bmax_index.shape[0]):
|
330 |
+
maxV.append(bmax_values[i])
|
331 |
+
# Serialize predictions and store in MaskData
|
332 |
+
data = MaskData(
|
333 |
+
masks=masks.flatten(0, 1),
|
334 |
+
iou_preds=iou_preds.flatten(0, 1),
|
335 |
+
pred_class=pred_class,
|
336 |
+
maxValue = maxV,
|
337 |
+
)
|
338 |
+
data["boxes"] = batched_mask_to_box(data["masks"])
|
339 |
+
data["rles"] = mask_to_rle_pytorch(data["masks"])
|
340 |
+
del data["masks"]
|
341 |
+
total_data.cat(data)
|
342 |
+
|
343 |
+
if total_data["pred_class"]==[]:
|
344 |
+
return False
|
345 |
+
if total_data["maxValue"]!=[]:
|
346 |
+
keep_mask= []
|
347 |
+
for i in total_data["maxValue"]:
|
348 |
+
if i != max(total_data["maxValue"]):
|
349 |
+
keep_mask.append(False)
|
350 |
+
else:
|
351 |
+
keep_mask.append(True)
|
352 |
+
keep_mask = torch.tensor(keep_mask)
|
353 |
+
total_data.filter(keep_mask)
|
354 |
+
mask_data = total_data
|
355 |
+
|
356 |
+
|
357 |
+
# Encode masks
|
358 |
+
if self.output_mode == "coco_rle":
|
359 |
+
mask_data["segmentations"] = [coco_encode_rle(rle) for rle in mask_data["rles"]]
|
360 |
+
elif self.output_mode == "binary_mask":
|
361 |
+
mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]]
|
362 |
+
else:
|
363 |
+
mask_data["segmentations"] = mask_data["rles"]
|
364 |
+
|
365 |
+
# Write mask records
|
366 |
+
curr_anns = []
|
367 |
+
|
368 |
+
for idx in range(len(mask_data["segmentations"])):
|
369 |
+
|
370 |
+
ann = {
|
371 |
+
"segmentation": mask_data["segmentations"][idx],
|
372 |
+
"area": area_from_rle(mask_data["rles"][idx]),
|
373 |
+
"bbox": box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(),
|
374 |
+
"predicted_iou": mask_data["iou_preds"][idx].item(),
|
375 |
+
"pred_class": mask_data["pred_class"][idx],
|
376 |
+
}
|
377 |
+
curr_anns.append(ann)
|
378 |
+
|
379 |
+
return curr_anns
|
380 |
+
|
381 |
+
|
382 |
+
def _generate_masks_point(self, image: np.ndarray, point:np.ndarray, label:np.ndarray) -> MaskData:
|
383 |
+
self.predictor.set_image(image,clip_input_size=336)
|
384 |
+
orig_size = image.shape[:2]
|
385 |
+
|
386 |
+
point = np.array(point)
|
387 |
+
transformed_points = self.predictor.transform.apply_coords(point, orig_size)
|
388 |
+
in_points = torch.as_tensor(transformed_points, device=self.predictor.device).reshape(-1,2)
|
389 |
+
in_labels = torch.as_tensor(label, device=self.predictor.device)
|
390 |
+
# self.predictor.set_image(image,clip_input_size=336)
|
391 |
+
masks, iou_preds, _, max_values, max_index = self.predictor.predict_torch(
|
392 |
+
in_points[:, None, :],
|
393 |
+
in_labels[:, None],
|
394 |
+
multimask_output=True,
|
395 |
+
return_logits=True,
|
396 |
+
mask_threshold = self.mask_threshold,
|
397 |
+
)
|
398 |
+
|
399 |
+
max_values = max_values.detach().cpu().numpy()
|
400 |
+
max_index = max_index.detach().cpu().numpy()
|
401 |
+
pred_class = []
|
402 |
+
for i in range(max_index.shape[0]):
|
403 |
+
if max_values[i] > self.crop_thresh:
|
404 |
+
pred_class.append(max_index[i])
|
405 |
+
else:
|
406 |
+
pred_class.append(-1)
|
407 |
+
|
408 |
+
# Serialize predictions and store in MaskData
|
409 |
+
data = MaskData(
|
410 |
+
masks=masks.flatten(0, 1),
|
411 |
+
iou_preds=iou_preds.flatten(0, 1),
|
412 |
+
pred_class=pred_class,
|
413 |
+
# points=torch.as_tensor(points.repeat(masks.shape[1], axis=0)),
|
414 |
+
)
|
415 |
+
del masks
|
416 |
+
# Filter by predicted IoU
|
417 |
+
# if self.pred_iou_thresh > 0.0:
|
418 |
+
# keep_mask = data["iou_preds"] > self.pred_iou_thresh
|
419 |
+
# data.filter(keep_mask)
|
420 |
+
if data["pred_class"]!=[]:
|
421 |
+
|
422 |
+
keep_mask= []
|
423 |
+
for i in data["pred_class"]:
|
424 |
+
if i != -1:
|
425 |
+
keep_mask.append(True)
|
426 |
+
else:
|
427 |
+
keep_mask.append(False)
|
428 |
+
|
429 |
+
keep_mask = torch.tensor(keep_mask)
|
430 |
+
data.filter(keep_mask)
|
431 |
+
|
432 |
+
# Threshold masks and calculate boxes
|
433 |
+
data["masks"] = data["masks"] > self.predictor.model.mask_threshold
|
434 |
+
data["boxes"] = batched_mask_to_box(data["masks"])
|
435 |
+
data["rles"] = mask_to_rle_pytorch(data["masks"])
|
436 |
+
|
437 |
+
del data["masks"]
|
438 |
+
|
439 |
+
keep_by_nms = batched_nms(
|
440 |
+
data["boxes"].float(),
|
441 |
+
data["iou_preds"],
|
442 |
+
torch.zeros_like(data["boxes"][:, 0]), # categories
|
443 |
+
iou_threshold=self.box_nms_thresh,
|
444 |
+
)
|
445 |
+
data.filter(keep_by_nms)
|
446 |
+
|
447 |
+
data.to_numpy()
|
448 |
+
return data
|
449 |
+
|
450 |
+
def _generate_masks(self, image: np.ndarray) -> MaskData:
|
451 |
+
orig_size = image.shape[:2]
|
452 |
+
crop_boxes, layer_idxs = generate_crop_boxes(
|
453 |
+
orig_size, self.crop_n_layers, self.crop_overlap_ratio
|
454 |
+
)
|
455 |
+
# Iterate over image crops
|
456 |
+
data = MaskData()
|
457 |
+
for crop_box, layer_idx in zip(crop_boxes, layer_idxs):
|
458 |
+
crop_data = self._process_crop(image, crop_box, layer_idx, orig_size)
|
459 |
+
data.cat(crop_data)
|
460 |
+
|
461 |
+
# Remove duplicate masks between crops
|
462 |
+
if len(crop_boxes) > 1:
|
463 |
+
# Prefer masks from smaller crops
|
464 |
+
scores = 1 / box_area(data["crop_boxes"])
|
465 |
+
scores = scores.to(data["boxes"].device)
|
466 |
+
keep_by_nms = batched_nms(
|
467 |
+
data["boxes"].float(),
|
468 |
+
scores,
|
469 |
+
torch.zeros_like(data["boxes"][:, 0]), # categories
|
470 |
+
iou_threshold=self.crop_nms_thresh,
|
471 |
+
)
|
472 |
+
data.filter(keep_by_nms)
|
473 |
+
data.to_numpy()
|
474 |
+
return data
|
475 |
+
|
476 |
+
def _process_crop(
|
477 |
+
self,
|
478 |
+
image: np.ndarray,
|
479 |
+
crop_box: List[int],
|
480 |
+
crop_layer_idx: int,
|
481 |
+
orig_size: Tuple[int, ...],
|
482 |
+
) -> MaskData:
|
483 |
+
# Crop the image and calculate embeddings
|
484 |
+
x0, y0, x1, y1 = crop_box
|
485 |
+
cropped_im = image[y0:y1, x0:x1, :]
|
486 |
+
cropped_im_size = cropped_im.shape[:2]
|
487 |
+
self.predictor.set_image(cropped_im,clip_input_size=336)
|
488 |
+
|
489 |
+
# Get points for this crop
|
490 |
+
points_scale = np.array(cropped_im_size)[None, ::-1]
|
491 |
+
points_for_image = self.point_grids[crop_layer_idx] * points_scale
|
492 |
+
|
493 |
+
# Generate masks for this crop in batches
|
494 |
+
data = MaskData()
|
495 |
+
for (points,) in batch_iterator(self.points_per_batch, points_for_image):
|
496 |
+
batch_data = self._process_batch(points, cropped_im_size, crop_box, orig_size)
|
497 |
+
data.cat(batch_data)
|
498 |
+
del batch_data
|
499 |
+
self.predictor.reset_image()
|
500 |
+
|
501 |
+
# Remove duplicates within this crop.
|
502 |
+
keep_by_nms = batched_nms(
|
503 |
+
data["boxes"].float(),
|
504 |
+
data["iou_preds"],
|
505 |
+
torch.zeros_like(data["boxes"][:, 0]), # categories
|
506 |
+
iou_threshold=self.box_nms_thresh,
|
507 |
+
)
|
508 |
+
data.filter(keep_by_nms)
|
509 |
+
# Return to the original image frame
|
510 |
+
data["boxes"] = uncrop_boxes_xyxy(data["boxes"], crop_box)
|
511 |
+
data["points"] = uncrop_points(data["points"], crop_box)
|
512 |
+
data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))])
|
513 |
+
|
514 |
+
return data
|
515 |
+
|
516 |
+
def _process_batch(
|
517 |
+
self,
|
518 |
+
points: np.ndarray,
|
519 |
+
im_size: Tuple[int, ...],
|
520 |
+
crop_box: List[int],
|
521 |
+
orig_size: Tuple[int, ...],
|
522 |
+
) -> MaskData:
|
523 |
+
orig_h, orig_w = orig_size
|
524 |
+
|
525 |
+
# Run model on this batch
|
526 |
+
transformed_points = self.predictor.transform.apply_coords(points, im_size)
|
527 |
+
in_points = torch.as_tensor(transformed_points, device=self.predictor.device)
|
528 |
+
in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device)
|
529 |
+
# import ipdb; ipdb.set_trace()
|
530 |
+
masks, iou_preds, _, max_values, max_index = self.predictor.predict_torch(
|
531 |
+
in_points[:, None, :],
|
532 |
+
in_labels[:, None],
|
533 |
+
multimask_output=True,
|
534 |
+
return_logits=True,
|
535 |
+
mask_threshold = self.mask_threshold,
|
536 |
+
)
|
537 |
+
max_values = max_values.detach().cpu().numpy()
|
538 |
+
max_index = max_index.detach().cpu().numpy()
|
539 |
+
pred_class = []
|
540 |
+
for i in range(max_index.shape[0]):
|
541 |
+
if max_values[i] > self.crop_thresh:
|
542 |
+
pred_class.append(max_index[i])
|
543 |
+
else:
|
544 |
+
pred_class.append(-1)
|
545 |
+
|
546 |
+
# Serialize predictions and store in MaskData
|
547 |
+
data = MaskData(
|
548 |
+
masks=masks.flatten(0, 1),
|
549 |
+
iou_preds=iou_preds.flatten(0, 1),
|
550 |
+
pred_class=pred_class,
|
551 |
+
points=torch.as_tensor(points.repeat(masks.shape[1], axis=0)),
|
552 |
+
)
|
553 |
+
del masks
|
554 |
+
|
555 |
+
if data["pred_class"]!=[]:
|
556 |
+
|
557 |
+
keep_mask= []
|
558 |
+
for i in data["pred_class"]:
|
559 |
+
if i != -1:
|
560 |
+
keep_mask.append(True)
|
561 |
+
else:
|
562 |
+
keep_mask.append(False)
|
563 |
+
|
564 |
+
keep_mask = torch.tensor(keep_mask)
|
565 |
+
data.filter(keep_mask)
|
566 |
+
|
567 |
+
# Threshold masks and calculate boxes
|
568 |
+
data["masks"] = data["masks"] > self.predictor.model.mask_threshold
|
569 |
+
data["boxes"] = batched_mask_to_box(data["masks"])
|
570 |
+
|
571 |
+
# Filter boxes that touch crop boundaries
|
572 |
+
keep_mask = ~is_box_near_crop_edge(data["boxes"], crop_box, [0, 0, orig_w, orig_h])
|
573 |
+
if not torch.all(keep_mask):
|
574 |
+
data.filter(keep_mask)
|
575 |
+
|
576 |
+
# Compress to RLE
|
577 |
+
data["masks"] = uncrop_masks(data["masks"], crop_box, orig_h, orig_w)
|
578 |
+
data["rles"] = mask_to_rle_pytorch(data["masks"])
|
579 |
+
del data["masks"]
|
580 |
+
|
581 |
+
return data
|
582 |
+
|
583 |
+
@staticmethod
|
584 |
+
def postprocess_small_regions(
|
585 |
+
mask_data: MaskData, min_area: int, nms_thresh: float
|
586 |
+
) -> MaskData:
|
587 |
+
"""
|
588 |
+
Removes small disconnected regions and holes in masks, then reruns
|
589 |
+
box NMS to remove any new duplicates.
|
590 |
+
|
591 |
+
Edits mask_data in place.
|
592 |
+
|
593 |
+
Requires open-cv as a dependency.
|
594 |
+
"""
|
595 |
+
if len(mask_data["rles"]) == 0:
|
596 |
+
return mask_data
|
597 |
+
|
598 |
+
# Filter small disconnected regions and holes
|
599 |
+
new_masks = []
|
600 |
+
scores = []
|
601 |
+
for rle in mask_data["rles"]:
|
602 |
+
mask = rle_to_mask(rle)
|
603 |
+
|
604 |
+
mask, changed = remove_small_regions(mask, min_area, mode="holes")
|
605 |
+
unchanged = not changed
|
606 |
+
mask, changed = remove_small_regions(mask, min_area, mode="islands")
|
607 |
+
unchanged = unchanged and not changed
|
608 |
+
|
609 |
+
new_masks.append(torch.as_tensor(mask).unsqueeze(0))
|
610 |
+
# Give score=0 to changed masks and score=1 to unchanged masks
|
611 |
+
# so NMS will prefer ones that didn't need postprocessing
|
612 |
+
scores.append(float(unchanged))
|
613 |
+
|
614 |
+
# Recalculate boxes and remove any new duplicates
|
615 |
+
masks = torch.cat(new_masks, dim=0)
|
616 |
+
boxes = batched_mask_to_box(masks)
|
617 |
+
keep_by_nms = batched_nms(
|
618 |
+
boxes.float(),
|
619 |
+
torch.as_tensor(scores),
|
620 |
+
torch.zeros_like(boxes[:, 0]), # categories
|
621 |
+
iou_threshold=nms_thresh,
|
622 |
+
)
|
623 |
+
|
624 |
+
# Only recalculate RLEs for masks that have changed
|
625 |
+
for i_mask in keep_by_nms:
|
626 |
+
if scores[i_mask] == 0.0:
|
627 |
+
mask_torch = masks[i_mask].unsqueeze(0)
|
628 |
+
mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0]
|
629 |
+
mask_data["boxes"][i_mask] = boxes[i_mask] # update res directly
|
630 |
+
mask_data.filter(keep_by_nms)
|
631 |
+
|
632 |
+
return mask_data
|
regionspot/build.py
ADDED
@@ -0,0 +1,307 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import itertools
|
2 |
+
import logging
|
3 |
+
import torch.utils.data
|
4 |
+
|
5 |
+
from detectron2.config import CfgNode, configurable
|
6 |
+
from detectron2.data.build import (
|
7 |
+
build_batch_data_loader,
|
8 |
+
load_proposals_into_dataset,
|
9 |
+
trivial_batch_collator,
|
10 |
+
)
|
11 |
+
from detectron2.data.catalog import DatasetCatalog
|
12 |
+
from detectron2.data.common import DatasetFromList, MapDataset
|
13 |
+
from detectron2.data.dataset_mapper import DatasetMapper
|
14 |
+
from detectron2.data.samplers import InferenceSampler, TrainingSampler
|
15 |
+
from detectron2.utils.comm import get_world_size
|
16 |
+
|
17 |
+
from torch.utils.data.sampler import Sampler
|
18 |
+
from collections import defaultdict
|
19 |
+
from typing import Optional
|
20 |
+
from detectron2.utils import comm
|
21 |
+
|
22 |
+
|
23 |
+
def _compute_num_images_per_worker(cfg: CfgNode):
|
24 |
+
num_workers = get_world_size()
|
25 |
+
images_per_batch = cfg.SOLVER.IMS_PER_BATCH
|
26 |
+
assert (
|
27 |
+
images_per_batch % num_workers == 0
|
28 |
+
), "SOLVER.IMS_PER_BATCH ({}) must be divisible by the number of workers ({}).".format(
|
29 |
+
images_per_batch, num_workers
|
30 |
+
)
|
31 |
+
assert (
|
32 |
+
images_per_batch >= num_workers
|
33 |
+
), "SOLVER.IMS_PER_BATCH ({}) must be larger than the number of workers ({}).".format(
|
34 |
+
images_per_batch, num_workers
|
35 |
+
)
|
36 |
+
images_per_worker = images_per_batch // num_workers
|
37 |
+
return images_per_worker
|
38 |
+
|
39 |
+
|
40 |
+
def filter_images_with_only_crowd_annotations(dataset_dicts, dataset_names):
|
41 |
+
"""
|
42 |
+
Filter out images with none annotations or only crowd annotations
|
43 |
+
(i.e., images without non-crowd annotations).
|
44 |
+
A common training-time preprocessing on COCO dataset.
|
45 |
+
|
46 |
+
Args:
|
47 |
+
dataset_dicts (list[dict]): annotations in Detectron2 Dataset format.
|
48 |
+
|
49 |
+
Returns:
|
50 |
+
list[dict]: the same format, but filtered.
|
51 |
+
"""
|
52 |
+
num_before = len(dataset_dicts)
|
53 |
+
|
54 |
+
def valid(anns):
|
55 |
+
for ann in anns:
|
56 |
+
if isinstance(ann, list):
|
57 |
+
for instance in ann:
|
58 |
+
if instance.get("iscrowd", 0) == 0:
|
59 |
+
return True
|
60 |
+
else:
|
61 |
+
if ann.get("iscrowd", 0) == 0:
|
62 |
+
return True
|
63 |
+
return False
|
64 |
+
|
65 |
+
dataset_dicts = [x for x in dataset_dicts if valid(x["annotations"])]
|
66 |
+
num_after = len(dataset_dicts)
|
67 |
+
logger = logging.getLogger(__name__)
|
68 |
+
logger.info(
|
69 |
+
"Removed {} images with no usable annotations. {} images left.".format(
|
70 |
+
num_before - num_after, num_after
|
71 |
+
)
|
72 |
+
)
|
73 |
+
return dataset_dicts
|
74 |
+
|
75 |
+
|
76 |
+
def get_detection_dataset_dicts(
|
77 |
+
dataset_names, filter_empty=True, proposal_files=None
|
78 |
+
):
|
79 |
+
"""
|
80 |
+
Load and prepare dataset dicts for instance detection/segmentation and semantic segmentation.
|
81 |
+
|
82 |
+
Args:
|
83 |
+
dataset_names (str or list[str]): a dataset name or a list of dataset names
|
84 |
+
filter_empty (bool): whether to filter out images without instance annotations
|
85 |
+
proposal_files (list[str]): if given, a list of object proposal files
|
86 |
+
that match each dataset in `dataset_names`.
|
87 |
+
|
88 |
+
Returns:
|
89 |
+
list[dict]: a list of dicts following the standard dataset dict format.
|
90 |
+
"""
|
91 |
+
if isinstance(dataset_names, str):
|
92 |
+
dataset_names = [dataset_names]
|
93 |
+
assert len(dataset_names)
|
94 |
+
dataset_dicts = [DatasetCatalog.get(dataset_name) for dataset_name in dataset_names]
|
95 |
+
for dataset_name, dicts in zip(dataset_names, dataset_dicts):
|
96 |
+
assert len(dicts), "Dataset '{}' is empty!".format(dataset_name)
|
97 |
+
|
98 |
+
if proposal_files is not None:
|
99 |
+
assert len(dataset_names) == len(proposal_files)
|
100 |
+
# load precomputed proposals from proposal files
|
101 |
+
dataset_dicts = [
|
102 |
+
load_proposals_into_dataset(dataset_i_dicts, proposal_file)
|
103 |
+
for dataset_i_dicts, proposal_file in zip(dataset_dicts, proposal_files)
|
104 |
+
]
|
105 |
+
|
106 |
+
dataset_dicts = list(itertools.chain.from_iterable(dataset_dicts))
|
107 |
+
|
108 |
+
has_instances = "annotations" in dataset_dicts[0]
|
109 |
+
if filter_empty and has_instances:
|
110 |
+
dataset_dicts = filter_images_with_only_crowd_annotations(dataset_dicts, dataset_names)
|
111 |
+
|
112 |
+
assert len(dataset_dicts), "No valid data found in {}.".format(",".join(dataset_names))
|
113 |
+
return dataset_dicts
|
114 |
+
|
115 |
+
|
116 |
+
def _train_loader_from_config(cfg, mapper, *, dataset=None, sampler=None):
|
117 |
+
if dataset is None:
|
118 |
+
dataset = get_detection_dataset_dicts(
|
119 |
+
cfg.DATASETS.TRAIN,
|
120 |
+
filter_empty=cfg.DATALOADER.FILTER_EMPTY_ANNOTATIONS,
|
121 |
+
proposal_files=cfg.DATASETS.PROPOSAL_FILES_TRAIN if cfg.MODEL.LOAD_PROPOSALS else None,
|
122 |
+
)
|
123 |
+
|
124 |
+
if mapper is None:
|
125 |
+
mapper = DatasetMapper(cfg, True)
|
126 |
+
|
127 |
+
if sampler is None:
|
128 |
+
sampler_name = cfg.DATALOADER.SAMPLER_TRAIN
|
129 |
+
logger = logging.getLogger(__name__)
|
130 |
+
logger.info("Using training sampler {}".format(sampler_name))
|
131 |
+
if sampler_name == "TrainingSampler":
|
132 |
+
sampler = TrainingSampler(len(dataset))
|
133 |
+
elif sampler_name == "ClassAwareSampler":
|
134 |
+
sampler = ClassAwareSampler(dataset)
|
135 |
+
|
136 |
+
return {
|
137 |
+
"dataset": dataset,
|
138 |
+
"sampler": sampler,
|
139 |
+
"mapper": mapper,
|
140 |
+
"total_batch_size": cfg.SOLVER.IMS_PER_BATCH,
|
141 |
+
"aspect_ratio_grouping": cfg.DATALOADER.ASPECT_RATIO_GROUPING,
|
142 |
+
"num_workers": cfg.DATALOADER.NUM_WORKERS,
|
143 |
+
"use_mixup": True
|
144 |
+
}
|
145 |
+
|
146 |
+
|
147 |
+
# TODO can allow dataset as an iterable or IterableDataset to make this function more general
|
148 |
+
@configurable(from_config=_train_loader_from_config)
|
149 |
+
def build_detection_train_loader(
|
150 |
+
dataset, *, mapper, sampler=None, total_batch_size, aspect_ratio_grouping=True, num_workers=0,
|
151 |
+
use_mixup=False
|
152 |
+
):
|
153 |
+
"""
|
154 |
+
Build a dataloader for object detection with some default features.
|
155 |
+
This interface is experimental.
|
156 |
+
|
157 |
+
Args:
|
158 |
+
dataset (list or torch.utils.data.Dataset): a list of dataset dicts,
|
159 |
+
or a map-style pytorch dataset. They can be obtained by using
|
160 |
+
:func:`DatasetCatalog.get` or :func:`get_detection_dataset_dicts`.
|
161 |
+
mapper (callable): a callable which takes a sample (dict) from dataset and
|
162 |
+
returns the format to be consumed by the model.
|
163 |
+
When using cfg, the default choice is ``DatasetMapper(cfg, is_train=True)``.
|
164 |
+
sampler (torch.utils.data.sampler.Sampler or None): a sampler that
|
165 |
+
produces indices to be applied on ``dataset``.
|
166 |
+
Default to :class:`TrainingSampler`, which coordinates a random shuffle
|
167 |
+
sequence across all workers.
|
168 |
+
total_batch_size (int): total batch size across all workers. Batching
|
169 |
+
simply puts data into a list.
|
170 |
+
aspect_ratio_grouping (bool): whether to group images with similar
|
171 |
+
aspect ratio for efficiency. When enabled, it requires each
|
172 |
+
element in dataset be a dict with keys "width" and "height".
|
173 |
+
num_workers (int): number of parallel data loading workers
|
174 |
+
|
175 |
+
Returns:
|
176 |
+
torch.utils.data.DataLoader: a dataloader. Each output from it is a
|
177 |
+
``list[mapped_element]`` of length ``total_batch_size / num_workers``,
|
178 |
+
where ``mapped_element`` is produced by the ``mapper``.
|
179 |
+
"""
|
180 |
+
if isinstance(dataset, list):
|
181 |
+
dataset = DatasetFromList(dataset, copy=False)
|
182 |
+
if mapper is not None:
|
183 |
+
if use_mixup:
|
184 |
+
dataset = MapDatasetMixup(dataset, mapper)
|
185 |
+
else:
|
186 |
+
dataset = MapDataset(dataset, mapper)
|
187 |
+
if sampler is None:
|
188 |
+
sampler = TrainingSampler(len(dataset))
|
189 |
+
assert isinstance(sampler, torch.utils.data.sampler.Sampler)
|
190 |
+
return build_batch_data_loader(
|
191 |
+
dataset,
|
192 |
+
sampler,
|
193 |
+
total_batch_size,
|
194 |
+
aspect_ratio_grouping=aspect_ratio_grouping,
|
195 |
+
num_workers=num_workers,
|
196 |
+
)
|
197 |
+
|
198 |
+
|
199 |
+
def _test_loader_from_config(cfg, dataset_name, mapper=None):
|
200 |
+
"""
|
201 |
+
Uses the given `dataset_name` argument (instead of the names in cfg), because the
|
202 |
+
standard practice is to evaluate each test set individually (not combining them).
|
203 |
+
"""
|
204 |
+
dataset = get_detection_dataset_dicts(
|
205 |
+
[dataset_name],
|
206 |
+
filter_empty=False,
|
207 |
+
proposal_files=[
|
208 |
+
cfg.DATASETS.PROPOSAL_FILES_TEST[list(cfg.DATASETS.TEST).index(dataset_name)]
|
209 |
+
]
|
210 |
+
if cfg.MODEL.LOAD_PROPOSALS
|
211 |
+
else None,
|
212 |
+
)
|
213 |
+
if mapper is None:
|
214 |
+
mapper = DatasetMapper(cfg, False)
|
215 |
+
return {"dataset": dataset, "mapper": mapper, "num_workers": cfg.DATALOADER.NUM_WORKERS}
|
216 |
+
|
217 |
+
|
218 |
+
@configurable(from_config=_test_loader_from_config)
|
219 |
+
def build_detection_test_loader(dataset, *, mapper, num_workers=0):
|
220 |
+
"""
|
221 |
+
Similar to `build_detection_train_loader`, but uses a batch size of 1.
|
222 |
+
This interface is experimental.
|
223 |
+
|
224 |
+
Args:
|
225 |
+
dataset (list or torch.utils.data.Dataset): a list of dataset dicts,
|
226 |
+
or a map-style pytorch dataset. They can be obtained by using
|
227 |
+
:func:`DatasetCatalog.get` or :func:`get_detection_dataset_dicts`.
|
228 |
+
mapper (callable): a callable which takes a sample (dict) from dataset
|
229 |
+
and returns the format to be consumed by the model.
|
230 |
+
When using cfg, the default choice is ``DatasetMapper(cfg, is_train=False)``.
|
231 |
+
num_workers (int): number of parallel data loading workers
|
232 |
+
|
233 |
+
Returns:
|
234 |
+
DataLoader: a torch DataLoader, that loads the given detection
|
235 |
+
dataset, with test-time transformation and batching.
|
236 |
+
|
237 |
+
Examples:
|
238 |
+
::
|
239 |
+
data_loader = build_detection_test_loader(
|
240 |
+
DatasetRegistry.get("my_test"),
|
241 |
+
mapper=DatasetMapper(...))
|
242 |
+
|
243 |
+
# or, instantiate with a CfgNode:
|
244 |
+
data_loader = build_detection_test_loader(cfg, "my_test")
|
245 |
+
"""
|
246 |
+
if isinstance(dataset, list):
|
247 |
+
dataset = DatasetFromList(dataset, copy=False)
|
248 |
+
if mapper is not None:
|
249 |
+
dataset = MapDataset(dataset, mapper)
|
250 |
+
sampler = InferenceSampler(len(dataset))
|
251 |
+
# Always use 1 image per worker during inference since this is the
|
252 |
+
# standard when reporting inference time in papers.
|
253 |
+
# batch_sampler = torch.utils.data.sampler.BatchSampler(sampler, 1, drop_last=False)
|
254 |
+
data_loader = torch.utils.data.DataLoader(
|
255 |
+
dataset,
|
256 |
+
batch_size=1,
|
257 |
+
sampler=sampler,
|
258 |
+
drop_last=False,
|
259 |
+
num_workers=num_workers,
|
260 |
+
collate_fn=trivial_batch_collator,
|
261 |
+
)
|
262 |
+
return data_loader
|
263 |
+
|
264 |
+
|
265 |
+
class ClassAwareSampler(Sampler):
|
266 |
+
def __init__(self, dataset_dicts, seed: Optional[int] = None):
|
267 |
+
"""
|
268 |
+
"""
|
269 |
+
self._size = len(dataset_dicts)
|
270 |
+
assert self._size > 0
|
271 |
+
if seed is None:
|
272 |
+
seed = comm.shared_random_seed()
|
273 |
+
self._seed = int(seed)
|
274 |
+
|
275 |
+
self._rank = comm.get_rank()
|
276 |
+
self._world_size = comm.get_world_size()
|
277 |
+
self.weights = self._get_class_balance_factor(dataset_dicts)
|
278 |
+
|
279 |
+
|
280 |
+
def __iter__(self):
|
281 |
+
start = self._rank
|
282 |
+
yield from itertools.islice(
|
283 |
+
self._infinite_indices(), start, None, self._world_size)
|
284 |
+
|
285 |
+
|
286 |
+
def _infinite_indices(self):
|
287 |
+
g = torch.Generator()
|
288 |
+
g.manual_seed(self._seed)
|
289 |
+
while True:
|
290 |
+
ids = torch.multinomial(
|
291 |
+
self.weights, self._size, generator=g,
|
292 |
+
replacement=True)
|
293 |
+
yield from ids
|
294 |
+
|
295 |
+
|
296 |
+
def _get_class_balance_factor(self, dataset_dicts, l=1.):
|
297 |
+
ret = []
|
298 |
+
category_freq = defaultdict(int)
|
299 |
+
for dataset_dict in dataset_dicts: # For each image (without repeats)
|
300 |
+
cat_ids = {ann["category_id"] for ann in dataset_dict["annotations"]}
|
301 |
+
for cat_id in cat_ids:
|
302 |
+
category_freq[cat_id] += 1
|
303 |
+
for i, dataset_dict in enumerate(dataset_dicts):
|
304 |
+
cat_ids = {ann["category_id"] for ann in dataset_dict["annotations"]}
|
305 |
+
ret.append(sum(
|
306 |
+
[1. / (category_freq[cat_id] ** l) for cat_id in cat_ids]))
|
307 |
+
return torch.tensor(ret).float()
|
regionspot/config.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from detectron2.config import CfgNode as CN
|
2 |
+
|
3 |
+
|
4 |
+
def add_regionspot_config(cfg):
|
5 |
+
"""
|
6 |
+
Add config for RegionSpot
|
7 |
+
"""
|
8 |
+
cfg.MODEL.RegionSpot = CN()
|
9 |
+
cfg.MODEL.CLIP_TYPE = 'CLIP_400M_Large'
|
10 |
+
cfg.MODEL.CLIP_INPUT_SIZE = 224
|
11 |
+
# Inference
|
12 |
+
cfg.MODEL.TRAINING = True
|
13 |
+
cfg.MODEL.BOX_TYPE = 'GT'
|
14 |
+
|
15 |
+
#Dataloder
|
16 |
+
cfg.DATALOADER.DATASET_RATIO = [1,1,1] # sample ratio
|
17 |
+
cfg.DATALOADER.USE_RFS = [False, False, False]
|
18 |
+
cfg.DATALOADER.MULTI_DATASET_GROUPING = True # Always true when multi-dataset is enabled
|
19 |
+
cfg.DATALOADER.DATASET_ANN = ['box', 'box', 'box'] # Annotation type of each dataset
|
20 |
+
cfg.DATALOADER.USE_DIFF_BS_SIZE = False # Use different batchsize for each dataset
|
21 |
+
cfg.DATALOADER.DATASET_BS = [8, 32] # Used when USE_DIFF_BS_SIZE is on
|
22 |
+
|
23 |
+
|
24 |
+
|
25 |
+
# Optimizer.
|
26 |
+
cfg.SOLVER.OPTIMIZER = "ADAMW"
|
27 |
+
cfg.SOLVER.BACKBONE_MULTIPLIER = 1.0
|
28 |
+
|
29 |
+
# TTA.
|
30 |
+
cfg.TEST.AUG.MIN_SIZES = (400, 500, 600, 640, 700, 900, 1000, 1100, 1200, 1300, 1400, 1800, 800)
|
31 |
+
cfg.TEST.AUG.CVPODS_TTA = True
|
32 |
+
cfg.TEST.AUG.SCALE_FILTER = True
|
33 |
+
cfg.TEST.AUG.SCALE_RANGES = ([96, 10000], [96, 10000],
|
34 |
+
[64, 10000], [64, 10000],
|
35 |
+
[64, 10000], [0, 10000],
|
36 |
+
[0, 10000], [0, 256],
|
37 |
+
[0, 256], [0, 192],
|
38 |
+
[0, 192], [0, 96],
|
39 |
+
[0, 10000])
|
regionspot/data/__pycache__/custom_dataset_dataloader.cpython-38.pyc
ADDED
Binary file (10.3 kB). View file
|
|
regionspot/data/__pycache__/dataset_mapper.cpython-38.pyc
ADDED
Binary file (3.87 kB). View file
|
|
regionspot/data/__pycache__/v3det_categories.cpython-38.pyc
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3ba3ef77b19c8f288db6ed9ac232384853c7d27413c35d662423586824480519
|
3 |
+
size 1552733
|
regionspot/data/custom_dataset_dataloader.py
ADDED
@@ -0,0 +1,331 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# Part of the code is from https://github.com/xingyizhou/UniDet/blob/master/projects/UniDet/unidet/data/multi_dataset_dataloader.py (Apache-2.0 License)
|
3 |
+
import copy
|
4 |
+
import logging
|
5 |
+
import numpy as np
|
6 |
+
import operator
|
7 |
+
import torch
|
8 |
+
import torch.utils.data
|
9 |
+
import json
|
10 |
+
from detectron2.utils.comm import get_world_size
|
11 |
+
from detectron2.utils.logger import _log_api_usage, log_first_n
|
12 |
+
|
13 |
+
from detectron2.config import configurable
|
14 |
+
from detectron2.data import samplers
|
15 |
+
from torch.utils.data.sampler import BatchSampler, Sampler
|
16 |
+
from detectron2.data.common import DatasetFromList, MapDataset
|
17 |
+
from detectron2.data.dataset_mapper import DatasetMapper
|
18 |
+
from detectron2.data.build import get_detection_dataset_dicts, build_batch_data_loader
|
19 |
+
from detectron2.data.samplers import TrainingSampler, RepeatFactorTrainingSampler
|
20 |
+
from detectron2.data.build import worker_init_reset_seed, print_instances_class_histogram
|
21 |
+
from detectron2.data.build import filter_images_with_only_crowd_annotations
|
22 |
+
from detectron2.data.build import filter_images_with_few_keypoints
|
23 |
+
from detectron2.data.build import check_metadata_consistency
|
24 |
+
from detectron2.data.catalog import MetadataCatalog, DatasetCatalog
|
25 |
+
from detectron2.utils import comm
|
26 |
+
import itertools
|
27 |
+
import math
|
28 |
+
from collections import defaultdict
|
29 |
+
from typing import Optional
|
30 |
+
|
31 |
+
|
32 |
+
def _custom_train_loader_from_config(cfg, mapper=None, *, dataset=None, sampler=None):
|
33 |
+
sampler_name = cfg.DATALOADER.SAMPLER_TRAIN # "MultiDatasetSampler"
|
34 |
+
if 'MultiDataset' in sampler_name: # True
|
35 |
+
dataset_dicts = get_detection_dataset_dicts_with_source(
|
36 |
+
cfg.DATASETS.TRAIN,
|
37 |
+
filter_empty=cfg.DATALOADER.FILTER_EMPTY_ANNOTATIONS,
|
38 |
+
min_keypoints=cfg.MODEL.ROI_KEYPOINT_HEAD.MIN_KEYPOINTS_PER_IMAGE
|
39 |
+
if cfg.MODEL.KEYPOINT_ON else 0,
|
40 |
+
proposal_files=cfg.DATASETS.PROPOSAL_FILES_TRAIN if cfg.MODEL.LOAD_PROPOSALS else None,
|
41 |
+
)
|
42 |
+
else: # False
|
43 |
+
dataset_dicts = get_detection_dataset_dicts(
|
44 |
+
cfg.DATASETS.TRAIN,
|
45 |
+
filter_empty=cfg.DATALOADER.FILTER_EMPTY_ANNOTATIONS,
|
46 |
+
min_keypoints=cfg.MODEL.ROI_KEYPOINT_HEAD.MIN_KEYPOINTS_PER_IMAGE
|
47 |
+
if cfg.MODEL.KEYPOINT_ON else 0,
|
48 |
+
proposal_files=cfg.DATASETS.PROPOSAL_FILES_TRAIN if cfg.MODEL.LOAD_PROPOSALS else None,
|
49 |
+
)
|
50 |
+
|
51 |
+
if mapper is None: # False
|
52 |
+
mapper = DatasetMapper(cfg, True)
|
53 |
+
|
54 |
+
if sampler is not None:
|
55 |
+
pass
|
56 |
+
elif sampler_name == "TrainingSampler": # False
|
57 |
+
sampler = TrainingSampler(len(dataset))
|
58 |
+
elif sampler_name == "MultiDatasetSampler": # True
|
59 |
+
sampler = MultiDatasetSampler(
|
60 |
+
dataset_dicts,
|
61 |
+
dataset_ratio = cfg.DATALOADER.DATASET_RATIO,
|
62 |
+
use_rfs = cfg.DATALOADER.USE_RFS,
|
63 |
+
dataset_ann = cfg.DATALOADER.DATASET_ANN,
|
64 |
+
repeat_threshold = cfg.DATALOADER.REPEAT_THRESHOLD,
|
65 |
+
)
|
66 |
+
elif sampler_name == "RepeatFactorTrainingSampler": # False
|
67 |
+
repeat_factors = RepeatFactorTrainingSampler.repeat_factors_from_category_frequency(
|
68 |
+
dataset_dicts, cfg.DATALOADER.REPEAT_THRESHOLD
|
69 |
+
)
|
70 |
+
sampler = RepeatFactorTrainingSampler(repeat_factors)
|
71 |
+
else:
|
72 |
+
raise ValueError("Unknown training sampler: {}".format(sampler_name))
|
73 |
+
|
74 |
+
return {
|
75 |
+
"dataset": dataset_dicts,
|
76 |
+
"sampler": sampler,
|
77 |
+
"mapper": mapper,
|
78 |
+
"total_batch_size": cfg.SOLVER.IMS_PER_BATCH, # 64
|
79 |
+
"aspect_ratio_grouping": cfg.DATALOADER.ASPECT_RATIO_GROUPING,
|
80 |
+
"num_workers": cfg.DATALOADER.NUM_WORKERS, # 8
|
81 |
+
'multi_dataset_grouping': cfg.DATALOADER.MULTI_DATASET_GROUPING, # True
|
82 |
+
'use_diff_bs_size': cfg.DATALOADER.USE_DIFF_BS_SIZE, # True
|
83 |
+
'dataset_bs': cfg.DATALOADER.DATASET_BS, # [8, 32]
|
84 |
+
'num_datasets': len(cfg.DATASETS.TRAIN) # 2
|
85 |
+
}
|
86 |
+
|
87 |
+
|
88 |
+
@configurable(from_config=_custom_train_loader_from_config)
|
89 |
+
def build_custom_train_loader(
|
90 |
+
dataset, *, mapper, sampler,
|
91 |
+
total_batch_size=16, # 64
|
92 |
+
aspect_ratio_grouping=True,
|
93 |
+
num_workers=0, # 8
|
94 |
+
num_datasets=1, # 2
|
95 |
+
multi_dataset_grouping=False, # True
|
96 |
+
use_diff_bs_size=False, # True
|
97 |
+
dataset_bs=[] # [8, 32]
|
98 |
+
):
|
99 |
+
"""
|
100 |
+
Modified from detectron2.data.build.build_custom_train_loader, but supports
|
101 |
+
different samplers
|
102 |
+
"""
|
103 |
+
if isinstance(dataset, list):
|
104 |
+
dataset = DatasetFromList(dataset, copy=False)
|
105 |
+
if mapper is not None: # True
|
106 |
+
dataset = MapDataset(dataset, mapper)
|
107 |
+
if sampler is None: # False
|
108 |
+
sampler = TrainingSampler(len(dataset))
|
109 |
+
assert isinstance(sampler, torch.utils.data.sampler.Sampler)
|
110 |
+
if multi_dataset_grouping: # True
|
111 |
+
return build_multi_dataset_batch_data_loader(
|
112 |
+
use_diff_bs_size,
|
113 |
+
dataset_bs,
|
114 |
+
dataset,
|
115 |
+
sampler,
|
116 |
+
total_batch_size,
|
117 |
+
num_datasets=num_datasets,
|
118 |
+
num_workers=num_workers,
|
119 |
+
)
|
120 |
+
else: # False
|
121 |
+
return build_batch_data_loader(
|
122 |
+
dataset,
|
123 |
+
sampler,
|
124 |
+
total_batch_size,
|
125 |
+
aspect_ratio_grouping=aspect_ratio_grouping,
|
126 |
+
num_workers=num_workers,
|
127 |
+
)
|
128 |
+
|
129 |
+
|
130 |
+
def build_multi_dataset_batch_data_loader(
|
131 |
+
use_diff_bs_size, dataset_bs,
|
132 |
+
dataset, sampler, total_batch_size, num_datasets, num_workers=0
|
133 |
+
):
|
134 |
+
"""
|
135 |
+
"""
|
136 |
+
world_size = get_world_size()
|
137 |
+
assert (
|
138 |
+
total_batch_size > 0 and total_batch_size % world_size == 0
|
139 |
+
), "Total batch size ({}) must be divisible by the number of gpus ({}).".format(
|
140 |
+
total_batch_size, world_size
|
141 |
+
)
|
142 |
+
|
143 |
+
batch_size = total_batch_size // world_size
|
144 |
+
data_loader = torch.utils.data.DataLoader(
|
145 |
+
dataset,
|
146 |
+
sampler=sampler,
|
147 |
+
num_workers=num_workers,
|
148 |
+
batch_sampler=None,
|
149 |
+
collate_fn=operator.itemgetter(0), # don't batch, but yield individual elements
|
150 |
+
worker_init_fn=worker_init_reset_seed,
|
151 |
+
) # yield individual mapped dict
|
152 |
+
if use_diff_bs_size:
|
153 |
+
return DIFFMDAspectRatioGroupedDataset(
|
154 |
+
data_loader, dataset_bs, num_datasets)
|
155 |
+
else:
|
156 |
+
return MDAspectRatioGroupedDataset(
|
157 |
+
data_loader, batch_size, num_datasets)
|
158 |
+
|
159 |
+
|
160 |
+
def get_detection_dataset_dicts_with_source(
|
161 |
+
dataset_names, filter_empty=True, min_keypoints=0, proposal_files=None
|
162 |
+
):
|
163 |
+
assert len(dataset_names)
|
164 |
+
dataset_dicts = [DatasetCatalog.get(dataset_name) for dataset_name in dataset_names]
|
165 |
+
for dataset_name, dicts in zip(dataset_names, dataset_dicts):
|
166 |
+
assert len(dicts), "Dataset '{}' is empty!".format(dataset_name)
|
167 |
+
|
168 |
+
for source_id, (dataset_name, dicts) in \
|
169 |
+
enumerate(zip(dataset_names, dataset_dicts)):
|
170 |
+
assert len(dicts), "Dataset '{}' is empty!".format(dataset_name)
|
171 |
+
for d in dicts:
|
172 |
+
d['dataset_source'] = source_id # add "dataset_source" to original dict
|
173 |
+
|
174 |
+
if "annotations" in dicts[0]:
|
175 |
+
try:
|
176 |
+
class_names = MetadataCatalog.get(dataset_name).thing_classes
|
177 |
+
check_metadata_consistency("thing_classes", dataset_name)
|
178 |
+
print_instances_class_histogram(dicts, class_names)
|
179 |
+
except AttributeError: # class names are not available for this dataset
|
180 |
+
pass
|
181 |
+
|
182 |
+
assert proposal_files is None
|
183 |
+
|
184 |
+
dataset_dicts = list(itertools.chain.from_iterable(dataset_dicts)) # connect multiple iterable objects to one
|
185 |
+
|
186 |
+
has_instances = "annotations" in dataset_dicts[0]
|
187 |
+
if filter_empty and has_instances:
|
188 |
+
dataset_dicts = filter_images_with_only_crowd_annotations(dataset_dicts)
|
189 |
+
if min_keypoints > 0 and has_instances:
|
190 |
+
dataset_dicts = filter_images_with_few_keypoints(dataset_dicts, min_keypoints)
|
191 |
+
|
192 |
+
return dataset_dicts
|
193 |
+
|
194 |
+
|
195 |
+
class MultiDatasetSampler(Sampler):
|
196 |
+
def __init__(
|
197 |
+
self,
|
198 |
+
dataset_dicts,
|
199 |
+
dataset_ratio,
|
200 |
+
use_rfs, # [True, False]
|
201 |
+
dataset_ann,
|
202 |
+
repeat_threshold=0.001,
|
203 |
+
seed: Optional[int] = None,
|
204 |
+
):
|
205 |
+
"""
|
206 |
+
"""
|
207 |
+
sizes = [0 for _ in range(len(dataset_ratio))]
|
208 |
+
for d in dataset_dicts:
|
209 |
+
sizes[d['dataset_source']] += 1 # size of each dataset
|
210 |
+
print('dataset sizes', sizes)
|
211 |
+
self.sizes = sizes
|
212 |
+
assert len(dataset_ratio) == len(sizes), \
|
213 |
+
'length of dataset ratio {} should be equal to number if dataset {}'.format(
|
214 |
+
len(dataset_ratio), len(sizes)
|
215 |
+
)
|
216 |
+
if seed is None:
|
217 |
+
seed = comm.shared_random_seed() # seed shared across all GPUs
|
218 |
+
self._seed = int(seed)
|
219 |
+
self._rank = comm.get_rank()
|
220 |
+
self._world_size = comm.get_world_size()
|
221 |
+
|
222 |
+
self.dataset_ids = torch.tensor(
|
223 |
+
[d['dataset_source'] for d in dataset_dicts], dtype=torch.long)
|
224 |
+
|
225 |
+
dataset_weight = [torch.ones(s) * max(sizes) / s * r / sum(dataset_ratio) \
|
226 |
+
for i, (r, s) in enumerate(zip(dataset_ratio, sizes))]
|
227 |
+
dataset_weight = torch.cat(dataset_weight)
|
228 |
+
|
229 |
+
rfs_factors = []
|
230 |
+
st = 0
|
231 |
+
for i, s in enumerate(sizes):
|
232 |
+
if use_rfs[i]:
|
233 |
+
if dataset_ann[i] == 'box':
|
234 |
+
rfs_func = RepeatFactorTrainingSampler.repeat_factors_from_category_frequency
|
235 |
+
else:
|
236 |
+
rfs_func = repeat_factors_from_tag_frequency
|
237 |
+
rfs_factor = rfs_func(
|
238 |
+
dataset_dicts[st: st + s],
|
239 |
+
repeat_thresh=repeat_threshold)
|
240 |
+
rfs_factor = rfs_factor * (s / rfs_factor.sum())
|
241 |
+
else:
|
242 |
+
rfs_factor = torch.ones(s)
|
243 |
+
rfs_factors.append(rfs_factor)
|
244 |
+
st = st + s
|
245 |
+
rfs_factors = torch.cat(rfs_factors)
|
246 |
+
|
247 |
+
self.weights = dataset_weight * rfs_factors # weights for each element in the dataset_dict
|
248 |
+
self.sample_epoch_size = len(self.weights)
|
249 |
+
|
250 |
+
def __iter__(self):
|
251 |
+
start = self._rank
|
252 |
+
yield from itertools.islice(
|
253 |
+
self._infinite_indices(), start, None, self._world_size) # itertools.islice(iterable, start, stop[, step])
|
254 |
+
|
255 |
+
|
256 |
+
def _infinite_indices(self):
|
257 |
+
g = torch.Generator()
|
258 |
+
g.manual_seed(self._seed)
|
259 |
+
while True:
|
260 |
+
ids = torch.multinomial(
|
261 |
+
self.weights, self.sample_epoch_size, generator=g,
|
262 |
+
replacement=True) # randomly sample according to the given weights
|
263 |
+
nums = [(self.dataset_ids[ids] == i).sum().int().item() \
|
264 |
+
for i in range(len(self.sizes))]
|
265 |
+
yield from ids
|
266 |
+
|
267 |
+
|
268 |
+
class MDAspectRatioGroupedDataset(torch.utils.data.IterableDataset):
|
269 |
+
def __init__(self, dataset, batch_size, num_datasets):
|
270 |
+
"""
|
271 |
+
"""
|
272 |
+
self.dataset = dataset
|
273 |
+
self.batch_size = batch_size
|
274 |
+
self._buckets = [[] for _ in range(2 * num_datasets)] # there are (2 x num_datasets) types of data. For each dataset, there are two types: w>h or w<=h
|
275 |
+
|
276 |
+
def __iter__(self):
|
277 |
+
for d in self.dataset:
|
278 |
+
w, h = d["width"], d["height"]
|
279 |
+
aspect_ratio_bucket_id = 0 if w > h else 1
|
280 |
+
bucket_id = d['dataset_source'] * 2 + aspect_ratio_bucket_id
|
281 |
+
bucket = self._buckets[bucket_id]
|
282 |
+
bucket.append(d)
|
283 |
+
if len(bucket) == self.batch_size:
|
284 |
+
yield bucket[:]
|
285 |
+
del bucket[:]
|
286 |
+
|
287 |
+
|
288 |
+
class DIFFMDAspectRatioGroupedDataset(torch.utils.data.IterableDataset):
|
289 |
+
def __init__(self, dataset, batch_sizes, num_datasets):
|
290 |
+
"""
|
291 |
+
"""
|
292 |
+
self.dataset = dataset
|
293 |
+
self.batch_sizes = batch_sizes
|
294 |
+
self._buckets = [[] for _ in range(2 * num_datasets)]
|
295 |
+
|
296 |
+
def __iter__(self):
|
297 |
+
for d in self.dataset:
|
298 |
+
w, h = d["width"], d["height"]
|
299 |
+
aspect_ratio_bucket_id = 0 if w > h else 1
|
300 |
+
bucket_id = d['dataset_source'] * 2 + aspect_ratio_bucket_id
|
301 |
+
bucket = self._buckets[bucket_id]
|
302 |
+
bucket.append(d)
|
303 |
+
if len(bucket) == self.batch_sizes[d['dataset_source']]: # allow different batchsizes
|
304 |
+
yield bucket[:]
|
305 |
+
del bucket[:]
|
306 |
+
|
307 |
+
|
308 |
+
def repeat_factors_from_tag_frequency(dataset_dicts, repeat_thresh):
|
309 |
+
"""
|
310 |
+
"""
|
311 |
+
category_freq = defaultdict(int)
|
312 |
+
for dataset_dict in dataset_dicts:
|
313 |
+
cat_ids = dataset_dict['pos_category_ids']
|
314 |
+
for cat_id in cat_ids:
|
315 |
+
category_freq[cat_id] += 1
|
316 |
+
num_images = len(dataset_dicts)
|
317 |
+
for k, v in category_freq.items():
|
318 |
+
category_freq[k] = v / num_images
|
319 |
+
|
320 |
+
category_rep = {
|
321 |
+
cat_id: max(1.0, math.sqrt(repeat_thresh / cat_freq))
|
322 |
+
for cat_id, cat_freq in category_freq.items()
|
323 |
+
}
|
324 |
+
|
325 |
+
rep_factors = []
|
326 |
+
for dataset_dict in dataset_dicts:
|
327 |
+
cat_ids = dataset_dict['pos_category_ids']
|
328 |
+
rep_factor = max({category_rep[cat_id] for cat_id in cat_ids}, default=1.0)
|
329 |
+
rep_factors.append(rep_factor)
|
330 |
+
|
331 |
+
return torch.tensor(rep_factors, dtype=torch.float32)
|
regionspot/data/dataset_mapper.py
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ========================================
|
2 |
+
# Modified by Shoufa Chen
|
3 |
+
# ========================================
|
4 |
+
# Modified by Peize Sun, Rufeng Zhang
|
5 |
+
# Contact: {sunpeize, cxrfzhang}@foxmail.com
|
6 |
+
#
|
7 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
8 |
+
import copy
|
9 |
+
import logging
|
10 |
+
import numpy as np
|
11 |
+
import torch
|
12 |
+
import os
|
13 |
+
from detectron2.data import detection_utils as utils
|
14 |
+
from detectron2.data import transforms as T
|
15 |
+
|
16 |
+
|
17 |
+
__all__ = ["RegionSpotDatasetMapper"]
|
18 |
+
|
19 |
+
|
20 |
+
def build_transform_gen(cfg, is_train):
|
21 |
+
"""
|
22 |
+
Create a list of :class:`TransformGen` from config.
|
23 |
+
Returns:
|
24 |
+
list[TransformGen]
|
25 |
+
"""
|
26 |
+
if is_train:
|
27 |
+
min_size = cfg.INPUT.MIN_SIZE_TRAIN
|
28 |
+
max_size = cfg.INPUT.MAX_SIZE_TRAIN
|
29 |
+
sample_style = cfg.INPUT.MIN_SIZE_TRAIN_SAMPLING
|
30 |
+
else:
|
31 |
+
min_size = cfg.INPUT.MIN_SIZE_TEST
|
32 |
+
max_size = cfg.INPUT.MAX_SIZE_TEST
|
33 |
+
sample_style = "choice"
|
34 |
+
if sample_style == "range":
|
35 |
+
assert len(min_size) == 2, "more than 2 ({}) min_size(s) are provided for ranges".format(len(min_size))
|
36 |
+
|
37 |
+
logger = logging.getLogger(__name__)
|
38 |
+
tfm_gens = []
|
39 |
+
if is_train:
|
40 |
+
tfm_gens.append(T.RandomFlip())
|
41 |
+
# ResizeShortestEdge
|
42 |
+
tfm_gens.append(T.ResizeShortestEdge(min_size, max_size, sample_style))
|
43 |
+
|
44 |
+
if is_train:
|
45 |
+
logger.info("TransformGens used in training: " + str(tfm_gens))
|
46 |
+
return tfm_gens
|
47 |
+
|
48 |
+
|
49 |
+
class RegionSpotDatasetMapper:
|
50 |
+
"""
|
51 |
+
A callable which takes a dataset dict in Detectron2 Dataset format,
|
52 |
+
and map it into a format used by DiffusionDet.
|
53 |
+
|
54 |
+
The callable currently does the following:
|
55 |
+
|
56 |
+
1. Read the image from "file_name"
|
57 |
+
2. Applies geometric transforms to the image and annotation
|
58 |
+
3. Find and applies suitable cropping to the image and annotation
|
59 |
+
4. Prepare image and annotation to Tensors
|
60 |
+
"""
|
61 |
+
|
62 |
+
def __init__(self, cfg, is_train=True):
|
63 |
+
if cfg.INPUT.CROP.ENABLED and is_train:
|
64 |
+
self.crop_gen = [
|
65 |
+
T.ResizeShortestEdge([400, 500, 600], sample_style="choice"),
|
66 |
+
T.RandomCrop(cfg.INPUT.CROP.TYPE, cfg.INPUT.CROP.SIZE),
|
67 |
+
]
|
68 |
+
else:
|
69 |
+
self.crop_gen = None
|
70 |
+
|
71 |
+
self.tfm_gens = build_transform_gen(cfg, is_train)
|
72 |
+
logging.getLogger(__name__).info(
|
73 |
+
"Full TransformGens used in training: {}, crop: {}".format(str(self.tfm_gens), str(self.crop_gen))
|
74 |
+
)
|
75 |
+
|
76 |
+
self.img_format = cfg.INPUT.FORMAT
|
77 |
+
self.is_train = is_train
|
78 |
+
# if self.is_train:
|
79 |
+
# for dataset_name in cfg.DATASETS.TRAIN:
|
80 |
+
# if dataset_name.startswith("coco"):
|
81 |
+
self.mask_tokens_dir = os.path.join('./datasets/datasets_mask_tokens_vit_b/')
|
82 |
+
|
83 |
+
def __call__(self, dataset_dict):
|
84 |
+
"""
|
85 |
+
Args:
|
86 |
+
dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format.
|
87 |
+
|
88 |
+
Returns:
|
89 |
+
dict: a format that builtin models in detectron2 accept
|
90 |
+
"""
|
91 |
+
dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below
|
92 |
+
image = utils.read_image(dataset_dict["file_name"], format=self.img_format)
|
93 |
+
# utils.check_image_size(dataset_dict, image)
|
94 |
+
#
|
95 |
+
#get mask token and responsed label
|
96 |
+
image_id = dataset_dict["image_id"]
|
97 |
+
dataset_name = dataset_dict["file_name"].split('/')[1]
|
98 |
+
#datasets/coco/train2017/000000566174.jpg
|
99 |
+
#read pth
|
100 |
+
pth_file = os.path.join(self.mask_tokens_dir, os.path.join(dataset_name, str(image_id)+'.pth'))
|
101 |
+
offline_token = torch.load(pth_file)
|
102 |
+
#
|
103 |
+
if self.crop_gen is None:
|
104 |
+
image, transforms = T.apply_transform_gens(self.tfm_gens, image)
|
105 |
+
else:
|
106 |
+
if np.random.rand() > 0.5:
|
107 |
+
image, transforms = T.apply_transform_gens(self.tfm_gens, image)
|
108 |
+
else:
|
109 |
+
image, transforms = T.apply_transform_gens(
|
110 |
+
self.tfm_gens[:-1] + self.crop_gen + self.tfm_gens[-1:], image
|
111 |
+
)
|
112 |
+
|
113 |
+
image_shape = image.shape[:2] # h, w
|
114 |
+
|
115 |
+
# Pytorch's dataloader is efficient on torch.Tensor due to shared-memory,
|
116 |
+
# but not efficient on large generic data structures due to the use of pickle & mp.Queue.
|
117 |
+
# Therefore it's important to use torch.Tensor.
|
118 |
+
dataset_dict["image"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1)))
|
119 |
+
dataset_dict["dataset_name"] = dataset_name
|
120 |
+
dataset_dict["extra_info"] = offline_token
|
121 |
+
if not self.is_train:
|
122 |
+
# USER: Modify this if you want to keep them for some reason.
|
123 |
+
dataset_dict.pop("annotations", None)
|
124 |
+
return dataset_dict
|
125 |
+
|
126 |
+
if "annotations" in dataset_dict:
|
127 |
+
# USER: Modify this if you want to keep them for some reason.
|
128 |
+
for anno in dataset_dict["annotations"]:
|
129 |
+
anno.pop("segmentation", None)
|
130 |
+
anno.pop("keypoints", None)
|
131 |
+
|
132 |
+
# USER: Implement additional transformations if you have other types of data
|
133 |
+
annos = [
|
134 |
+
utils.transform_instance_annotations(obj, transforms, image_shape)
|
135 |
+
for obj in dataset_dict.pop("annotations")
|
136 |
+
if obj.get("iscrowd", 0) == 0
|
137 |
+
]
|
138 |
+
instances = utils.annotations_to_instances(annos, image_shape)
|
139 |
+
dataset_dict["instances"] = utils.filter_empty_instances(instances)
|
140 |
+
return dataset_dict
|
regionspot/data/objects365.py
ADDED
@@ -0,0 +1,391 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from detectron2.data.datasets.register_coco import register_coco_instances
|
2 |
+
import os
|
3 |
+
|
4 |
+
categories = [
|
5 |
+
{'id': 164, 'name': 'cutting/chopping board'} ,
|
6 |
+
{'id': 49, 'name': 'tie'} ,
|
7 |
+
{'id': 306, 'name': 'crosswalk sign'} ,
|
8 |
+
{'id': 145, 'name': 'gun'} ,
|
9 |
+
{'id': 14, 'name': 'street lights'} ,
|
10 |
+
{'id': 223, 'name': 'bar soap'} ,
|
11 |
+
{'id': 74, 'name': 'wild bird'} ,
|
12 |
+
{'id': 219, 'name': 'ice cream'} ,
|
13 |
+
{'id': 37, 'name': 'stool'} ,
|
14 |
+
{'id': 25, 'name': 'storage box'} ,
|
15 |
+
{'id': 153, 'name': 'giraffe'} ,
|
16 |
+
{'id': 52, 'name': 'pen/pencil'} ,
|
17 |
+
{'id': 61, 'name': 'high heels'} ,
|
18 |
+
{'id': 340, 'name': 'mangosteen'} ,
|
19 |
+
{'id': 22, 'name': 'bracelet'} ,
|
20 |
+
{'id': 155, 'name': 'piano'} ,
|
21 |
+
{'id': 162, 'name': 'vent'} ,
|
22 |
+
{'id': 75, 'name': 'laptop'} ,
|
23 |
+
{'id': 236, 'name': 'toaster'} ,
|
24 |
+
{'id': 231, 'name': 'fire truck'} ,
|
25 |
+
{'id': 42, 'name': 'basket'} ,
|
26 |
+
{'id': 150, 'name': 'zebra'} ,
|
27 |
+
{'id': 124, 'name': 'head phone'} ,
|
28 |
+
{'id': 90, 'name': 'sheep'} ,
|
29 |
+
{'id': 322, 'name': 'steak'} ,
|
30 |
+
{'id': 39, 'name': 'couch'} ,
|
31 |
+
{'id': 209, 'name': 'toothbrush'} ,
|
32 |
+
{'id': 59, 'name': 'bicycle'} ,
|
33 |
+
{'id': 336, 'name': 'red cabbage'} ,
|
34 |
+
{'id': 228, 'name': 'golf ball'} ,
|
35 |
+
{'id': 120, 'name': 'tomato'} ,
|
36 |
+
{'id': 132, 'name': 'computer box'} ,
|
37 |
+
{'id': 8, 'name': 'cup'} ,
|
38 |
+
{'id': 183, 'name': 'basketball'} ,
|
39 |
+
{'id': 298, 'name': 'butterfly'} ,
|
40 |
+
{'id': 250, 'name': 'garlic'} ,
|
41 |
+
{'id': 12, 'name': 'desk'} ,
|
42 |
+
{'id': 141, 'name': 'microwave'} ,
|
43 |
+
{'id': 171, 'name': 'strawberry'} ,
|
44 |
+
{'id': 200, 'name': 'kettle'} ,
|
45 |
+
{'id': 63, 'name': 'van'} ,
|
46 |
+
{'id': 300, 'name': 'cheese'} ,
|
47 |
+
{'id': 215, 'name': 'marker'} ,
|
48 |
+
{'id': 100, 'name': 'blackboard/whiteboard'} ,
|
49 |
+
{'id': 186, 'name': 'printer'} ,
|
50 |
+
{'id': 333, 'name': 'bread/bun'} ,
|
51 |
+
{'id': 243, 'name': 'penguin'} ,
|
52 |
+
{'id': 364, 'name': 'iron'} ,
|
53 |
+
{'id': 180, 'name': 'ladder'} ,
|
54 |
+
{'id': 34, 'name': 'flag'} ,
|
55 |
+
{'id': 78, 'name': 'cell phone'} ,
|
56 |
+
{'id': 97, 'name': 'fan'} ,
|
57 |
+
{'id': 224, 'name': 'scale'} ,
|
58 |
+
{'id': 151, 'name': 'duck'} ,
|
59 |
+
{'id': 319, 'name': 'flute'} ,
|
60 |
+
{'id': 156, 'name': 'stop sign'} ,
|
61 |
+
{'id': 290, 'name': 'rickshaw'} ,
|
62 |
+
{'id': 128, 'name': 'sailboat'} ,
|
63 |
+
{'id': 165, 'name': 'tennis racket'} ,
|
64 |
+
{'id': 241, 'name': 'cigar'} ,
|
65 |
+
{'id': 101, 'name': 'balloon'} ,
|
66 |
+
{'id': 308, 'name': 'hair drier'} ,
|
67 |
+
{'id': 167, 'name': 'skating and skiing shoes'} ,
|
68 |
+
{'id': 237, 'name': 'helicopter'} ,
|
69 |
+
{'id': 65, 'name': 'sink'} ,
|
70 |
+
{'id': 129, 'name': 'tangerine'} ,
|
71 |
+
{'id': 330, 'name': 'crab'} ,
|
72 |
+
{'id': 320, 'name': 'measuring cup'} ,
|
73 |
+
{'id': 260, 'name': 'fishing rod'} ,
|
74 |
+
{'id': 346, 'name': 'saw'} ,
|
75 |
+
{'id': 216, 'name': 'ship'} ,
|
76 |
+
{'id': 46, 'name': 'coffee table'} ,
|
77 |
+
{'id': 194, 'name': 'facial mask'} ,
|
78 |
+
{'id': 281, 'name': 'stapler'} ,
|
79 |
+
{'id': 118, 'name': 'refrigerator'} ,
|
80 |
+
{'id': 40, 'name': 'belt'} ,
|
81 |
+
{'id': 349, 'name': 'starfish'} ,
|
82 |
+
{'id': 87, 'name': 'hanger'} ,
|
83 |
+
{'id': 116, 'name': 'baseball glove'} ,
|
84 |
+
{'id': 261, 'name': 'cherry'} ,
|
85 |
+
{'id': 334, 'name': 'baozi'} ,
|
86 |
+
{'id': 267, 'name': 'screwdriver'} ,
|
87 |
+
{'id': 158, 'name': 'converter'} ,
|
88 |
+
{'id': 335, 'name': 'lion'} ,
|
89 |
+
{'id': 170, 'name': 'baseball'} ,
|
90 |
+
{'id': 111, 'name': 'skis'} ,
|
91 |
+
{'id': 136, 'name': 'broccoli'} ,
|
92 |
+
{'id': 342, 'name': 'eraser'} ,
|
93 |
+
{'id': 337, 'name': 'polar bear'} ,
|
94 |
+
{'id': 139, 'name': 'shovel'} ,
|
95 |
+
{'id': 193, 'name': 'extension cord'} ,
|
96 |
+
{'id': 284, 'name': 'goldfish'} ,
|
97 |
+
{'id': 174, 'name': 'pepper'} ,
|
98 |
+
{'id': 138, 'name': 'stroller'} ,
|
99 |
+
{'id': 328, 'name': 'yak'} ,
|
100 |
+
{'id': 83, 'name': 'clock'} ,
|
101 |
+
{'id': 235, 'name': 'tricycle'} ,
|
102 |
+
{'id': 248, 'name': 'parking meter'} ,
|
103 |
+
{'id': 274, 'name': 'trophy'} ,
|
104 |
+
{'id': 324, 'name': 'binoculars'} ,
|
105 |
+
{'id': 51, 'name': 'traffic light'} ,
|
106 |
+
{'id': 314, 'name': 'donkey'} ,
|
107 |
+
{'id': 45, 'name': 'barrel/bucket'} ,
|
108 |
+
{'id': 292, 'name': 'pomegranate'} ,
|
109 |
+
{'id': 13, 'name': 'handbag'} ,
|
110 |
+
{'id': 262, 'name': 'tablet'} ,
|
111 |
+
{'id': 68, 'name': 'apple'} ,
|
112 |
+
{'id': 226, 'name': 'cabbage'} ,
|
113 |
+
{'id': 23, 'name': 'flower'} ,
|
114 |
+
{'id': 58, 'name': 'faucet'} ,
|
115 |
+
{'id': 206, 'name': 'tong'} ,
|
116 |
+
{'id': 291, 'name': 'trombone'} ,
|
117 |
+
{'id': 160, 'name': 'carrot'} ,
|
118 |
+
{'id': 172, 'name': 'bow tie'} ,
|
119 |
+
{'id': 122, 'name': 'tent'} ,
|
120 |
+
{'id': 163, 'name': 'cookies'} ,
|
121 |
+
{'id': 115, 'name': 'remote'} ,
|
122 |
+
{'id': 175, 'name': 'coffee machine'} ,
|
123 |
+
{'id': 238, 'name': 'green beans'} ,
|
124 |
+
{'id': 233, 'name': 'cello'} ,
|
125 |
+
{'id': 28, 'name': 'wine glass'} ,
|
126 |
+
{'id': 295, 'name': 'mushroom'} ,
|
127 |
+
{'id': 344, 'name': 'scallop'} ,
|
128 |
+
{'id': 125, 'name': 'lantern'} ,
|
129 |
+
{'id': 123, 'name': 'shampoo/shower gel'} ,
|
130 |
+
{'id': 285, 'name': 'meat balls'} ,
|
131 |
+
{'id': 266, 'name': 'key'} ,
|
132 |
+
{'id': 296, 'name': 'calculator'} ,
|
133 |
+
{'id': 168, 'name': 'scissors'} ,
|
134 |
+
{'id': 103, 'name': 'cymbal'} ,
|
135 |
+
{'id': 6, 'name': 'bottle'} ,
|
136 |
+
{'id': 264, 'name': 'nuts'} ,
|
137 |
+
{'id': 234, 'name': 'notepaper'} ,
|
138 |
+
{'id': 211, 'name': 'mango'} ,
|
139 |
+
{'id': 287, 'name': 'toothpaste'} ,
|
140 |
+
{'id': 196, 'name': 'chopsticks'} ,
|
141 |
+
{'id': 140, 'name': 'baseball bat'} ,
|
142 |
+
{'id': 244, 'name': 'hurdle'} ,
|
143 |
+
{'id': 195, 'name': 'tennis ball'} ,
|
144 |
+
{'id': 144, 'name': 'surveillance camera'} ,
|
145 |
+
{'id': 271, 'name': 'volleyball'} ,
|
146 |
+
{'id': 94, 'name': 'keyboard'} ,
|
147 |
+
{'id': 339, 'name': 'seal'} ,
|
148 |
+
{'id': 11, 'name': 'picture/frame'} ,
|
149 |
+
{'id': 348, 'name': 'okra'} ,
|
150 |
+
{'id': 191, 'name': 'sausage'} ,
|
151 |
+
{'id': 166, 'name': 'candy'} ,
|
152 |
+
{'id': 62, 'name': 'ring'} ,
|
153 |
+
{'id': 311, 'name': 'dolphin'} ,
|
154 |
+
{'id': 273, 'name': 'eggplant'} ,
|
155 |
+
{'id': 84, 'name': 'drum'} ,
|
156 |
+
{'id': 143, 'name': 'surfboard'} ,
|
157 |
+
{'id': 288, 'name': 'antelope'} ,
|
158 |
+
{'id': 204, 'name': 'clutch'} ,
|
159 |
+
{'id': 207, 'name': 'slide'} ,
|
160 |
+
{'id': 43, 'name': 'towel/napkin'} ,
|
161 |
+
{'id': 352, 'name': 'durian'} ,
|
162 |
+
{'id': 276, 'name': 'board eraser'} ,
|
163 |
+
{'id': 315, 'name': 'electric drill'} ,
|
164 |
+
{'id': 312, 'name': 'sushi'} ,
|
165 |
+
{'id': 198, 'name': 'pie'} ,
|
166 |
+
{'id': 106, 'name': 'pickup truck'} ,
|
167 |
+
{'id': 176, 'name': 'bathtub'} ,
|
168 |
+
{'id': 26, 'name': 'vase'} ,
|
169 |
+
{'id': 133, 'name': 'elephant'} ,
|
170 |
+
{'id': 256, 'name': 'sandwich'} ,
|
171 |
+
{'id': 327, 'name': 'noodles'} ,
|
172 |
+
{'id': 10, 'name': 'glasses'} ,
|
173 |
+
{'id': 109, 'name': 'airplane'} ,
|
174 |
+
{'id': 95, 'name': 'tripod'} ,
|
175 |
+
{'id': 247, 'name': 'CD'} ,
|
176 |
+
{'id': 121, 'name': 'machinery vehicle'} ,
|
177 |
+
{'id': 365, 'name': 'flashlight'} ,
|
178 |
+
{'id': 53, 'name': 'microphone'} ,
|
179 |
+
{'id': 270, 'name': 'pliers'} ,
|
180 |
+
{'id': 362, 'name': 'chainsaw'} ,
|
181 |
+
{'id': 259, 'name': 'bear'} ,
|
182 |
+
{'id': 197, 'name': 'electronic stove and gas stove'} ,
|
183 |
+
{'id': 89, 'name': 'pot/pan'} ,
|
184 |
+
{'id': 220, 'name': 'tape'} ,
|
185 |
+
{'id': 338, 'name': 'lighter'} ,
|
186 |
+
{'id': 177, 'name': 'snowboard'} ,
|
187 |
+
{'id': 214, 'name': 'violin'} ,
|
188 |
+
{'id': 217, 'name': 'chicken'} ,
|
189 |
+
{'id': 2, 'name': 'sneakers'} ,
|
190 |
+
{'id': 161, 'name': 'washing machine'} ,
|
191 |
+
{'id': 131, 'name': 'kite'} ,
|
192 |
+
{'id': 354, 'name': 'rabbit'} ,
|
193 |
+
{'id': 86, 'name': 'bus'} ,
|
194 |
+
{'id': 275, 'name': 'dates'} ,
|
195 |
+
{'id': 282, 'name': 'camel'} ,
|
196 |
+
{'id': 88, 'name': 'nightstand'} ,
|
197 |
+
{'id': 179, 'name': 'grapes'} ,
|
198 |
+
{'id': 229, 'name': 'pine apple'} ,
|
199 |
+
{'id': 56, 'name': 'necklace'} ,
|
200 |
+
{'id': 18, 'name': 'leather shoes'} ,
|
201 |
+
{'id': 358, 'name': 'hoverboard'} ,
|
202 |
+
{'id': 345, 'name': 'pencil case'} ,
|
203 |
+
{'id': 359, 'name': 'pasta'} ,
|
204 |
+
{'id': 157, 'name': 'radiator'} ,
|
205 |
+
{'id': 201, 'name': 'hamburger'} ,
|
206 |
+
{'id': 268, 'name': 'globe'} ,
|
207 |
+
{'id': 332, 'name': 'barbell'} ,
|
208 |
+
{'id': 329, 'name': 'mop'} ,
|
209 |
+
{'id': 252, 'name': 'horn'} ,
|
210 |
+
{'id': 350, 'name': 'eagle'} ,
|
211 |
+
{'id': 169, 'name': 'folder'} ,
|
212 |
+
{'id': 137, 'name': 'toilet'} ,
|
213 |
+
{'id': 5, 'name': 'lamp'} ,
|
214 |
+
{'id': 27, 'name': 'bench'} ,
|
215 |
+
{'id': 249, 'name': 'swan'} ,
|
216 |
+
{'id': 76, 'name': 'knife'} ,
|
217 |
+
{'id': 341, 'name': 'comb'} ,
|
218 |
+
{'id': 64, 'name': 'watch'} ,
|
219 |
+
{'id': 105, 'name': 'telephone'} ,
|
220 |
+
{'id': 3, 'name': 'chair'} ,
|
221 |
+
{'id': 33, 'name': 'boat'} ,
|
222 |
+
{'id': 107, 'name': 'orange'} ,
|
223 |
+
{'id': 60, 'name': 'bread'} ,
|
224 |
+
{'id': 147, 'name': 'cat'} ,
|
225 |
+
{'id': 135, 'name': 'gas stove'} ,
|
226 |
+
{'id': 307, 'name': 'papaya'} ,
|
227 |
+
{'id': 227, 'name': 'router/modem'} ,
|
228 |
+
{'id': 357, 'name': 'asparagus'} ,
|
229 |
+
{'id': 73, 'name': 'motorcycle'} ,
|
230 |
+
{'id': 77, 'name': 'traffic sign'} ,
|
231 |
+
{'id': 67, 'name': 'fish'} ,
|
232 |
+
{'id': 326, 'name': 'radish'} ,
|
233 |
+
{'id': 213, 'name': 'egg'} ,
|
234 |
+
{'id': 203, 'name': 'cucumber'} ,
|
235 |
+
{'id': 17, 'name': 'helmet'} ,
|
236 |
+
{'id': 110, 'name': 'luggage'} ,
|
237 |
+
{'id': 80, 'name': 'truck'} ,
|
238 |
+
{'id': 199, 'name': 'frisbee'} ,
|
239 |
+
{'id': 232, 'name': 'peach'} ,
|
240 |
+
{'id': 1, 'name': 'person'} ,
|
241 |
+
{'id': 29, 'name': 'boots'} ,
|
242 |
+
{'id': 310, 'name': 'chips'} ,
|
243 |
+
{'id': 142, 'name': 'skateboard'} ,
|
244 |
+
{'id': 44, 'name': 'slippers'} ,
|
245 |
+
{'id': 4, 'name': 'hat'} ,
|
246 |
+
{'id': 178, 'name': 'suitcase'} ,
|
247 |
+
{'id': 24, 'name': 'tv'} ,
|
248 |
+
{'id': 119, 'name': 'train'} ,
|
249 |
+
{'id': 82, 'name': 'power outlet'} ,
|
250 |
+
{'id': 245, 'name': 'swing'} ,
|
251 |
+
{'id': 15, 'name': 'book'} ,
|
252 |
+
{'id': 294, 'name': 'jellyfish'} ,
|
253 |
+
{'id': 192, 'name': 'fire extinguisher'} ,
|
254 |
+
{'id': 212, 'name': 'deer'} ,
|
255 |
+
{'id': 181, 'name': 'pear'} ,
|
256 |
+
{'id': 347, 'name': 'table tennis paddle'} ,
|
257 |
+
{'id': 113, 'name': 'trolley'} ,
|
258 |
+
{'id': 91, 'name': 'guitar'} ,
|
259 |
+
{'id': 202, 'name': 'golf club'} ,
|
260 |
+
{'id': 221, 'name': 'wheelchair'} ,
|
261 |
+
{'id': 254, 'name': 'saxophone'} ,
|
262 |
+
{'id': 117, 'name': 'paper towel'} ,
|
263 |
+
{'id': 303, 'name': 'race car'} ,
|
264 |
+
{'id': 240, 'name': 'carriage'} ,
|
265 |
+
{'id': 246, 'name': 'radio'} ,
|
266 |
+
{'id': 318, 'name': 'parrot'} ,
|
267 |
+
{'id': 251, 'name': 'french fries'} ,
|
268 |
+
{'id': 98, 'name': 'dog'} ,
|
269 |
+
{'id': 112, 'name': 'soccer'} ,
|
270 |
+
{'id': 355, 'name': 'french horn'} ,
|
271 |
+
{'id': 79, 'name': 'paddle'} ,
|
272 |
+
{'id': 283, 'name': 'lettuce'} ,
|
273 |
+
{'id': 9, 'name': 'car'} ,
|
274 |
+
{'id': 258, 'name': 'kiwi fruit'} ,
|
275 |
+
{'id': 325, 'name': 'llama'} ,
|
276 |
+
{'id': 187, 'name': 'billiards'} ,
|
277 |
+
{'id': 210, 'name': 'facial cleanser'} ,
|
278 |
+
{'id': 81, 'name': 'cow'} ,
|
279 |
+
{'id': 331, 'name': 'microscope'} ,
|
280 |
+
{'id': 148, 'name': 'lemon'} ,
|
281 |
+
{'id': 302, 'name': 'pomelo'} ,
|
282 |
+
{'id': 85, 'name': 'fork'} ,
|
283 |
+
{'id': 154, 'name': 'pumpkin'} ,
|
284 |
+
{'id': 289, 'name': 'shrimp'} ,
|
285 |
+
{'id': 71, 'name': 'teddy bear'} ,
|
286 |
+
{'id': 184, 'name': 'potato'} ,
|
287 |
+
{'id': 102, 'name': 'air conditioner'} ,
|
288 |
+
{'id': 208, 'name': 'hot dog'} ,
|
289 |
+
{'id': 222, 'name': 'plum'} ,
|
290 |
+
{'id': 316, 'name': 'spring rolls'} ,
|
291 |
+
{'id': 230, 'name': 'crane'} ,
|
292 |
+
{'id': 149, 'name': 'liquid soap'} ,
|
293 |
+
{'id': 55, 'name': 'canned'} ,
|
294 |
+
{'id': 35, 'name': 'speaker'} ,
|
295 |
+
{'id': 108, 'name': 'banana'} ,
|
296 |
+
{'id': 297, 'name': 'treadmill'} ,
|
297 |
+
{'id': 99, 'name': 'spoon'} ,
|
298 |
+
{'id': 104, 'name': 'mouse'} ,
|
299 |
+
{'id': 182, 'name': 'american football'} ,
|
300 |
+
{'id': 299, 'name': 'egg tart'} ,
|
301 |
+
{'id': 127, 'name': 'cleaning products'} ,
|
302 |
+
{'id': 313, 'name': 'urinal'} ,
|
303 |
+
{'id': 286, 'name': 'medal'} ,
|
304 |
+
{'id': 239, 'name': 'brush'} ,
|
305 |
+
{'id': 96, 'name': 'hockey'} ,
|
306 |
+
{'id': 279, 'name': 'dumbbell'} ,
|
307 |
+
{'id': 32, 'name': 'umbrella'} ,
|
308 |
+
{'id': 272, 'name': 'hammer'} ,
|
309 |
+
{'id': 16, 'name': 'plate'} ,
|
310 |
+
{'id': 21, 'name': 'potted plant'} ,
|
311 |
+
{'id': 242, 'name': 'earphone'} ,
|
312 |
+
{'id': 70, 'name': 'candle'} ,
|
313 |
+
{'id': 185, 'name': 'paint brush'} ,
|
314 |
+
{'id': 48, 'name': 'toy'} ,
|
315 |
+
{'id': 130, 'name': 'pizza'} ,
|
316 |
+
{'id': 255, 'name': 'trumpet'} ,
|
317 |
+
{'id': 361, 'name': 'hotair balloon'} ,
|
318 |
+
{'id': 188, 'name': 'fire hydrant'} ,
|
319 |
+
{'id': 50, 'name': 'bed'} ,
|
320 |
+
{'id': 253, 'name': 'avocado'} ,
|
321 |
+
{'id': 293, 'name': 'coconut'} ,
|
322 |
+
{'id': 257, 'name': 'cue'} ,
|
323 |
+
{'id': 280, 'name': 'hamimelon'} ,
|
324 |
+
{'id': 66, 'name': 'horse'} ,
|
325 |
+
{'id': 173, 'name': 'pigeon'} ,
|
326 |
+
{'id': 190, 'name': 'projector'} ,
|
327 |
+
{'id': 69, 'name': 'camera'} ,
|
328 |
+
{'id': 30, 'name': 'bowl'} ,
|
329 |
+
{'id': 269, 'name': 'broom'} ,
|
330 |
+
{'id': 343, 'name': 'pitaya'} ,
|
331 |
+
{'id': 305, 'name': 'tuba'} ,
|
332 |
+
{'id': 309, 'name': 'green onion'} ,
|
333 |
+
{'id': 363, 'name': 'lobster'} ,
|
334 |
+
{'id': 225, 'name': 'watermelon'} ,
|
335 |
+
{'id': 47, 'name': 'suv'} ,
|
336 |
+
{'id': 31, 'name': 'dining table'} ,
|
337 |
+
{'id': 54, 'name': 'sandals'} ,
|
338 |
+
{'id': 351, 'name': 'monkey'} ,
|
339 |
+
{'id': 218, 'name': 'onion'} ,
|
340 |
+
{'id': 36, 'name': 'trash bin/can'} ,
|
341 |
+
{'id': 20, 'name': 'glove'} ,
|
342 |
+
{'id': 277, 'name': 'rice'} ,
|
343 |
+
{'id': 152, 'name': 'sports car'} ,
|
344 |
+
{'id': 360, 'name': 'target'} ,
|
345 |
+
{'id': 205, 'name': 'blender'} ,
|
346 |
+
{'id': 19, 'name': 'pillow'} ,
|
347 |
+
{'id': 72, 'name': 'cake'} ,
|
348 |
+
{'id': 93, 'name': 'tea pot'} ,
|
349 |
+
{'id': 353, 'name': 'game board'} ,
|
350 |
+
{'id': 38, 'name': 'backpack'} ,
|
351 |
+
{'id': 356, 'name': 'ambulance'} ,
|
352 |
+
{'id': 146, 'name': 'life saver'} ,
|
353 |
+
{'id': 189, 'name': 'goose'} ,
|
354 |
+
{'id': 278, 'name': 'tape measure/ruler'} ,
|
355 |
+
{'id': 92, 'name': 'traffic cone'} ,
|
356 |
+
{'id': 134, 'name': 'toiletries'} ,
|
357 |
+
{'id': 114, 'name': 'oven'} ,
|
358 |
+
{'id': 317, 'name': 'tortoise/turtle'} ,
|
359 |
+
{'id': 265, 'name': 'corn'} ,
|
360 |
+
{'id': 126, 'name': 'donut'} ,
|
361 |
+
{'id': 57, 'name': 'mirror'} ,
|
362 |
+
{'id': 7, 'name': 'cabinet/shelf'} ,
|
363 |
+
{'id': 263, 'name': 'green vegetables'} ,
|
364 |
+
{'id': 159, 'name': 'tissue '} ,
|
365 |
+
{'id': 321, 'name': 'shark'} ,
|
366 |
+
{'id': 301, 'name': 'pig'} ,
|
367 |
+
{'id': 41, 'name': 'carpet'} ,
|
368 |
+
{'id': 304, 'name': 'rice cooker'} ,
|
369 |
+
{'id': 323, 'name': 'poker card'} ,
|
370 |
+
]
|
371 |
+
|
372 |
+
def _get_builtin_metadata():
|
373 |
+
id_to_name = {x['id']: x['name'] for x in categories}
|
374 |
+
thing_dataset_id_to_contiguous_id = {i + 1: i for i in range(365)}
|
375 |
+
thing_classes = [id_to_name[k] for k in sorted(id_to_name)]
|
376 |
+
return {
|
377 |
+
"thing_dataset_id_to_contiguous_id": thing_dataset_id_to_contiguous_id,
|
378 |
+
"thing_classes": thing_classes}
|
379 |
+
|
380 |
+
_PREDEFINED_SPLITS_OBJECTS365 = {
|
381 |
+
"objects365_train": ("objects365/train", "objects365/objects365_train.json"),
|
382 |
+
"objects365_val": ("objects365/val", "objects365/objects365_val.json"),
|
383 |
+
}
|
384 |
+
|
385 |
+
for key, (image_root, json_file) in _PREDEFINED_SPLITS_OBJECTS365.items():
|
386 |
+
register_coco_instances(
|
387 |
+
key,
|
388 |
+
_get_builtin_metadata(),
|
389 |
+
os.path.join("datasets", json_file) if "://" not in json_file else json_file,
|
390 |
+
os.path.join("datasets", image_root),
|
391 |
+
)
|
regionspot/data/openimages.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from detectron2.data.datasets.register_coco import register_coco_instances
|
2 |
+
import os
|
3 |
+
from .openimages_categories import categories
|
4 |
+
|
5 |
+
def _get_builtin_metadata(categories):
|
6 |
+
id_to_name = {x['id']: x['name'] for x in categories}
|
7 |
+
thing_dataset_id_to_contiguous_id = {i + 1: i for i in range(len(categories))}
|
8 |
+
thing_classes = [id_to_name[k] for k in sorted(id_to_name)]
|
9 |
+
|
10 |
+
return {
|
11 |
+
"thing_dataset_id_to_contiguous_id": thing_dataset_id_to_contiguous_id,
|
12 |
+
"thing_classes": thing_classes}
|
13 |
+
|
14 |
+
def _get_builtin_metadata():
|
15 |
+
id_to_name = {x['id']: x['name'] for x in categories}
|
16 |
+
thing_dataset_id_to_contiguous_id = {i + 1: i for i in range(len(categories))}
|
17 |
+
thing_classes = [id_to_name[k] for k in sorted(id_to_name)]
|
18 |
+
return {
|
19 |
+
"thing_dataset_id_to_contiguous_id": thing_dataset_id_to_contiguous_id,
|
20 |
+
"thing_classes": thing_classes}
|
21 |
+
|
22 |
+
|
23 |
+
_PREDEFINED_SPLITS_OPENIMAGES = {
|
24 |
+
"openimages_train": ("openimages/detection/", "re_openimages_v6_train_bbox_splitdir_int_ids.json"),
|
25 |
+
"openimages_val": ("openimages/detection/", "re_openimages_v6_train_bbox_splitdir_int_ids.json"),
|
26 |
+
}
|
27 |
+
|
28 |
+
for key, (image_root, json_file) in _PREDEFINED_SPLITS_OPENIMAGES.items():
|
29 |
+
register_coco_instances(
|
30 |
+
key,
|
31 |
+
_get_builtin_metadata(),
|
32 |
+
os.path.join("datasets", json_file) if "://" not in json_file else json_file,
|
33 |
+
os.path.join("datasets", image_root),
|
34 |
+
)
|
regionspot/data/openimages_categories.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
categories = [{'id': 1, 'name': 'Tortoise', 'freebase_id': '/m/011k07'}, {'id': 2, 'name': 'Container', 'freebase_id': '/m/011q46kg'}, {'id': 3, 'name': 'Magpie', 'freebase_id': '/m/012074'}, {'id': 4, 'name': 'Sea turtle', 'freebase_id': '/m/0120dh'}, {'id': 5, 'name': 'Football', 'freebase_id': '/m/01226z'}, {'id': 6, 'name': 'Ambulance', 'freebase_id': '/m/012n7d'}, {'id': 7, 'name': 'Ladder', 'freebase_id': '/m/012w5l'}, {'id': 8, 'name': 'Toothbrush', 'freebase_id': '/m/012xff'}, {'id': 9, 'name': 'Syringe', 'freebase_id': '/m/012ysf'}, {'id': 10, 'name': 'Sink', 'freebase_id': '/m/0130jx'}, {'id': 11, 'name': 'Toy', 'freebase_id': '/m/0138tl'}, {'id': 12, 'name': 'Organ (Musical Instrument)', 'freebase_id': '/m/013y1f'}, {'id': 13, 'name': 'Cassette deck', 'freebase_id': '/m/01432t'}, {'id': 14, 'name': 'Apple', 'freebase_id': '/m/014j1m'}, {'id': 15, 'name': 'Human eye', 'freebase_id': '/m/014sv8'}, {'id': 16, 'name': 'Cosmetics', 'freebase_id': '/m/014trl'}, {'id': 17, 'name': 'Paddle', 'freebase_id': '/m/014y4n'}, {'id': 18, 'name': 'Snowman', 'freebase_id': '/m/0152hh'}, {'id': 19, 'name': 'Beer', 'freebase_id': '/m/01599'}, {'id': 20, 'name': 'Chopsticks', 'freebase_id': '/m/01_5g'}, {'id': 21, 'name': 'Human beard', 'freebase_id': '/m/015h_t'}, {'id': 22, 'name': 'Bird', 'freebase_id': '/m/015p6'}, {'id': 23, 'name': 'Parking meter', 'freebase_id': '/m/015qbp'}, {'id': 24, 'name': 'Traffic light', 'freebase_id': '/m/015qff'}, {'id': 25, 'name': 'Croissant', 'freebase_id': '/m/015wgc'}, {'id': 26, 'name': 'Cucumber', 'freebase_id': '/m/015x4r'}, {'id': 27, 'name': 'Radish', 'freebase_id': '/m/015x5n'}, {'id': 28, 'name': 'Towel', 'freebase_id': '/m/0162_1'}, {'id': 29, 'name': 'Doll', 'freebase_id': '/m/0167gd'}, {'id': 30, 'name': 'Skull', 'freebase_id': '/m/016m2d'}, {'id': 31, 'name': 'Washing machine', 'freebase_id': '/m/0174k2'}, {'id': 32, 'name': 'Glove', 'freebase_id': '/m/0174n1'}, {'id': 33, 'name': 'Tick', 'freebase_id': '/m/0175cv'}, {'id': 34, 'name': 'Belt', 'freebase_id': '/m/0176mf'}, {'id': 35, 'name': 'Sunglasses', 'freebase_id': '/m/017ftj'}, {'id': 36, 'name': 'Banjo', 'freebase_id': '/m/018j2'}, {'id': 37, 'name': 'Cart', 'freebase_id': '/m/018p4k'}, {'id': 38, 'name': 'Ball', 'freebase_id': '/m/018xm'}, {'id': 39, 'name': 'Backpack', 'freebase_id': '/m/01940j'}, {'id': 40, 'name': 'Bicycle', 'freebase_id': '/m/0199g'}, {'id': 41, 'name': 'Home appliance', 'freebase_id': '/m/019dx1'}, {'id': 42, 'name': 'Centipede', 'freebase_id': '/m/019h78'}, {'id': 43, 'name': 'Boat', 'freebase_id': '/m/019jd'}, {'id': 44, 'name': 'Surfboard', 'freebase_id': '/m/019w40'}, {'id': 45, 'name': 'Boot', 'freebase_id': '/m/01b638'}, {'id': 46, 'name': 'Headphones', 'freebase_id': '/m/01b7fy'}, {'id': 47, 'name': 'Hot dog', 'freebase_id': '/m/01b9xk'}, {'id': 48, 'name': 'Shorts', 'freebase_id': '/m/01bfm9'}, {'id': 49, 'name': 'Fast food', 'freebase_id': '/m/01_bhs'}, {'id': 50, 'name': 'Bus', 'freebase_id': '/m/01bjv'}, {'id': 51, 'name': 'Boy', 'freebase_id': '/m/01bl7v'}, {'id': 52, 'name': 'Screwdriver', 'freebase_id': '/m/01bms0'}, {'id': 53, 'name': 'Bicycle wheel', 'freebase_id': '/m/01bqk0'}, {'id': 54, 'name': 'Barge', 'freebase_id': '/m/01btn'}, {'id': 55, 'name': 'Laptop', 'freebase_id': '/m/01c648'}, {'id': 56, 'name': 'Miniskirt', 'freebase_id': '/m/01cmb2'}, {'id': 57, 'name': 'Drill (Tool)', 'freebase_id': '/m/01d380'}, {'id': 58, 'name': 'Dress', 'freebase_id': '/m/01d40f'}, {'id': 59, 'name': 'Bear', 'freebase_id': '/m/01dws'}, {'id': 60, 'name': 'Waffle', 'freebase_id': '/m/01dwsz'}, {'id': 61, 'name': 'Pancake', 'freebase_id': '/m/01dwwc'}, {'id': 62, 'name': 'Brown bear', 'freebase_id': '/m/01dxs'}, {'id': 63, 'name': 'Woodpecker', 'freebase_id': '/m/01dy8n'}, {'id': 64, 'name': 'Blue jay', 'freebase_id': '/m/01f8m5'}, {'id': 65, 'name': 'Pretzel', 'freebase_id': '/m/01f91_'}, {'id': 66, 'name': 'Bagel', 'freebase_id': '/m/01fb_0'}, {'id': 67, 'name': 'Tower', 'freebase_id': '/m/01fdzj'}, {'id': 68, 'name': 'Teapot', 'freebase_id': '/m/01fh4r'}, {'id': 69, 'name': 'Person', 'freebase_id': '/m/01g317'}, {'id': 70, 'name': 'Bow and arrow', 'freebase_id': '/m/01g3x7'}, {'id': 71, 'name': 'Swimwear', 'freebase_id': '/m/01gkx_'}, {'id': 72, 'name': 'Beehive', 'freebase_id': '/m/01gllr'}, {'id': 73, 'name': 'Brassiere', 'freebase_id': '/m/01gmv2'}, {'id': 74, 'name': 'Bee', 'freebase_id': '/m/01h3n'}, {'id': 75, 'name': 'Bat (Animal)', 'freebase_id': '/m/01h44'}, {'id': 76, 'name': 'Starfish', 'freebase_id': '/m/01h8tj'}, {'id': 77, 'name': 'Popcorn', 'freebase_id': '/m/01hrv5'}, {'id': 78, 'name': 'Burrito', 'freebase_id': '/m/01j3zr'}, {'id': 79, 'name': 'Chainsaw', 'freebase_id': '/m/01j4z9'}, {'id': 80, 'name': 'Balloon', 'freebase_id': '/m/01j51'}, {'id': 81, 'name': 'Wrench', 'freebase_id': '/m/01j5ks'}, {'id': 82, 'name': 'Tent', 'freebase_id': '/m/01j61q'}, {'id': 83, 'name': 'Vehicle registration plate', 'freebase_id': '/m/01jfm_'}, {'id': 84, 'name': 'Lantern', 'freebase_id': '/m/01jfsr'}, {'id': 85, 'name': 'Toaster', 'freebase_id': '/m/01k6s3'}, {'id': 86, 'name': 'Flashlight', 'freebase_id': '/m/01kb5b'}, {'id': 87, 'name': 'Billboard', 'freebase_id': '/m/01knjb'}, {'id': 88, 'name': 'Tiara', 'freebase_id': '/m/01krhy'}, {'id': 89, 'name': 'Limousine', 'freebase_id': '/m/01lcw4'}, {'id': 90, 'name': 'Necklace', 'freebase_id': '/m/01llwg'}, {'id': 91, 'name': 'Carnivore', 'freebase_id': '/m/01lrl'}, {'id': 92, 'name': 'Scissors', 'freebase_id': '/m/01lsmm'}, {'id': 93, 'name': 'Stairs', 'freebase_id': '/m/01lynh'}, {'id': 94, 'name': 'Computer keyboard', 'freebase_id': '/m/01m2v'}, {'id': 95, 'name': 'Printer', 'freebase_id': '/m/01m4t'}, {'id': 96, 'name': 'Traffic sign', 'freebase_id': '/m/01mqdt'}, {'id': 97, 'name': 'Chair', 'freebase_id': '/m/01mzpv'}, {'id': 98, 'name': 'Shirt', 'freebase_id': '/m/01n4qj'}, {'id': 99, 'name': 'Poster', 'freebase_id': '/m/01n5jq'}, {'id': 100, 'name': 'Cheese', 'freebase_id': '/m/01nkt'}, {'id': 101, 'name': 'Sock', 'freebase_id': '/m/01nq26'}, {'id': 102, 'name': 'Fire hydrant', 'freebase_id': '/m/01pns0'}, {'id': 103, 'name': 'Land vehicle', 'freebase_id': '/m/01prls'}, {'id': 104, 'name': 'Earrings', 'freebase_id': '/m/01r546'}, {'id': 105, 'name': 'Tie', 'freebase_id': '/m/01rkbr'}, {'id': 106, 'name': 'Watercraft', 'freebase_id': '/m/01rzcn'}, {'id': 107, 'name': 'Cabinetry', 'freebase_id': '/m/01s105'}, {'id': 108, 'name': 'Suitcase', 'freebase_id': '/m/01s55n'}, {'id': 109, 'name': 'Muffin', 'freebase_id': '/m/01tcjp'}, {'id': 110, 'name': 'Bidet', 'freebase_id': '/m/01vbnl'}, {'id': 111, 'name': 'Snack', 'freebase_id': '/m/01ww8y'}, {'id': 112, 'name': 'Snowmobile', 'freebase_id': '/m/01x3jk'}, {'id': 113, 'name': 'Clock', 'freebase_id': '/m/01x3z'}, {'id': 114, 'name': 'Medical equipment', 'freebase_id': '/m/01xgg_'}, {'id': 115, 'name': 'Cattle', 'freebase_id': '/m/01xq0k1'}, {'id': 116, 'name': 'Cello', 'freebase_id': '/m/01xqw'}, {'id': 117, 'name': 'Jet ski', 'freebase_id': '/m/01xs3r'}, {'id': 118, 'name': 'Camel', 'freebase_id': '/m/01x_v'}, {'id': 119, 'name': 'Coat', 'freebase_id': '/m/01xygc'}, {'id': 120, 'name': 'Suit', 'freebase_id': '/m/01xyhv'}, {'id': 121, 'name': 'Desk', 'freebase_id': '/m/01y9k5'}, {'id': 122, 'name': 'Cat', 'freebase_id': '/m/01yrx'}, {'id': 123, 'name': 'Bronze sculpture', 'freebase_id': '/m/01yx86'}, {'id': 124, 'name': 'Juice', 'freebase_id': '/m/01z1kdw'}, {'id': 125, 'name': 'Gondola', 'freebase_id': '/m/02068x'}, {'id': 126, 'name': 'Beetle', 'freebase_id': '/m/020jm'}, {'id': 127, 'name': 'Cannon', 'freebase_id': '/m/020kz'}, {'id': 128, 'name': 'Computer mouse', 'freebase_id': '/m/020lf'}, {'id': 129, 'name': 'Cookie', 'freebase_id': '/m/021mn'}, {'id': 130, 'name': 'Office building', 'freebase_id': '/m/021sj1'}, {'id': 131, 'name': 'Fountain', 'freebase_id': '/m/0220r2'}, {'id': 132, 'name': 'Coin', 'freebase_id': '/m/0242l'}, {'id': 133, 'name': 'Calculator', 'freebase_id': '/m/024d2'}, {'id': 134, 'name': 'Cocktail', 'freebase_id': '/m/024g6'}, {'id': 135, 'name': 'Computer monitor', 'freebase_id': '/m/02522'}, {'id': 136, 'name': 'Box', 'freebase_id': '/m/025dyy'}, {'id': 137, 'name': 'Stapler', 'freebase_id': '/m/025fsf'}, {'id': 138, 'name': 'Christmas tree', 'freebase_id': '/m/025nd'}, {'id': 139, 'name': 'Cowboy hat', 'freebase_id': '/m/025rp__'}, {'id': 140, 'name': 'Hiking equipment', 'freebase_id': '/m/0268lbt'}, {'id': 141, 'name': 'Studio couch', 'freebase_id': '/m/026qbn5'}, {'id': 142, 'name': 'Drum', 'freebase_id': '/m/026t6'}, {'id': 143, 'name': 'Dessert', 'freebase_id': '/m/0270h'}, {'id': 144, 'name': 'Wine rack', 'freebase_id': '/m/0271qf7'}, {'id': 145, 'name': 'Drink', 'freebase_id': '/m/0271t'}, {'id': 146, 'name': 'Zucchini', 'freebase_id': '/m/027pcv'}, {'id': 147, 'name': 'Ladle', 'freebase_id': '/m/027rl48'}, {'id': 148, 'name': 'Human mouth', 'freebase_id': '/m/0283dt1'}, {'id': 149, 'name': 'Dairy Product', 'freebase_id': '/m/0284d'}, {'id': 150, 'name': 'Dice', 'freebase_id': '/m/029b3'}, {'id': 151, 'name': 'Oven', 'freebase_id': '/m/029bxz'}, {'id': 152, 'name': 'Dinosaur', 'freebase_id': '/m/029tx'}, {'id': 153, 'name': 'Ratchet (Device)', 'freebase_id': '/m/02bm9n'}, {'id': 154, 'name': 'Couch', 'freebase_id': '/m/02crq1'}, {'id': 155, 'name': 'Cricket ball', 'freebase_id': '/m/02ctlc'}, {'id': 156, 'name': 'Winter melon', 'freebase_id': '/m/02cvgx'}, {'id': 157, 'name': 'Spatula', 'freebase_id': '/m/02d1br'}, {'id': 158, 'name': 'Whiteboard', 'freebase_id': '/m/02d9qx'}, {'id': 159, 'name': 'Pencil sharpener', 'freebase_id': '/m/02ddwp'}, {'id': 160, 'name': 'Door', 'freebase_id': '/m/02dgv'}, {'id': 161, 'name': 'Hat', 'freebase_id': '/m/02dl1y'}, {'id': 162, 'name': 'Shower', 'freebase_id': '/m/02f9f_'}, {'id': 163, 'name': 'Eraser', 'freebase_id': '/m/02fh7f'}, {'id': 164, 'name': 'Fedora', 'freebase_id': '/m/02fq_6'}, {'id': 165, 'name': 'Guacamole', 'freebase_id': '/m/02g30s'}, {'id': 166, 'name': 'Dagger', 'freebase_id': '/m/02gzp'}, {'id': 167, 'name': 'Scarf', 'freebase_id': '/m/02h19r'}, {'id': 168, 'name': 'Dolphin', 'freebase_id': '/m/02hj4'}, {'id': 169, 'name': 'Sombrero', 'freebase_id': '/m/02jfl0'}, {'id': 170, 'name': 'Tin can', 'freebase_id': '/m/02jnhm'}, {'id': 171, 'name': 'Mug', 'freebase_id': '/m/02jvh9'}, {'id': 172, 'name': 'Tap', 'freebase_id': '/m/02jz0l'}, {'id': 173, 'name': 'Harbor seal', 'freebase_id': '/m/02l8p9'}, {'id': 174, 'name': 'Stretcher', 'freebase_id': '/m/02lbcq'}, {'id': 175, 'name': 'Can opener', 'freebase_id': '/m/02mqfb'}, {'id': 176, 'name': 'Goggles', 'freebase_id': '/m/02_n6y'}, {'id': 177, 'name': 'Human body', 'freebase_id': '/m/02p0tk3'}, {'id': 178, 'name': 'Roller skates', 'freebase_id': '/m/02p3w7d'}, {'id': 179, 'name': 'Coffee cup', 'freebase_id': '/m/02p5f1q'}, {'id': 180, 'name': 'Cutting board', 'freebase_id': '/m/02pdsw'}, {'id': 181, 'name': 'Blender', 'freebase_id': '/m/02pjr4'}, {'id': 182, 'name': 'Plumbing fixture', 'freebase_id': '/m/02pkr5'}, {'id': 183, 'name': 'Stop sign', 'freebase_id': '/m/02pv19'}, {'id': 184, 'name': 'Office supplies', 'freebase_id': '/m/02rdsp'}, {'id': 185, 'name': 'Volleyball (Ball)', 'freebase_id': '/m/02rgn06'}, {'id': 186, 'name': 'Vase', 'freebase_id': '/m/02s195'}, {'id': 187, 'name': 'Slow cooker', 'freebase_id': '/m/02tsc9'}, {'id': 188, 'name': 'Wardrobe', 'freebase_id': '/m/02vkqh8'}, {'id': 189, 'name': 'Coffee', 'freebase_id': '/m/02vqfm'}, {'id': 190, 'name': 'Whisk', 'freebase_id': '/m/02vwcm'}, {'id': 191, 'name': 'Paper towel', 'freebase_id': '/m/02w3r3'}, {'id': 192, 'name': 'Personal care', 'freebase_id': '/m/02w3_ws'}, {'id': 193, 'name': 'Food', 'freebase_id': '/m/02wbm'}, {'id': 194, 'name': 'Sun hat', 'freebase_id': '/m/02wbtzl'}, {'id': 195, 'name': 'Tree house', 'freebase_id': '/m/02wg_p'}, {'id': 196, 'name': 'Flying disc', 'freebase_id': '/m/02wmf'}, {'id': 197, 'name': 'Skirt', 'freebase_id': '/m/02wv6h6'}, {'id': 198, 'name': 'Gas stove', 'freebase_id': '/m/02wv84t'}, {'id': 199, 'name': 'Salt and pepper shakers', 'freebase_id': '/m/02x8cch'}, {'id': 200, 'name': 'Mechanical fan', 'freebase_id': '/m/02x984l'}, {'id': 201, 'name': 'Face powder', 'freebase_id': '/m/02xb7qb'}, {'id': 202, 'name': 'Fax', 'freebase_id': '/m/02xqq'}, {'id': 203, 'name': 'Fruit', 'freebase_id': '/m/02xwb'}, {'id': 204, 'name': 'French fries', 'freebase_id': '/m/02y6n'}, {'id': 205, 'name': 'Nightstand', 'freebase_id': '/m/02z51p'}, {'id': 206, 'name': 'Barrel', 'freebase_id': '/m/02zn6n'}, {'id': 207, 'name': 'Kite', 'freebase_id': '/m/02zt3'}, {'id': 208, 'name': 'Tart', 'freebase_id': '/m/02zvsm'}, {'id': 209, 'name': 'Treadmill', 'freebase_id': '/m/030610'}, {'id': 210, 'name': 'Fox', 'freebase_id': '/m/0306r'}, {'id': 211, 'name': 'Flag', 'freebase_id': '/m/03120'}, {'id': 212, 'name': 'French horn', 'freebase_id': '/m/0319l'}, {'id': 213, 'name': 'Window blind', 'freebase_id': '/m/031b6r'}, {'id': 214, 'name': 'Human foot', 'freebase_id': '/m/031n1'}, {'id': 215, 'name': 'Golf cart', 'freebase_id': '/m/0323sq'}, {'id': 216, 'name': 'Jacket', 'freebase_id': '/m/032b3c'}, {'id': 217, 'name': 'Egg (Food)', 'freebase_id': '/m/033cnk'}, {'id': 218, 'name': 'Street light', 'freebase_id': '/m/033rq4'}, {'id': 219, 'name': 'Guitar', 'freebase_id': '/m/0342h'}, {'id': 220, 'name': 'Pillow', 'freebase_id': '/m/034c16'}, {'id': 221, 'name': 'Human leg', 'freebase_id': '/m/035r7c'}, {'id': 222, 'name': 'Isopod', 'freebase_id': '/m/035vxb'}, {'id': 223, 'name': 'Grape', 'freebase_id': '/m/0388q'}, {'id': 224, 'name': 'Human ear', 'freebase_id': '/m/039xj_'}, {'id': 225, 'name': 'Power plugs and sockets', 'freebase_id': '/m/03bbps'}, {'id': 226, 'name': 'Panda', 'freebase_id': '/m/03bj1'}, {'id': 227, 'name': 'Giraffe', 'freebase_id': '/m/03bk1'}, {'id': 228, 'name': 'Woman', 'freebase_id': '/m/03bt1vf'}, {'id': 229, 'name': 'Door handle', 'freebase_id': '/m/03c7gz'}, {'id': 230, 'name': 'Rhinoceros', 'freebase_id': '/m/03d443'}, {'id': 231, 'name': 'Bathtub', 'freebase_id': '/m/03dnzn'}, {'id': 232, 'name': 'Goldfish', 'freebase_id': '/m/03fj2'}, {'id': 233, 'name': 'Houseplant', 'freebase_id': '/m/03fp41'}, {'id': 234, 'name': 'Goat', 'freebase_id': '/m/03fwl'}, {'id': 235, 'name': 'Baseball bat', 'freebase_id': '/m/03g8mr'}, {'id': 236, 'name': 'Baseball glove', 'freebase_id': '/m/03grzl'}, {'id': 237, 'name': 'Mixing bowl', 'freebase_id': '/m/03hj559'}, {'id': 238, 'name': 'Marine invertebrates', 'freebase_id': '/m/03hl4l9'}, {'id': 239, 'name': 'Kitchen utensil', 'freebase_id': '/m/03hlz0c'}, {'id': 240, 'name': 'Light switch', 'freebase_id': '/m/03jbxj'}, {'id': 241, 'name': 'House', 'freebase_id': '/m/03jm5'}, {'id': 242, 'name': 'Horse', 'freebase_id': '/m/03k3r'}, {'id': 243, 'name': 'Stationary bicycle', 'freebase_id': '/m/03kt2w'}, {'id': 244, 'name': 'Hammer', 'freebase_id': '/m/03l9g'}, {'id': 245, 'name': 'Ceiling fan', 'freebase_id': '/m/03ldnb'}, {'id': 246, 'name': 'Sofa bed', 'freebase_id': '/m/03m3pdh'}, {'id': 247, 'name': 'Adhesive tape', 'freebase_id': '/m/03m3vtv'}, {'id': 248, 'name': 'Harp', 'freebase_id': '/m/03m5k'}, {'id': 249, 'name': 'Sandal', 'freebase_id': '/m/03nfch'}, {'id': 250, 'name': 'Bicycle helmet', 'freebase_id': '/m/03p3bw'}, {'id': 251, 'name': 'Saucer', 'freebase_id': '/m/03q5c7'}, {'id': 252, 'name': 'Harpsichord', 'freebase_id': '/m/03q5t'}, {'id': 253, 'name': 'Human hair', 'freebase_id': '/m/03q69'}, {'id': 254, 'name': 'Heater', 'freebase_id': '/m/03qhv5'}, {'id': 255, 'name': 'Harmonica', 'freebase_id': '/m/03qjg'}, {'id': 256, 'name': 'Hamster', 'freebase_id': '/m/03qrc'}, {'id': 257, 'name': 'Curtain', 'freebase_id': '/m/03rszm'}, {'id': 258, 'name': 'Bed', 'freebase_id': '/m/03ssj5'}, {'id': 259, 'name': 'Kettle', 'freebase_id': '/m/03s_tn'}, {'id': 260, 'name': 'Fireplace', 'freebase_id': '/m/03tw93'}, {'id': 261, 'name': 'Scale', 'freebase_id': '/m/03txqz'}, {'id': 262, 'name': 'Drinking straw', 'freebase_id': '/m/03v5tg'}, {'id': 263, 'name': 'Insect', 'freebase_id': '/m/03vt0'}, {'id': 264, 'name': 'Hair dryer', 'freebase_id': '/m/03wvsk'}, {'id': 265, 'name': 'Kitchenware', 'freebase_id': '/m/03_wxk'}, {'id': 266, 'name': 'Indoor rower', 'freebase_id': '/m/03wym'}, {'id': 267, 'name': 'Invertebrate', 'freebase_id': '/m/03xxp'}, {'id': 268, 'name': 'Food processor', 'freebase_id': '/m/03y6mg'}, {'id': 269, 'name': 'Bookcase', 'freebase_id': '/m/03__z0'}, {'id': 270, 'name': 'Refrigerator', 'freebase_id': '/m/040b_t'}, {'id': 271, 'name': 'Wood-burning stove', 'freebase_id': '/m/04169hn'}, {'id': 272, 'name': 'Punching bag', 'freebase_id': '/m/0420v5'}, {'id': 273, 'name': 'Common fig', 'freebase_id': '/m/043nyj'}, {'id': 274, 'name': 'Cocktail shaker', 'freebase_id': '/m/0440zs'}, {'id': 275, 'name': 'Jaguar (Animal)', 'freebase_id': '/m/0449p'}, {'id': 276, 'name': 'Golf ball', 'freebase_id': '/m/044r5d'}, {'id': 277, 'name': 'Fashion accessory', 'freebase_id': '/m/0463sg'}, {'id': 278, 'name': 'Alarm clock', 'freebase_id': '/m/046dlr'}, {'id': 279, 'name': 'Filing cabinet', 'freebase_id': '/m/047j0r'}, {'id': 280, 'name': 'Artichoke', 'freebase_id': '/m/047v4b'}, {'id': 281, 'name': 'Table', 'freebase_id': '/m/04bcr3'}, {'id': 282, 'name': 'Tableware', 'freebase_id': '/m/04brg2'}, {'id': 283, 'name': 'Kangaroo', 'freebase_id': '/m/04c0y'}, {'id': 284, 'name': 'Koala', 'freebase_id': '/m/04cp_'}, {'id': 285, 'name': 'Knife', 'freebase_id': '/m/04ctx'}, {'id': 286, 'name': 'Bottle', 'freebase_id': '/m/04dr76w'}, {'id': 287, 'name': 'Bottle opener', 'freebase_id': '/m/04f5ws'}, {'id': 288, 'name': 'Lynx', 'freebase_id': '/m/04g2r'}, {'id': 289, 'name': 'Lavender (Plant)', 'freebase_id': '/m/04gth'}, {'id': 290, 'name': 'Lighthouse', 'freebase_id': '/m/04h7h'}, {'id': 291, 'name': 'Dumbbell', 'freebase_id': '/m/04h8sr'}, {'id': 292, 'name': 'Human head', 'freebase_id': '/m/04hgtk'}, {'id': 293, 'name': 'Bowl', 'freebase_id': '/m/04kkgm'}, {'id': 294, 'name': 'Humidifier', 'freebase_id': '/m/04lvq_'}, {'id': 295, 'name': 'Porch', 'freebase_id': '/m/04m6gz'}, {'id': 296, 'name': 'Lizard', 'freebase_id': '/m/04m9y'}, {'id': 297, 'name': 'Billiard table', 'freebase_id': '/m/04p0qw'}, {'id': 298, 'name': 'Mammal', 'freebase_id': '/m/04rky'}, {'id': 299, 'name': 'Mouse', 'freebase_id': '/m/04rmv'}, {'id': 300, 'name': 'Motorcycle', 'freebase_id': '/m/04_sv'}, {'id': 301, 'name': 'Musical instrument', 'freebase_id': '/m/04szw'}, {'id': 302, 'name': 'Swim cap', 'freebase_id': '/m/04tn4x'}, {'id': 303, 'name': 'Frying pan', 'freebase_id': '/m/04v6l4'}, {'id': 304, 'name': 'Snowplow', 'freebase_id': '/m/04vv5k'}, {'id': 305, 'name': 'Bathroom cabinet', 'freebase_id': '/m/04y4h8h'}, {'id': 306, 'name': 'Missile', 'freebase_id': '/m/04ylt'}, {'id': 307, 'name': 'Bust', 'freebase_id': '/m/04yqq2'}, {'id': 308, 'name': 'Man', 'freebase_id': '/m/04yx4'}, {'id': 309, 'name': 'Waffle iron', 'freebase_id': '/m/04z4wx'}, {'id': 310, 'name': 'Milk', 'freebase_id': '/m/04zpv'}, {'id': 311, 'name': 'Ring binder', 'freebase_id': '/m/04zwwv'}, {'id': 312, 'name': 'Plate', 'freebase_id': '/m/050gv4'}, {'id': 313, 'name': 'Mobile phone', 'freebase_id': '/m/050k8'}, {'id': 314, 'name': 'Baked goods', 'freebase_id': '/m/052lwg6'}, {'id': 315, 'name': 'Mushroom', 'freebase_id': '/m/052sf'}, {'id': 316, 'name': 'Crutch', 'freebase_id': '/m/05441v'}, {'id': 317, 'name': 'Pitcher (Container)', 'freebase_id': '/m/054fyh'}, {'id': 318, 'name': 'Mirror', 'freebase_id': '/m/054_l'}, {'id': 319, 'name': 'Personal flotation device', 'freebase_id': '/m/054xkw'}, {'id': 320, 'name': 'Table tennis racket', 'freebase_id': '/m/05_5p_0'}, {'id': 321, 'name': 'Pencil case', 'freebase_id': '/m/05676x'}, {'id': 322, 'name': 'Musical keyboard', 'freebase_id': '/m/057cc'}, {'id': 323, 'name': 'Scoreboard', 'freebase_id': '/m/057p5t'}, {'id': 324, 'name': 'Briefcase', 'freebase_id': '/m/0584n8'}, {'id': 325, 'name': 'Kitchen knife', 'freebase_id': '/m/058qzx'}, {'id': 326, 'name': 'Nail (Construction)', 'freebase_id': '/m/05bm6'}, {'id': 327, 'name': 'Tennis ball', 'freebase_id': '/m/05ctyq'}, {'id': 328, 'name': 'Plastic bag', 'freebase_id': '/m/05gqfk'}, {'id': 329, 'name': 'Oboe', 'freebase_id': '/m/05kms'}, {'id': 330, 'name': 'Chest of drawers', 'freebase_id': '/m/05kyg_'}, {'id': 331, 'name': 'Ostrich', 'freebase_id': '/m/05n4y'}, {'id': 332, 'name': 'Piano', 'freebase_id': '/m/05r5c'}, {'id': 333, 'name': 'Girl', 'freebase_id': '/m/05r655'}, {'id': 334, 'name': 'Plant', 'freebase_id': '/m/05s2s'}, {'id': 335, 'name': 'Potato', 'freebase_id': '/m/05vtc'}, {'id': 336, 'name': 'Hair spray', 'freebase_id': '/m/05w9t9'}, {'id': 337, 'name': 'Sports equipment', 'freebase_id': '/m/05y5lj'}, {'id': 338, 'name': 'Pasta', 'freebase_id': '/m/05z55'}, {'id': 339, 'name': 'Penguin', 'freebase_id': '/m/05z6w'}, {'id': 340, 'name': 'Pumpkin', 'freebase_id': '/m/05zsy'}, {'id': 341, 'name': 'Pear', 'freebase_id': '/m/061_f'}, {'id': 342, 'name': 'Infant bed', 'freebase_id': '/m/061hd_'}, {'id': 343, 'name': 'Polar bear', 'freebase_id': '/m/0633h'}, {'id': 344, 'name': 'Mixer', 'freebase_id': '/m/063rgb'}, {'id': 345, 'name': 'Cupboard', 'freebase_id': '/m/0642b4'}, {'id': 346, 'name': 'Jacuzzi', 'freebase_id': '/m/065h6l'}, {'id': 347, 'name': 'Pizza', 'freebase_id': '/m/0663v'}, {'id': 348, 'name': 'Digital clock', 'freebase_id': '/m/06_72j'}, {'id': 349, 'name': 'Pig', 'freebase_id': '/m/068zj'}, {'id': 350, 'name': 'Reptile', 'freebase_id': '/m/06bt6'}, {'id': 351, 'name': 'Rifle', 'freebase_id': '/m/06c54'}, {'id': 352, 'name': 'Lipstick', 'freebase_id': '/m/06c7f7'}, {'id': 353, 'name': 'Skateboard', 'freebase_id': '/m/06_fw'}, {'id': 354, 'name': 'Raven', 'freebase_id': '/m/06j2d'}, {'id': 355, 'name': 'High heels', 'freebase_id': '/m/06k2mb'}, {'id': 356, 'name': 'Red panda', 'freebase_id': '/m/06l9r'}, {'id': 357, 'name': 'Rose', 'freebase_id': '/m/06m11'}, {'id': 358, 'name': 'Rabbit', 'freebase_id': '/m/06mf6'}, {'id': 359, 'name': 'Sculpture', 'freebase_id': '/m/06msq'}, {'id': 360, 'name': 'Saxophone', 'freebase_id': '/m/06ncr'}, {'id': 361, 'name': 'Shotgun', 'freebase_id': '/m/06nrc'}, {'id': 362, 'name': 'Seafood', 'freebase_id': '/m/06nwz'}, {'id': 363, 'name': 'Submarine sandwich', 'freebase_id': '/m/06pcq'}, {'id': 364, 'name': 'Snowboard', 'freebase_id': '/m/06__v'}, {'id': 365, 'name': 'Sword', 'freebase_id': '/m/06y5r'}, {'id': 366, 'name': 'Picture frame', 'freebase_id': '/m/06z37_'}, {'id': 367, 'name': 'Sushi', 'freebase_id': '/m/07030'}, {'id': 368, 'name': 'Loveseat', 'freebase_id': '/m/0703r8'}, {'id': 369, 'name': 'Ski', 'freebase_id': '/m/071p9'}, {'id': 370, 'name': 'Squirrel', 'freebase_id': '/m/071qp'}, {'id': 371, 'name': 'Tripod', 'freebase_id': '/m/073bxn'}, {'id': 372, 'name': 'Stethoscope', 'freebase_id': '/m/073g6'}, {'id': 373, 'name': 'Submarine', 'freebase_id': '/m/074d1'}, {'id': 374, 'name': 'Scorpion', 'freebase_id': '/m/0755b'}, {'id': 375, 'name': 'Segway', 'freebase_id': '/m/076bq'}, {'id': 376, 'name': 'Training bench', 'freebase_id': '/m/076lb9'}, {'id': 377, 'name': 'Snake', 'freebase_id': '/m/078jl'}, {'id': 378, 'name': 'Coffee table', 'freebase_id': '/m/078n6m'}, {'id': 379, 'name': 'Skyscraper', 'freebase_id': '/m/079cl'}, {'id': 380, 'name': 'Sheep', 'freebase_id': '/m/07bgp'}, {'id': 381, 'name': 'Television', 'freebase_id': '/m/07c52'}, {'id': 382, 'name': 'Trombone', 'freebase_id': '/m/07c6l'}, {'id': 383, 'name': 'Tea', 'freebase_id': '/m/07clx'}, {'id': 384, 'name': 'Tank', 'freebase_id': '/m/07cmd'}, {'id': 385, 'name': 'Taco', 'freebase_id': '/m/07crc'}, {'id': 386, 'name': 'Telephone', 'freebase_id': '/m/07cx4'}, {'id': 387, 'name': 'Torch', 'freebase_id': '/m/07dd4'}, {'id': 388, 'name': 'Tiger', 'freebase_id': '/m/07dm6'}, {'id': 389, 'name': 'Strawberry', 'freebase_id': '/m/07fbm7'}, {'id': 390, 'name': 'Trumpet', 'freebase_id': '/m/07gql'}, {'id': 391, 'name': 'Tree', 'freebase_id': '/m/07j7r'}, {'id': 392, 'name': 'Tomato', 'freebase_id': '/m/07j87'}, {'id': 393, 'name': 'Train', 'freebase_id': '/m/07jdr'}, {'id': 394, 'name': 'Tool', 'freebase_id': '/m/07k1x'}, {'id': 395, 'name': 'Picnic basket', 'freebase_id': '/m/07kng9'}, {'id': 396, 'name': 'Cooking spray', 'freebase_id': '/m/07mcwg'}, {'id': 397, 'name': 'Trousers', 'freebase_id': '/m/07mhn'}, {'id': 398, 'name': 'Bowling equipment', 'freebase_id': '/m/07pj7bq'}, {'id': 399, 'name': 'Football helmet', 'freebase_id': '/m/07qxg_'}, {'id': 400, 'name': 'Truck', 'freebase_id': '/m/07r04'}, {'id': 401, 'name': 'Measuring cup', 'freebase_id': '/m/07v9_z'}, {'id': 402, 'name': 'Coffeemaker', 'freebase_id': '/m/07xyvk'}, {'id': 403, 'name': 'Violin', 'freebase_id': '/m/07y_7'}, {'id': 404, 'name': 'Vehicle', 'freebase_id': '/m/07yv9'}, {'id': 405, 'name': 'Handbag', 'freebase_id': '/m/080hkjn'}, {'id': 406, 'name': 'Paper cutter', 'freebase_id': '/m/080n7g'}, {'id': 407, 'name': 'Wine', 'freebase_id': '/m/081qc'}, {'id': 408, 'name': 'Weapon', 'freebase_id': '/m/083kb'}, {'id': 409, 'name': 'Wheel', 'freebase_id': '/m/083wq'}, {'id': 410, 'name': 'Worm', 'freebase_id': '/m/084hf'}, {'id': 411, 'name': 'Wok', 'freebase_id': '/m/084rd'}, {'id': 412, 'name': 'Whale', 'freebase_id': '/m/084zz'}, {'id': 413, 'name': 'Zebra', 'freebase_id': '/m/0898b'}, {'id': 414, 'name': 'Auto part', 'freebase_id': '/m/08dz3q'}, {'id': 415, 'name': 'Jug', 'freebase_id': '/m/08hvt4'}, {'id': 416, 'name': 'Pizza cutter', 'freebase_id': '/m/08ks85'}, {'id': 417, 'name': 'Cream', 'freebase_id': '/m/08p92x'}, {'id': 418, 'name': 'Monkey', 'freebase_id': '/m/08pbxl'}, {'id': 419, 'name': 'Lion', 'freebase_id': '/m/096mb'}, {'id': 420, 'name': 'Bread', 'freebase_id': '/m/09728'}, {'id': 421, 'name': 'Platter', 'freebase_id': '/m/099ssp'}, {'id': 422, 'name': 'Chicken', 'freebase_id': '/m/09b5t'}, {'id': 423, 'name': 'Eagle', 'freebase_id': '/m/09csl'}, {'id': 424, 'name': 'Helicopter', 'freebase_id': '/m/09ct_'}, {'id': 425, 'name': 'Owl', 'freebase_id': '/m/09d5_'}, {'id': 426, 'name': 'Duck', 'freebase_id': '/m/09ddx'}, {'id': 427, 'name': 'Turtle', 'freebase_id': '/m/09dzg'}, {'id': 428, 'name': 'Hippopotamus', 'freebase_id': '/m/09f20'}, {'id': 429, 'name': 'Crocodile', 'freebase_id': '/m/09f_2'}, {'id': 430, 'name': 'Toilet', 'freebase_id': '/m/09g1w'}, {'id': 431, 'name': 'Toilet paper', 'freebase_id': '/m/09gtd'}, {'id': 432, 'name': 'Squid', 'freebase_id': '/m/09gys'}, {'id': 433, 'name': 'Clothing', 'freebase_id': '/m/09j2d'}, {'id': 434, 'name': 'Footwear', 'freebase_id': '/m/09j5n'}, {'id': 435, 'name': 'Lemon', 'freebase_id': '/m/09k_b'}, {'id': 436, 'name': 'Spider', 'freebase_id': '/m/09kmb'}, {'id': 437, 'name': 'Deer', 'freebase_id': '/m/09kx5'}, {'id': 438, 'name': 'Frog', 'freebase_id': '/m/09ld4'}, {'id': 439, 'name': 'Banana', 'freebase_id': '/m/09qck'}, {'id': 440, 'name': 'Rocket', 'freebase_id': '/m/09rvcxw'}, {'id': 441, 'name': 'Wine glass', 'freebase_id': '/m/09tvcd'}, {'id': 442, 'name': 'Countertop', 'freebase_id': '/m/0b3fp9'}, {'id': 443, 'name': 'Tablet computer', 'freebase_id': '/m/0bh9flk'}, {'id': 444, 'name': 'Waste container', 'freebase_id': '/m/0bjyj5'}, {'id': 445, 'name': 'Swimming pool', 'freebase_id': '/m/0b_rs'}, {'id': 446, 'name': 'Dog', 'freebase_id': '/m/0bt9lr'}, {'id': 447, 'name': 'Book', 'freebase_id': '/m/0bt_c3'}, {'id': 448, 'name': 'Elephant', 'freebase_id': '/m/0bwd_0j'}, {'id': 449, 'name': 'Shark', 'freebase_id': '/m/0by6g'}, {'id': 450, 'name': 'Candle', 'freebase_id': '/m/0c06p'}, {'id': 451, 'name': 'Leopard', 'freebase_id': '/m/0c29q'}, {'id': 452, 'name': 'Axe', 'freebase_id': '/m/0c2jj'}, {'id': 453, 'name': 'Hand dryer', 'freebase_id': '/m/0c3m8g'}, {'id': 454, 'name': 'Soap dispenser', 'freebase_id': '/m/0c3mkw'}, {'id': 455, 'name': 'Porcupine', 'freebase_id': '/m/0c568'}, {'id': 456, 'name': 'Flower', 'freebase_id': '/m/0c9ph5'}, {'id': 457, 'name': 'Canary', 'freebase_id': '/m/0ccs93'}, {'id': 458, 'name': 'Cheetah', 'freebase_id': '/m/0cd4d'}, {'id': 459, 'name': 'Palm tree', 'freebase_id': '/m/0cdl1'}, {'id': 460, 'name': 'Hamburger', 'freebase_id': '/m/0cdn1'}, {'id': 461, 'name': 'Maple', 'freebase_id': '/m/0cffdh'}, {'id': 462, 'name': 'Building', 'freebase_id': '/m/0cgh4'}, {'id': 463, 'name': 'Fish', 'freebase_id': '/m/0ch_cf'}, {'id': 464, 'name': 'Lobster', 'freebase_id': '/m/0cjq5'}, {'id': 465, 'name': 'Garden Asparagus', 'freebase_id': '/m/0cjs7'}, {'id': 466, 'name': 'Furniture', 'freebase_id': '/m/0c_jw'}, {'id': 467, 'name': 'Hedgehog', 'freebase_id': '/m/0cl4p'}, {'id': 468, 'name': 'Airplane', 'freebase_id': '/m/0cmf2'}, {'id': 469, 'name': 'Spoon', 'freebase_id': '/m/0cmx8'}, {'id': 470, 'name': 'Otter', 'freebase_id': '/m/0cn6p'}, {'id': 471, 'name': 'Bull', 'freebase_id': '/m/0cnyhnx'}, {'id': 472, 'name': 'Oyster', 'freebase_id': '/m/0_cp5'}, {'id': 473, 'name': 'Horizontal bar', 'freebase_id': '/m/0cqn2'}, {'id': 474, 'name': 'Convenience store', 'freebase_id': '/m/0crjs'}, {'id': 475, 'name': 'Bomb', 'freebase_id': '/m/0ct4f'}, {'id': 476, 'name': 'Bench', 'freebase_id': '/m/0cvnqh'}, {'id': 477, 'name': 'Ice cream', 'freebase_id': '/m/0cxn2'}, {'id': 478, 'name': 'Caterpillar', 'freebase_id': '/m/0cydv'}, {'id': 479, 'name': 'Butterfly', 'freebase_id': '/m/0cyf8'}, {'id': 480, 'name': 'Parachute', 'freebase_id': '/m/0cyfs'}, {'id': 481, 'name': 'Orange', 'freebase_id': '/m/0cyhj_'}, {'id': 482, 'name': 'Antelope', 'freebase_id': '/m/0czz2'}, {'id': 483, 'name': 'Beaker', 'freebase_id': '/m/0d20w4'}, {'id': 484, 'name': 'Moths and butterflies', 'freebase_id': '/m/0d_2m'}, {'id': 485, 'name': 'Window', 'freebase_id': '/m/0d4v4'}, {'id': 486, 'name': 'Closet', 'freebase_id': '/m/0d4w1'}, {'id': 487, 'name': 'Castle', 'freebase_id': '/m/0d5gx'}, {'id': 488, 'name': 'Jellyfish', 'freebase_id': '/m/0d8zb'}, {'id': 489, 'name': 'Goose', 'freebase_id': '/m/0dbvp'}, {'id': 490, 'name': 'Mule', 'freebase_id': '/m/0dbzx'}, {'id': 491, 'name': 'Swan', 'freebase_id': '/m/0dftk'}, {'id': 492, 'name': 'Peach', 'freebase_id': '/m/0dj6p'}, {'id': 493, 'name': 'Coconut', 'freebase_id': '/m/0djtd'}, {'id': 494, 'name': 'Seat belt', 'freebase_id': '/m/0dkzw'}, {'id': 495, 'name': 'Raccoon', 'freebase_id': '/m/0dq75'}, {'id': 496, 'name': 'Chisel', 'freebase_id': '/m/0_dqb'}, {'id': 497, 'name': 'Fork', 'freebase_id': '/m/0dt3t'}, {'id': 498, 'name': 'Lamp', 'freebase_id': '/m/0dtln'}, {'id': 499, 'name': 'Camera', 'freebase_id': '/m/0dv5r'}, {'id': 500, 'name': 'Squash (Plant)', 'freebase_id': '/m/0dv77'}, {'id': 501, 'name': 'Racket', 'freebase_id': '/m/0dv9c'}, {'id': 502, 'name': 'Human face', 'freebase_id': '/m/0dzct'}, {'id': 503, 'name': 'Human arm', 'freebase_id': '/m/0dzf4'}, {'id': 504, 'name': 'Vegetable', 'freebase_id': '/m/0f4s2w'}, {'id': 505, 'name': 'Diaper', 'freebase_id': '/m/0f571'}, {'id': 506, 'name': 'Unicycle', 'freebase_id': '/m/0f6nr'}, {'id': 507, 'name': 'Falcon', 'freebase_id': '/m/0f6wt'}, {'id': 508, 'name': 'Chime', 'freebase_id': '/m/0f8s22'}, {'id': 509, 'name': 'Snail', 'freebase_id': '/m/0f9_l'}, {'id': 510, 'name': 'Shellfish', 'freebase_id': '/m/0fbdv'}, {'id': 511, 'name': 'Cabbage', 'freebase_id': '/m/0fbw6'}, {'id': 512, 'name': 'Carrot', 'freebase_id': '/m/0fj52s'}, {'id': 513, 'name': 'Mango', 'freebase_id': '/m/0fldg'}, {'id': 514, 'name': 'Jeans', 'freebase_id': '/m/0fly7'}, {'id': 515, 'name': 'Flowerpot', 'freebase_id': '/m/0fm3zh'}, {'id': 516, 'name': 'Pineapple', 'freebase_id': '/m/0fp6w'}, {'id': 517, 'name': 'Drawer', 'freebase_id': '/m/0fqfqc'}, {'id': 518, 'name': 'Stool', 'freebase_id': '/m/0fqt361'}, {'id': 519, 'name': 'Envelope', 'freebase_id': '/m/0frqm'}, {'id': 520, 'name': 'Cake', 'freebase_id': '/m/0fszt'}, {'id': 521, 'name': 'Dragonfly', 'freebase_id': '/m/0ft9s'}, {'id': 522, 'name': 'Common sunflower', 'freebase_id': '/m/0ftb8'}, {'id': 523, 'name': 'Microwave oven', 'freebase_id': '/m/0fx9l'}, {'id': 524, 'name': 'Honeycomb', 'freebase_id': '/m/0fz0h'}, {'id': 525, 'name': 'Marine mammal', 'freebase_id': '/m/0gd2v'}, {'id': 526, 'name': 'Sea lion', 'freebase_id': '/m/0gd36'}, {'id': 527, 'name': 'Ladybug', 'freebase_id': '/m/0gj37'}, {'id': 528, 'name': 'Shelf', 'freebase_id': '/m/0gjbg72'}, {'id': 529, 'name': 'Watch', 'freebase_id': '/m/0gjkl'}, {'id': 530, 'name': 'Candy', 'freebase_id': '/m/0gm28'}, {'id': 531, 'name': 'Salad', 'freebase_id': '/m/0grw1'}, {'id': 532, 'name': 'Parrot', 'freebase_id': '/m/0gv1x'}, {'id': 533, 'name': 'Handgun', 'freebase_id': '/m/0gxl3'}, {'id': 534, 'name': 'Sparrow', 'freebase_id': '/m/0h23m'}, {'id': 535, 'name': 'Van', 'freebase_id': '/m/0h2r6'}, {'id': 536, 'name': 'Grinder', 'freebase_id': '/m/0h8jyh6'}, {'id': 537, 'name': 'Spice rack', 'freebase_id': '/m/0h8kx63'}, {'id': 538, 'name': 'Light bulb', 'freebase_id': '/m/0h8l4fh'}, {'id': 539, 'name': 'Corded phone', 'freebase_id': '/m/0h8lkj8'}, {'id': 540, 'name': 'Sports uniform', 'freebase_id': '/m/0h8mhzd'}, {'id': 541, 'name': 'Tennis racket', 'freebase_id': '/m/0h8my_4'}, {'id': 542, 'name': 'Wall clock', 'freebase_id': '/m/0h8mzrc'}, {'id': 543, 'name': 'Serving tray', 'freebase_id': '/m/0h8n27j'}, {'id': 544, 'name': 'Kitchen & dining room table', 'freebase_id': '/m/0h8n5zk'}, {'id': 545, 'name': 'Dog bed', 'freebase_id': '/m/0h8n6f9'}, {'id': 546, 'name': 'Cake stand', 'freebase_id': '/m/0h8n6ft'}, {'id': 547, 'name': 'Cat furniture', 'freebase_id': '/m/0h8nm9j'}, {'id': 548, 'name': 'Bathroom accessory', 'freebase_id': '/m/0h8nr_l'}, {'id': 549, 'name': 'Facial tissue holder', 'freebase_id': '/m/0h8nsvg'}, {'id': 550, 'name': 'Pressure cooker', 'freebase_id': '/m/0h8ntjv'}, {'id': 551, 'name': 'Kitchen appliance', 'freebase_id': '/m/0h99cwc'}, {'id': 552, 'name': 'Tire', 'freebase_id': '/m/0h9mv'}, {'id': 553, 'name': 'Ruler', 'freebase_id': '/m/0hdln'}, {'id': 554, 'name': 'Luggage and bags', 'freebase_id': '/m/0hf58v5'}, {'id': 555, 'name': 'Microphone', 'freebase_id': '/m/0hg7b'}, {'id': 556, 'name': 'Broccoli', 'freebase_id': '/m/0hkxq'}, {'id': 557, 'name': 'Umbrella', 'freebase_id': '/m/0hnnb'}, {'id': 558, 'name': 'Pastry', 'freebase_id': '/m/0hnyx'}, {'id': 559, 'name': 'Grapefruit', 'freebase_id': '/m/0hqkz'}, {'id': 560, 'name': 'Band-aid', 'freebase_id': '/m/0j496'}, {'id': 561, 'name': 'Animal', 'freebase_id': '/m/0jbk'}, {'id': 562, 'name': 'Bell pepper', 'freebase_id': '/m/0jg57'}, {'id': 563, 'name': 'Turkey', 'freebase_id': '/m/0jly1'}, {'id': 564, 'name': 'Lily', 'freebase_id': '/m/0jqgx'}, {'id': 565, 'name': 'Pomegranate', 'freebase_id': '/m/0jwn_'}, {'id': 566, 'name': 'Doughnut', 'freebase_id': '/m/0jy4k'}, {'id': 567, 'name': 'Glasses', 'freebase_id': '/m/0jyfg'}, {'id': 568, 'name': 'Human nose', 'freebase_id': '/m/0k0pj'}, {'id': 569, 'name': 'Pen', 'freebase_id': '/m/0k1tl'}, {'id': 570, 'name': 'Ant', 'freebase_id': '/m/0_k2'}, {'id': 571, 'name': 'Car', 'freebase_id': '/m/0k4j'}, {'id': 572, 'name': 'Aircraft', 'freebase_id': '/m/0k5j'}, {'id': 573, 'name': 'Human hand', 'freebase_id': '/m/0k65p'}, {'id': 574, 'name': 'Skunk', 'freebase_id': '/m/0km7z'}, {'id': 575, 'name': 'Teddy bear', 'freebase_id': '/m/0kmg4'}, {'id': 576, 'name': 'Watermelon', 'freebase_id': '/m/0kpqd'}, {'id': 577, 'name': 'Cantaloupe', 'freebase_id': '/m/0kpt_'}, {'id': 578, 'name': 'Dishwasher', 'freebase_id': '/m/0ky7b'}, {'id': 579, 'name': 'Flute', 'freebase_id': '/m/0l14j_'}, {'id': 580, 'name': 'Balance beam', 'freebase_id': '/m/0l3ms'}, {'id': 581, 'name': 'Sandwich', 'freebase_id': '/m/0l515'}, {'id': 582, 'name': 'Shrimp', 'freebase_id': '/m/0ll1f78'}, {'id': 583, 'name': 'Sewing machine', 'freebase_id': '/m/0llzx'}, {'id': 584, 'name': 'Binoculars', 'freebase_id': '/m/0lt4_'}, {'id': 585, 'name': 'Rays and skates', 'freebase_id': '/m/0m53l'}, {'id': 586, 'name': 'Ipod', 'freebase_id': '/m/0mcx2'}, {'id': 587, 'name': 'Accordion', 'freebase_id': '/m/0mkg'}, {'id': 588, 'name': 'Willow', 'freebase_id': '/m/0mw_6'}, {'id': 589, 'name': 'Crab', 'freebase_id': '/m/0n28_'}, {'id': 590, 'name': 'Crown', 'freebase_id': '/m/0nl46'}, {'id': 591, 'name': 'Seahorse', 'freebase_id': '/m/0nybt'}, {'id': 592, 'name': 'Perfume', 'freebase_id': '/m/0p833'}, {'id': 593, 'name': 'Alpaca', 'freebase_id': '/m/0pcr'}, {'id': 594, 'name': 'Taxi', 'freebase_id': '/m/0pg52'}, {'id': 595, 'name': 'Canoe', 'freebase_id': '/m/0ph39'}, {'id': 596, 'name': 'Remote control', 'freebase_id': '/m/0qjjc'}, {'id': 597, 'name': 'Wheelchair', 'freebase_id': '/m/0qmmr'}, {'id': 598, 'name': 'Rugby ball', 'freebase_id': '/m/0wdt60w'}, {'id': 599, 'name': 'Armadillo', 'freebase_id': '/m/0xfy'}, {'id': 600, 'name': 'Maracas', 'freebase_id': '/m/0xzly'}, {'id': 601, 'name': 'Helmet', 'freebase_id': '/m/0zvk5'}]
|
regionspot/data/v3det.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from detectron2.data.datasets.register_coco import register_coco_instances
|
2 |
+
import os
|
3 |
+
|
4 |
+
from .v3det_categories import categories
|
5 |
+
def _get_builtin_metadata(categories):
|
6 |
+
id_to_name = {x['id']: x['name'] for x in categories}
|
7 |
+
thing_dataset_id_to_contiguous_id = {i + 1: i for i in range(len(categories))}
|
8 |
+
thing_classes = [id_to_name[k] for k in sorted(id_to_name)]
|
9 |
+
|
10 |
+
return {
|
11 |
+
"thing_dataset_id_to_contiguous_id": thing_dataset_id_to_contiguous_id,
|
12 |
+
"thing_classes": thing_classes}
|
13 |
+
|
14 |
+
def _get_builtin_metadata():
|
15 |
+
id_to_name = {x['id']: x['name'] for x in categories}
|
16 |
+
thing_dataset_id_to_contiguous_id = {i + 1: i for i in range(len(categories))}
|
17 |
+
thing_classes = [id_to_name[k] for k in sorted(id_to_name)]
|
18 |
+
return {
|
19 |
+
"thing_dataset_id_to_contiguous_id": thing_dataset_id_to_contiguous_id,
|
20 |
+
"thing_classes": thing_classes}
|
21 |
+
|
22 |
+
|
23 |
+
_PREDEFINED_SPLITS_V3DET = {
|
24 |
+
"v3det_train": ("v3det/V3Det/", "v3det/v3det_2023_v1_train.json"),
|
25 |
+
"v3det_val": ("v3det/V3Det/", "v3det/v3det_2023_v1_val.json"),
|
26 |
+
}
|
27 |
+
|
28 |
+
for key, (image_root, json_file) in _PREDEFINED_SPLITS_V3DET.items():
|
29 |
+
register_coco_instances(
|
30 |
+
key,
|
31 |
+
_get_builtin_metadata(),
|
32 |
+
os.path.join("datasets", json_file) if "://" not in json_file else json_file,
|
33 |
+
os.path.join("datasets", image_root),
|
34 |
+
)
|
regionspot/data/v3det_categories.py
ADDED
The diff for this file is too large to render.
See raw diff
|
|
regionspot/detector.py
ADDED
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import namedtuple
|
2 |
+
from .modeling.regionspot import build_regionspot_model
|
3 |
+
import torch.cuda.amp as amp
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
import torch.nn.functional as F
|
8 |
+
from torch import nn
|
9 |
+
from einops import rearrange
|
10 |
+
import json
|
11 |
+
from detectron2.modeling import META_ARCH_REGISTRY
|
12 |
+
from .util.postprocessing import segmentation_postprocess
|
13 |
+
|
14 |
+
from detectron2.structures import Boxes, Instances
|
15 |
+
from .util.preprocessing import prepare_prompt_infer, prepare_prompt_train
|
16 |
+
|
17 |
+
__all__ = ["RegionSpot"]
|
18 |
+
|
19 |
+
|
20 |
+
|
21 |
+
@META_ARCH_REGISTRY.register()
|
22 |
+
class RegionSpot(nn.Module):
|
23 |
+
"""
|
24 |
+
Implement RegionSpot
|
25 |
+
"""
|
26 |
+
def __init__(self, cfg):
|
27 |
+
super().__init__()
|
28 |
+
self.device = torch.device(cfg.MODEL.DEVICE)
|
29 |
+
self.clip_type = cfg.MODEL.CLIP_TYPE
|
30 |
+
self.inference_box_type = cfg.MODEL.BOX_TYPE
|
31 |
+
self.clip_input_size = cfg.MODEL.CLIP_INPUT_SIZE
|
32 |
+
self.clip_target_size = (self.clip_input_size, self.clip_input_size)
|
33 |
+
self.model, _ = build_regionspot_model(clip_type = self.clip_type, is_training=cfg.MODEL.TRAINING, image_size=self.clip_input_size)
|
34 |
+
self.model.to(self.device)
|
35 |
+
if self.inference_box_type != 'GT':
|
36 |
+
path = './datasets/glip_results/nms_results_glip_tiny_model_o365_goldg_cc_sbu_lvis_val.json'
|
37 |
+
with open(path, 'r') as file:
|
38 |
+
self.pred_results = json.load(file)
|
39 |
+
else:
|
40 |
+
self.pred_results = None
|
41 |
+
|
42 |
+
@torch.no_grad()
|
43 |
+
def foward_inference(self, batched_inputs, do_postprocess=True):
|
44 |
+
with amp.autocast(enabled=True):
|
45 |
+
with torch.no_grad():
|
46 |
+
logits_per_image, pred_mask = self.model.forward_eval(batched_inputs, multimask_output=False)
|
47 |
+
|
48 |
+
image_sizes = [x["original_size"] for x in batched_inputs]
|
49 |
+
if self.inference_box_type == 'GT':
|
50 |
+
boxes = torch.stack([x["instances"].gt_boxes.tensor for x in batched_inputs], dim=0) #n, n_box, n_token, 256
|
51 |
+
if len(boxes[0]) == 0:
|
52 |
+
boxes=torch.tensor([[[0,0, image_sizes[0][0], image_sizes[0][1]]]])
|
53 |
+
else:
|
54 |
+
boxes = torch.stack([x["pred_boxes"] for x in batched_inputs], dim=0) #n, n_box, n_token, 256
|
55 |
+
scores = torch.stack([x["scores"] for x in batched_inputs], dim=0)
|
56 |
+
|
57 |
+
|
58 |
+
box_cls = logits_per_image
|
59 |
+
box_pred = boxes
|
60 |
+
if self.inference_box_type == 'GT':
|
61 |
+
results = self.inference_gt_box(box_cls, box_pred, pred_mask, image_sizes)
|
62 |
+
else:
|
63 |
+
results = self.inference_pred_box(box_cls, box_pred, scores, pred_mask, image_sizes)
|
64 |
+
if do_postprocess:
|
65 |
+
processed_results = []
|
66 |
+
for results_per_image, input_per_image, image_size in zip(results, batched_inputs, image_sizes):
|
67 |
+
height = input_per_image.get("height", image_size[0])
|
68 |
+
width = input_per_image.get("width", image_size[1])
|
69 |
+
r = segmentation_postprocess(results_per_image, height, width)
|
70 |
+
processed_results.append({"instances": r})
|
71 |
+
return processed_results
|
72 |
+
else:
|
73 |
+
return results
|
74 |
+
|
75 |
+
def foward_train(self, batched_inputs):
|
76 |
+
with amp.autocast(enabled=True):
|
77 |
+
outputs = self.model.forward_train(batched_inputs)
|
78 |
+
loss = {'loss': outputs}
|
79 |
+
return loss
|
80 |
+
|
81 |
+
def forward(self, batched_inputs, do_postprocess=True):
|
82 |
+
if not self.training:
|
83 |
+
# Prepare Prompt.
|
84 |
+
batched_inputs = prepare_prompt_infer(batched_inputs, pred_results = self.pred_results, target_size=self.clip_target_size)
|
85 |
+
|
86 |
+
results = self.foward_inference(batched_inputs)
|
87 |
+
return results
|
88 |
+
|
89 |
+
if self.training:
|
90 |
+
batched_inputs = prepare_prompt_train(batched_inputs, target_size=self.clip_target_size)
|
91 |
+
loss_dict = self.foward_train(batched_inputs)
|
92 |
+
return loss_dict
|
93 |
+
|
94 |
+
|
95 |
+
|
96 |
+
def inference_gt_box(self, box_cls, box_pred, pred_mask, image_sizes=None):
|
97 |
+
|
98 |
+
device = box_cls.device # assuming all tensors are on the same device
|
99 |
+
results = []
|
100 |
+
|
101 |
+
for logits, boxes, masks, img_size in zip(box_cls, box_pred, pred_mask, image_sizes):
|
102 |
+
# Calculate probabilities and flatten them
|
103 |
+
probs = F.softmax(logits, dim=-1)
|
104 |
+
probs_flattened = probs.view(-1)
|
105 |
+
|
106 |
+
# Determine number of top predictions to consider
|
107 |
+
top_num = min(900, len(probs_flattened))
|
108 |
+
|
109 |
+
# Get top-k values and indices
|
110 |
+
topk_probs, topk_indices = torch.topk(probs_flattened, top_num)
|
111 |
+
|
112 |
+
# Decode the top-k indices to get corresponding labels and boxes
|
113 |
+
topk_labels = topk_indices % logits.shape[1]
|
114 |
+
topk_boxes_indices = topk_indices // logits.shape[1]
|
115 |
+
|
116 |
+
# Ensure boxes, masks and topk_boxes_indices are on the same device
|
117 |
+
topk_boxes_indices = topk_boxes_indices.to(device)
|
118 |
+
boxes = boxes.to(device)
|
119 |
+
masks = masks.to(device)
|
120 |
+
|
121 |
+
# Retrieve predictions using the top-k indices
|
122 |
+
boxes_for_topk = boxes[topk_boxes_indices]
|
123 |
+
masks_for_topk = masks[topk_boxes_indices]
|
124 |
+
scores_for_topk = topk_probs # Modify accordingly if you have another score tensor
|
125 |
+
# Create Instances object for top-k predictions
|
126 |
+
result = Instances(img_size)
|
127 |
+
result.pred_boxes = Boxes(boxes_for_topk)
|
128 |
+
result.scores = scores_for_topk
|
129 |
+
result.pred_classes = topk_labels
|
130 |
+
result.pred_masks = masks_for_topk # Added masks to the result
|
131 |
+
results.append(result)
|
132 |
+
|
133 |
+
return results
|
134 |
+
|
135 |
+
def inference_pred_box(self, box_cls, box_pred, box_score, masks, image_sizes=None):
|
136 |
+
|
137 |
+
results = []
|
138 |
+
|
139 |
+
for i, (logits, box_pred_i, box_score_i, mask_i, img_size) in enumerate(zip(box_cls, box_pred, box_score, masks, image_sizes)):
|
140 |
+
|
141 |
+
logits = logits.cuda()
|
142 |
+
box_pred_i = box_pred_i.cuda()
|
143 |
+
box_score_i = box_score_i.cuda()
|
144 |
+
|
145 |
+
# Calculate probabilities and flatten them
|
146 |
+
probs = F.softmax(logits, dim=-1)
|
147 |
+
probs_flattened = probs.view(-1)
|
148 |
+
|
149 |
+
# Determine number of top predictions to consider
|
150 |
+
top_num = min(900, len(probs_flattened))
|
151 |
+
|
152 |
+
# Get top-k values and indices
|
153 |
+
topk_probs, topk_indices = torch.topk(probs_flattened, top_num)
|
154 |
+
|
155 |
+
# Decode the top-k indices to get corresponding labels and boxes
|
156 |
+
topk_labels = topk_indices % logits.shape[1]
|
157 |
+
topk_boxes_indices = topk_indices // logits.shape[1]
|
158 |
+
|
159 |
+
# Retrieve predictions using the top-k indices
|
160 |
+
boxes = box_pred_i[topk_boxes_indices]
|
161 |
+
masks = mask_i[topk_boxes_indices]
|
162 |
+
scores = box_score_i[topk_boxes_indices] * topk_probs
|
163 |
+
|
164 |
+
# Construct result for the current image
|
165 |
+
result = Instances(img_size)
|
166 |
+
result.pred_boxes = Boxes(boxes)
|
167 |
+
result.scores = scores
|
168 |
+
result.pred_classes = topk_labels
|
169 |
+
result.pred_masks = masks
|
170 |
+
results.append(result)
|
171 |
+
|
172 |
+
return results
|
173 |
+
|
174 |
+
|
regionspot/modeling/__pycache__/constants.cpython-38.pyc
ADDED
Binary file (98.2 kB). View file
|
|
regionspot/modeling/__pycache__/decoder.cpython-38.pyc
ADDED
Binary file (5 kB). View file
|
|
regionspot/modeling/__pycache__/regionspot.cpython-38.pyc
ADDED
Binary file (8.88 kB). View file
|
|
regionspot/modeling/clip/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .clip import *
|
regionspot/modeling/clip/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (227 Bytes). View file
|
|
regionspot/modeling/clip/__pycache__/clip.cpython-38.pyc
ADDED
Binary file (8.36 kB). View file
|
|
regionspot/modeling/clip/__pycache__/model.cpython-38.pyc
ADDED
Binary file (16.6 kB). View file
|
|