forSubAnony commited on
Commit
57abc33
·
1 Parent(s): f3a73a0
Files changed (48) hide show
  1. README.md +2 -0
  2. app.py +638 -0
  3. examples/0.png +0 -0
  4. examples/0008733.png +0 -0
  5. examples/0015849.png +0 -0
  6. examples/0021429.png +0 -0
  7. examples/1.png +0 -0
  8. examples/2.png +0 -0
  9. examples/5.jpg +0 -0
  10. examples/6.jpg +0 -0
  11. examples/8.jpg +0 -0
  12. load_models.py +132 -0
  13. models.py +208 -0
  14. ppc_decoder.py +231 -0
  15. requirements.txt +228 -0
  16. segment_anything/__init__.py +13 -0
  17. segment_anything/__pycache__/__init__.cpython-310.pyc +0 -0
  18. segment_anything/__pycache__/build_sam.cpython-310.pyc +0 -0
  19. segment_anything/build_sam.py +107 -0
  20. segment_anything/modeling/MaskDecoderHQ.py +210 -0
  21. segment_anything/modeling/__init__.py +13 -0
  22. segment_anything/modeling/__pycache__/MaskDecoderHQ.cpython-310.pyc +0 -0
  23. segment_anything/modeling/__pycache__/UpNet.cpython-310.pyc +0 -0
  24. segment_anything/modeling/__pycache__/__init__.cpython-310.pyc +0 -0
  25. segment_anything/modeling/__pycache__/common.cpython-310.pyc +0 -0
  26. segment_anything/modeling/__pycache__/image_encoder.cpython-310.pyc +0 -0
  27. segment_anything/modeling/__pycache__/mask_decoder.cpython-310.pyc +0 -0
  28. segment_anything/modeling/__pycache__/prompt_encoder.cpython-310.pyc +0 -0
  29. segment_anything/modeling/__pycache__/sam.cpython-310.pyc +0 -0
  30. segment_anything/modeling/__pycache__/transformer.cpython-310.pyc +0 -0
  31. segment_anything/modeling/common.py +43 -0
  32. segment_anything/modeling/image_encoder.py +398 -0
  33. segment_anything/modeling/mask_decoder.py +176 -0
  34. segment_anything/modeling/prompt_encoder.py +214 -0
  35. segment_anything/modeling/sam.py +182 -0
  36. segment_anything/modeling/transformer.py +240 -0
  37. segment_anything/utils/__init__.py +5 -0
  38. segment_anything/utils/transforms.py +102 -0
  39. utils/__pycache__/box_ops.cpython-310.pyc +0 -0
  40. utils/__pycache__/misc.cpython-310.pyc +0 -0
  41. utils/__pycache__/transforms.cpython-310.pyc +0 -0
  42. utils/box_ops.py +140 -0
  43. utils/datasets/__init__.py +0 -0
  44. utils/datasets/__pycache__/__init__.cpython-310.pyc +0 -0
  45. utils/datasets/__pycache__/transforms.cpython-310.pyc +0 -0
  46. utils/datasets/transforms.py +311 -0
  47. utils/misc.py +717 -0
  48. utils/transforms.py +102 -0
README.md CHANGED
@@ -11,3 +11,5 @@ license: apache-2.0
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
14
+
15
+ # PPC-SAM DEMO
app.py ADDED
@@ -0,0 +1,638 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn.functional as F
5
+ import os
6
+ import cv2
7
+ import pathlib
8
+ from load_models import PPC_SAM
9
+ # device = 'cuda' if torch.cuda.is_available() else 'cpu'
10
+ device = "cpu"
11
+ H = 512
12
+ W = 512
13
+
14
+ threshold_ppc = 0.5
15
+ threshold_sam = 0
16
+
17
+ test_example_dir = pathlib.Path("./examples")
18
+ test_examples = [str(test_example_dir / x) for x in sorted(os.listdir(test_example_dir))]
19
+
20
+ default_example = test_examples[0]
21
+
22
+ # -----------------------------------------------------------------------------
23
+ # Model initialization functions
24
+ # -----------------------------------------------------------------------------
25
+
26
+ def load_model(device = "cuda"):
27
+ exp = PPC_SAM(device=device)
28
+ return exp
29
+
30
+ # -----------------------------------------------------------------------------
31
+ # PPC-SAM help functions
32
+ # -----------------------------------------------------------------------------
33
+ import os
34
+ import numpy as np
35
+ import matplotlib.pyplot as plt
36
+ from PIL import Image
37
+
38
+ def visualize_and_save_binary_mask(mask, save_dir, file_name_prefix):
39
+ """
40
+ Visualize and save a binary mask.
41
+
42
+ Parameters:
43
+ - mask (np.array): The binary mask to save and visualize, with shape (H, W) or (H, W, 3).
44
+ - save_dir (str): Directory where the images will be saved.
45
+ - file_name_prefix (str): Prefix for the saved file names.
46
+
47
+ Saves the following image:
48
+ - mask: "{file_name_prefix}_mask.png"
49
+ - colored mask: "{file_name_prefix}_mask_colored.png" (if mask is grayscale)
50
+ """
51
+
52
+ if isinstance(mask, np.ndarray):
53
+ # Check if mask is RGB (3 channels)
54
+ if len(mask.shape) == 3 and mask.shape[2] == 3:
55
+ mask_image = Image.fromarray(mask)
56
+ else:
57
+ # Ensure mask is binary (0 and 1) and convert to 0 and 255
58
+ mask = (mask > 0).astype(np.uint8) * 255
59
+ mask_image = Image.fromarray(mask)
60
+ else:
61
+ mask_image = mask
62
+
63
+ # Ensure the save directory exists
64
+ os.makedirs(save_dir, exist_ok=True)
65
+
66
+ # Save the binary mask or RGB mask
67
+ mask_image.save(os.path.join(save_dir, f"{file_name_prefix}_mask.png"))
68
+
69
+
70
+ print(f"Mask images saved in {save_dir}")
71
+
72
+
73
+ # -----------------------------------------------------------------------------
74
+ # Vizualization functions
75
+ # -----------------------------------------------------------------------------
76
+
77
+ def _get_overlay(img, lay, const_color="l_blue"):
78
+ """
79
+ Helper function for preparing overlay
80
+ """
81
+ assert lay.ndim==2, "Overlay must be 2D, got shape: " + str(lay.shape)
82
+
83
+ if img.ndim == 2:
84
+ img = np.repeat(img[...,None], 3, axis=-1)
85
+
86
+ assert img.ndim==3, "Image must be 3D, got shape: " + str(img.shape)
87
+
88
+ if const_color == "blue":
89
+ const_color = 255*np.array([0, 0, 1])
90
+ elif const_color == "green":
91
+ const_color = 255*np.array([0, 1, 0])
92
+ elif const_color == "red":
93
+ const_color = 255*np.array([1, 0, 0])
94
+ elif const_color == "l_blue":
95
+ const_color = np.array([31, 119, 180])
96
+ elif const_color == "orange":
97
+ const_color = np.array([255, 127, 14])
98
+ else:
99
+ raise NotImplementedError
100
+
101
+ x,y = np.nonzero(lay)
102
+ for i in range(img.shape[-1]):
103
+ img[x,y,i] = const_color[i]
104
+
105
+ return img
106
+
107
+ def image_overlay(img, mask=None, scribbles=None, contour=False, alpha=0.5):
108
+ """
109
+ Overlay the ground truth mask and scribbles on the image if provided
110
+ """
111
+ # assert img.ndim == 2, "Image must be 2D, got shape: " + str(img.shape)
112
+ # output = np.repeat(img[...,None], 3, axis=-1)
113
+
114
+ output = img
115
+
116
+ if mask is not None:
117
+
118
+ assert mask.ndim == 2, "Mask must be 2D, got shape: " + str(mask.shape)
119
+
120
+ if contour:
121
+ contours = cv2.findContours((mask[...,None]>0.5).astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
122
+ cv2.drawContours(output, contours[0], -1, (0, 255, 0), 2)
123
+ else:
124
+ mask_overlay = _get_overlay(img, mask)
125
+ mask2 = 0.5*np.repeat(mask[...,None], 3, axis=-1)
126
+ output = cv2.convertScaleAbs(mask_overlay * mask2 + output * (1 - mask2))
127
+
128
+ if scribbles is not None:
129
+ pos_scribble_overlay = _get_overlay(output, scribbles[0,...], const_color="green")
130
+ cv2.addWeighted(pos_scribble_overlay, alpha, output, 1 - alpha, 0, output)
131
+ neg_scribble_overlay = _get_overlay(output, scribbles[1,...], const_color="red")
132
+ cv2.addWeighted(neg_scribble_overlay, alpha, output, 1 - alpha, 0, output)
133
+
134
+ return output
135
+
136
+
137
+ def viz_pred_mask(img, mask=None, point_coords=None, point_labels=None, bbox_coords=None, seperate_scribble_masks=None, binary=True):
138
+ """
139
+ Visualize image with clicks, scribbles, predicted mask overlaid
140
+ """
141
+ assert isinstance(img, np.ndarray), "Image must be numpy array, got type: " + str(type(img))
142
+ if mask is not None:
143
+ if isinstance(mask, torch.Tensor):
144
+ mask = mask.cpu().numpy()
145
+
146
+ if binary and mask is not None:
147
+ mask = 1*(mask > 0.5)
148
+
149
+ out = image_overlay(img, mask=mask, scribbles=seperate_scribble_masks)
150
+
151
+ H,W = img.shape[:2]
152
+ marker_size = min(H,W)//100
153
+
154
+ if point_coords is not None:
155
+ for i,(col,row) in enumerate(point_coords):
156
+ if point_labels[i] == 1:
157
+ cv2.circle(out,(col, row), marker_size, (0,255,0), -1)
158
+ else:
159
+ cv2.circle(out,(col, row), marker_size, (255,0,0), -1)
160
+
161
+ if bbox_coords is not None:
162
+ for i in range(len(bbox_coords)//2):
163
+ cv2.rectangle(out, bbox_coords[2*i], bbox_coords[2*i+1], (255,165,0), marker_size)
164
+ if len(bbox_coords) % 2 == 1:
165
+ cv2.circle(out, tuple(bbox_coords[-1]), marker_size, (255,165,0), -1)
166
+
167
+ return out.astype(np.uint8)
168
+
169
+ # -----------------------------------------------------------------------------
170
+ # Collect scribbles
171
+ # -----------------------------------------------------------------------------
172
+
173
+ def get_scribbles(seperate_scribble_masks, last_scribble_mask, scribble_img):
174
+ """
175
+ Record scribbles
176
+ """
177
+ assert isinstance(seperate_scribble_masks, np.ndarray), "seperate_scribble_masks must be numpy array, got type: " + str(type(seperate_scribble_masks))
178
+
179
+ if scribble_img is not None:
180
+
181
+ # Only use first layer
182
+ color_mask = scribble_img.get('layers')[0]
183
+
184
+ positive_scribbles = 1.0*(color_mask[...,1] > 128)
185
+ negative_scribbles = 1.0*(color_mask[...,0] > 128)
186
+
187
+ seperate_scribble_masks = np.stack([positive_scribbles, negative_scribbles], axis=0)
188
+ last_scribble_mask = None
189
+
190
+ return seperate_scribble_masks, last_scribble_mask
191
+
192
+ def get_predictions(input_img, click_coords, click_labels, bbox_coords, seperate_scribble_masks,
193
+ low_res_mask, img_features, multimask_mode):
194
+ """
195
+ Make predictions
196
+ """
197
+ box = None
198
+ if len(bbox_coords) == 1:
199
+ gr.Error("Please click a second time to define the bounding box")
200
+ box = None
201
+ elif len(bbox_coords) == 2:
202
+ box = torch.Tensor(bbox_coords).flatten()[None,None,...].int().to(device) # B x n x 4
203
+
204
+ if seperate_scribble_masks is not None:
205
+ scribble = torch.from_numpy(seperate_scribble_masks)[None,...].to(device)
206
+ else:
207
+ scribble = None
208
+
209
+ #--------------------------#
210
+ # visualize_and_save_binary_mask(input_img, './output', 'example_rgb_mask')
211
+
212
+ image = input_img
213
+ box = box.squeeze(0) if box != None else None
214
+ points_coords = torch.Tensor([click_coords]).int().to(device) if len(click_coords)>0 else None
215
+ points_labels = torch.Tensor([click_labels]).int().to(device) if len(click_labels)>0 else None
216
+ #--------------------------#
217
+
218
+
219
+ prompts = dict(
220
+ image=image,
221
+ point_coords=points_coords,
222
+ point_labels=points_labels,
223
+ scribble=scribble,
224
+ mask_input=low_res_mask.to(device) if low_res_mask is not None else None,
225
+ boxes=box,
226
+ )
227
+
228
+ masks, img_features, low_res_mask = predictor.predict([prompts], multimask_ouput=multimask_mode)
229
+
230
+
231
+ return masks.cpu(), img_features, low_res_mask
232
+
233
+ def refresh_predictions(input_img, output_img, click_coords, click_labels, bbox_coords, brush_label,
234
+ scribble_img, seperate_scribble_masks, last_scribble_mask,
235
+ best_mask, low_res_mask, img_features, binary_checkbox, multimask_mode):
236
+
237
+ # Record any new scribbles
238
+
239
+ seperate_scribble_masks, last_scribble_mask = get_scribbles(
240
+ seperate_scribble_masks, last_scribble_mask, scribble_img
241
+ )
242
+
243
+ # Make prediction
244
+ stacked_masks, img_features, low_res_mask = get_predictions(
245
+ input_img, click_coords, click_labels, bbox_coords, seperate_scribble_masks, low_res_mask, img_features, multimask_mode
246
+ )
247
+
248
+ # Update input visualizations
249
+ # --------------------------------------#
250
+ if len(stacked_masks.shape) == 3 and stacked_masks.shape[0] == 3:
251
+ best_mask = stacked_masks[0]
252
+
253
+ input_img_copy = []
254
+ for i in range(1, stacked_masks.shape[0]):
255
+ input_img_copy.append(input_img.copy())
256
+ # --------------------------------------#
257
+
258
+ mask_to_viz = best_mask.numpy()
259
+ click_input_viz = viz_pred_mask(input_img, mask_to_viz, click_coords, click_labels, bbox_coords, seperate_scribble_masks, binary_checkbox)
260
+
261
+
262
+ empty_channel = np.zeros(input_img.shape[:2]).astype(np.uint8)
263
+ full_channel = 255*np.ones(input_img.shape[:2]).astype(np.uint8)
264
+ gray_mask = (255*mask_to_viz).astype(np.uint8)
265
+
266
+ bg = viz_pred_mask(input_img, mask_to_viz, click_coords, click_labels, bbox_coords, None, binary_checkbox)
267
+ old_scribbles = scribble_img.get('layers')[0]
268
+
269
+ scribble_mask = 255*(old_scribbles > 0).any(-1)
270
+
271
+ scribble_input_viz = {
272
+ "background": np.stack([bg[...,i] for i in range(3)]+[full_channel], axis=-1),
273
+ ["layers"][0]: [np.stack([
274
+ (255*seperate_scribble_masks[1]).astype(np.uint8),
275
+ (255*seperate_scribble_masks[0]).astype(np.uint8),
276
+ empty_channel,
277
+ scribble_mask
278
+ ], axis=-1)],
279
+ "composite": np.stack([click_input_viz[...,i] for i in range(3)]+[empty_channel], axis=-1),
280
+ }
281
+
282
+ mask_img = 255*(mask_to_viz[...,None].repeat(axis=2, repeats=3)>threshold_sam) if binary_checkbox else mask_to_viz[...,None].repeat(axis=2, repeats=3)
283
+
284
+ out_viz = [
285
+ viz_pred_mask(input_img, mask_to_viz, point_coords=None, point_labels=None, bbox_coords=None, seperate_scribble_masks=None, binary=binary_checkbox),
286
+ mask_img,
287
+ ]
288
+
289
+
290
+ for i in range(1, stacked_masks.shape[0]):
291
+ mask = stacked_masks[i].numpy()
292
+ mask_img = 255*(mask[...,None].repeat(axis=2, repeats=3)>threshold_sam) if binary_checkbox else mask_to_viz[...,None].repeat(axis=2, repeats=3)
293
+ tmp_viz = viz_pred_mask(input_img_copy[i-1], mask, point_coords=None, point_labels=None, bbox_coords=None, seperate_scribble_masks=None, binary=binary_checkbox)
294
+
295
+ out_viz.append(tmp_viz)
296
+ out_viz.append(mask_img)
297
+
298
+ return click_input_viz, scribble_input_viz, out_viz, best_mask, low_res_mask, img_features, seperate_scribble_masks, last_scribble_mask
299
+
300
+
301
+ def get_select_coords(input_img, brush_label, bbox_label, best_mask, low_res_mask,
302
+ click_coords, click_labels, bbox_coords,
303
+ seperate_scribble_masks, last_scribble_mask, scribble_img, img_features,
304
+ output_img, binary_checkbox, multimask_mode, autopredict_checkbox, evt: gr.SelectData):
305
+ """
306
+ Record user click and update the prediction
307
+ """
308
+ # Record click coordinates
309
+ if bbox_label:
310
+ bbox_coords.append(evt.index)
311
+ elif brush_label in ['Positive (green)', 'Negative (red)']:
312
+ click_coords.append(evt.index)
313
+ click_labels.append(1 if brush_label=='Positive (green)' else 0)
314
+ else:
315
+ raise TypeError("Invalid brush label: {brush_label}")
316
+
317
+ # Only make new prediction if not waiting for additional bounding box click
318
+ if (len(bbox_coords) % 2 == 0) and autopredict_checkbox:
319
+
320
+ click_input_viz, scribble_input_viz, output_viz, best_mask, low_res_mask, img_features, seperate_scribble_masks, last_scribble_mask = refresh_predictions(
321
+ input_img, output_img, click_coords, click_labels, bbox_coords, brush_label,
322
+ scribble_img, seperate_scribble_masks, last_scribble_mask,
323
+ best_mask, low_res_mask, img_features, binary_checkbox, multimask_mode
324
+ )
325
+ return click_input_viz, scribble_input_viz, output_viz, best_mask, low_res_mask, img_features, click_coords, click_labels, bbox_coords, seperate_scribble_masks, last_scribble_mask
326
+
327
+ else:
328
+ click_input_viz = viz_pred_mask(
329
+ input_img, best_mask, click_coords, click_labels, bbox_coords, seperate_scribble_masks, binary_checkbox
330
+ )
331
+
332
+
333
+ scribble_input_viz = viz_pred_mask(
334
+ input_img, best_mask, click_coords, click_labels, bbox_coords, None, binary_checkbox
335
+ )
336
+ # Don't update output image if waiting for additional bounding box click
337
+ return click_input_viz, scribble_input_viz, output_img, best_mask, low_res_mask, img_features, click_coords, click_labels, bbox_coords, seperate_scribble_masks, last_scribble_mask
338
+
339
+
340
+ def undo_click( input_img, brush_label, bbox_label, best_mask, low_res_mask, click_coords, click_labels, bbox_coords,
341
+ seperate_scribble_masks, last_scribble_mask, scribble_img, img_features,
342
+ output_img, binary_checkbox, multimask_mode, autopredict_checkbox):
343
+ """
344
+ Remove last click and then update the prediction
345
+ """
346
+ if bbox_label:
347
+ if len(bbox_coords) > 0:
348
+ bbox_coords.pop()
349
+ elif brush_label in ['Positive (green)', 'Negative (red)']:
350
+ if len(click_coords) > 0:
351
+ click_coords.pop()
352
+ click_labels.pop()
353
+ else:
354
+ raise TypeError("Invalid brush label: {brush_label}")
355
+
356
+ # Only make new prediction if not waiting for additional bounding box click
357
+ if (len(bbox_coords)==0 or len(bbox_coords)==2) and autopredict_checkbox:
358
+
359
+ click_input_viz, scribble_input_viz, output_viz, best_mask, low_res_mask, img_features, seperate_scribble_masks, last_scribble_mask = refresh_predictions(
360
+ input_img, output_img, click_coords, click_labels, bbox_coords, brush_label,
361
+ scribble_img, seperate_scribble_masks, last_scribble_mask,
362
+ best_mask, low_res_mask, img_features, binary_checkbox, multimask_mode
363
+ )
364
+ return click_input_viz, scribble_input_viz, output_viz, best_mask, low_res_mask, img_features, click_coords, click_labels, bbox_coords, seperate_scribble_masks, last_scribble_mask
365
+
366
+ else:
367
+ click_input_viz = viz_pred_mask(
368
+ input_img, best_mask, click_coords, click_labels, bbox_coords, seperate_scribble_masks, binary_checkbox
369
+ )
370
+ scribble_input_viz = viz_pred_mask(
371
+ input_img, best_mask, click_coords, click_labels, bbox_coords, None, binary_checkbox
372
+ )
373
+
374
+ # Don't update output image if waiting for additional bounding box click
375
+ return click_input_viz, scribble_input_viz, output_img, best_mask, low_res_mask, img_features, click_coords, click_labels, bbox_coords, seperate_scribble_masks, last_scribble_mask
376
+
377
+
378
+
379
+ # --------------------------------------------------
380
+
381
+ with gr.Blocks(theme=gr.themes.Default(text_size=gr.themes.sizes.text_lg)) as demo:
382
+
383
+ # State variables
384
+ seperate_scribble_masks = gr.State(np.zeros((2, H, W), dtype=np.float32))
385
+ last_scribble_mask = gr.State(np.zeros((H, W), dtype=np.float32))
386
+
387
+ click_coords = gr.State([])
388
+ click_labels = gr.State([])
389
+ bbox_coords = gr.State([])
390
+
391
+ # Load default model
392
+ predictor = load_model(device=device)
393
+ img_features = gr.State(None) # For SAM models
394
+ best_mask = gr.State(None)
395
+ low_res_mask = gr.State(None)
396
+
397
+ gr.HTML("""\
398
+ <h1 style="text-align: center; font-size: 28pt;">PPC-SAM Demo</h1>
399
+ """)
400
+
401
+ with gr.Accordion("Open for instructions!", open=False):
402
+ gr.Markdown(
403
+ """
404
+ * Select an input image from the examples below or upload your own image through the <b>'Input Image'</b> tab.
405
+ * Use the <b>'Points/Boxes'</b> tab to draw <span style='color:green'>positive</span> or <span style='color:red'>negative</span> points and <span style='color:orange'>bounding boxes</span> by placing two points.
406
+ * The <b>'Output'</b> tab will show the models' prediction based on your current inputs and the previous prediction.
407
+ * The <b>'Output 1, 2'</b> are results of PPC-SAM; <b>'Output 3, 4'</b> are results of SAM-HQ; and <b>'Output 5, 6'</b> are results of SAM.
408
+ * The <b>'Clear All Inputs'</b> button will clear all inputs (including points, bounding boxes, and the last prediction).
409
+ """
410
+ )
411
+
412
+
413
+ # Interface ------------------------------------
414
+
415
+ with gr.Row():
416
+ model_dropdown = gr.Dropdown(
417
+ label="Model",
418
+ multiselect=False,
419
+ interactive=False,
420
+ visible=False
421
+ )
422
+
423
+ with gr.Row():
424
+ with gr.Column(scale=1):
425
+ brush_label = gr.Radio(["Positive (green)", "Negative (red)"],
426
+ value="Positive (green)", label="Scribble/Click Label")
427
+ bbox_label = gr.Checkbox(value=False, label="Bounding Box (2 points)")
428
+ with gr.Column(scale=1):
429
+
430
+ binary_checkbox = gr.Checkbox(value=True, label="Show binary masks", visible=False)
431
+ autopredict_checkbox = gr.Checkbox(value=False, label="Auto-update prediction on clicks", visible=False)
432
+ with gr.Accordion("Troubleshooting tips", open=False):
433
+ gr.Markdown("<span style='color:orange'>If you encounter an <span style='color:orange'>error</span> try clicking 'Clear All Inputs'.")
434
+ multimask_mode = gr.Checkbox(value=False, label="Multi-mask mode", visible=False)
435
+
436
+ with gr.Row():
437
+ display_height = 512
438
+
439
+ green_brush = gr.Brush(colors=["#00FF00"], color_mode="fixed", default_size=2)
440
+ red_brush = gr.Brush(colors=["#FF0000"], color_mode="fixed", default_size=2)
441
+
442
+ with gr.Column(scale=1):
443
+ scribble_img = gr.ImageEditor(
444
+ label="Input",
445
+ image_mode="RGB",
446
+ brush=green_brush,
447
+ type='numpy',
448
+ value=default_example,
449
+ transforms=(),
450
+ sources=(),
451
+ show_download_button=True,
452
+ # height=display_height,
453
+ visible=False
454
+ )
455
+
456
+ with gr.Tab("Points/Boxes") as click_tab:
457
+ click_img = gr.Image(
458
+ label="Input",
459
+ type='numpy',
460
+ value=default_example,
461
+ show_download_button=True,
462
+ sources=(),
463
+ container=True,
464
+ # height=display_height-50
465
+ )
466
+
467
+ with gr.Tab("Input Image"):
468
+ input_img = gr.Image(
469
+ label="Input",
470
+ image_mode="RGB",
471
+ value=default_example,
472
+ container=True
473
+ # height=display_height
474
+ )
475
+ gr.Markdown("To upload your own image: click the `x` in the top right corner to clear the current image, then drag & drop")
476
+
477
+ with gr.Row():
478
+ undo_click_button = gr.Button("Undo Last Click")
479
+ clear_click_button = gr.Button("Clear Points/Boxes", variant="stop")
480
+ with gr.Column(scale=1):
481
+ with gr.Tab("Output"):
482
+ output_img = gr.Gallery(
483
+ label='Output',
484
+ columns=1,
485
+ elem_id="gallery",
486
+ preview=True,
487
+ object_fit="scale-down",
488
+ # height=display_height,
489
+ container=True
490
+ )
491
+ gr.Markdown("Output 1, 2: PPC-SAM; Output 3, 4: SAM-HQ; Output 5, 6: SAM.")
492
+
493
+ submit_button = gr.Button("Submit", variant='primary')
494
+ clear_all_button = gr.ClearButton([scribble_img], value="Clear All Inputs", variant="stop")
495
+ clear_mask_button = gr.Button("Clear Input Mask", visible=False)
496
+
497
+
498
+ # ----------------------------------------------
499
+ # Loading Examples
500
+ # ----------------------------------------------
501
+
502
+ gr.Examples(examples=test_examples,
503
+ inputs=[input_img],
504
+ examples_per_page=12,
505
+ label='Examples from datasets unseen during training'
506
+ )
507
+
508
+ # When clear clicks button is clicked
509
+ def clear_click_history(input_img):
510
+ return input_img, input_img, [], [], [], None, None
511
+
512
+ clear_click_button.click(clear_click_history,
513
+ inputs=[input_img],
514
+ outputs=[click_img, scribble_img, click_coords, click_labels, bbox_coords, best_mask, low_res_mask])
515
+
516
+ # When clear all button is clicked
517
+ def clear_all_history(input_img):
518
+ if input_img is not None:
519
+ input_shape = input_img.shape[:2]
520
+ else:
521
+ input_shape = (H, W)
522
+ return input_img, input_img, [], [], [], [], np.zeros((2,)+input_shape, dtype=np.float32), np.zeros(input_shape, dtype=np.float32), None, None, None
523
+
524
+ # def clear_history_and_pad_input(input_img):
525
+ # if input_img is not None:
526
+ # h,w = input_img.shape[:2]
527
+ # if h != w:
528
+ # # Pad to square
529
+ # pad = abs(h-w)
530
+ # if h > w:
531
+ # padding = [(0,0), (math.ceil(pad/2),math.floor(pad/2))]
532
+ # else:
533
+ # padding = [(math.ceil(pad/2),math.floor(pad/2)), (0,0)]
534
+
535
+ # input_img = np.pad(input_img, padding, mode='constant', constant_values=0)
536
+
537
+ # return clear_all_history(input_img)
538
+
539
+
540
+ input_img.change(clear_all_history,
541
+ inputs=[input_img],
542
+ outputs=[click_img, scribble_img,
543
+ output_img, click_coords, click_labels, bbox_coords,
544
+ seperate_scribble_masks, last_scribble_mask,
545
+ best_mask, low_res_mask, img_features
546
+ ])
547
+
548
+ clear_all_button.click(clear_all_history,
549
+ inputs=[input_img],
550
+ outputs=[click_img, scribble_img,
551
+ output_img, click_coords, click_labels, bbox_coords,
552
+ seperate_scribble_masks, last_scribble_mask,
553
+ best_mask, low_res_mask, img_features
554
+ ])
555
+
556
+ # clear previous prediction mask
557
+ def clear_best_mask(input_img, click_coords, click_labels, bbox_coords, seperate_scribble_masks):
558
+
559
+ click_input_viz = viz_pred_mask(
560
+ input_img, None, click_coords, click_labels, bbox_coords, seperate_scribble_masks
561
+ )
562
+ scribble_input_viz = viz_pred_mask(
563
+ input_img, None, click_coords, click_labels, bbox_coords, None
564
+ )
565
+
566
+ return None, None, click_input_viz, scribble_input_viz
567
+
568
+ clear_mask_button.click(
569
+ clear_best_mask,
570
+ inputs=[input_img, click_coords, click_labels, bbox_coords, seperate_scribble_masks],
571
+ outputs=[best_mask, low_res_mask, click_img, scribble_img],
572
+ )
573
+
574
+ # ----------------------------------------------
575
+ # Clicks
576
+ # ----------------------------------------------
577
+
578
+ click_img.select(get_select_coords,
579
+ inputs=[
580
+ input_img, brush_label, bbox_label, best_mask, low_res_mask, click_coords, click_labels, bbox_coords,
581
+ seperate_scribble_masks, last_scribble_mask, scribble_img, img_features,
582
+ output_img, binary_checkbox, multimask_mode, autopredict_checkbox
583
+ ],
584
+ outputs=[click_img, scribble_img, output_img, best_mask, low_res_mask, img_features,
585
+ click_coords, click_labels, bbox_coords, seperate_scribble_masks, last_scribble_mask],
586
+ api_name = "get_select_coords"
587
+ )
588
+
589
+ submit_button.click(fn=refresh_predictions,
590
+ inputs=[input_img, output_img, click_coords, click_labels, bbox_coords, brush_label,
591
+ scribble_img, seperate_scribble_masks, last_scribble_mask,
592
+ best_mask, low_res_mask, img_features, binary_checkbox, multimask_mode
593
+ ],
594
+ outputs=[click_img, scribble_img, output_img, best_mask, low_res_mask, img_features,
595
+ seperate_scribble_masks, last_scribble_mask],
596
+ api_name="refresh_predictions"
597
+ )
598
+
599
+ undo_click_button.click(fn=undo_click,
600
+ inputs=[
601
+
602
+ input_img, brush_label, bbox_label, best_mask, low_res_mask, click_coords,
603
+ click_labels, bbox_coords,
604
+ seperate_scribble_masks, last_scribble_mask, scribble_img, img_features,
605
+ output_img, binary_checkbox, multimask_mode, autopredict_checkbox
606
+ ],
607
+ outputs=[click_img, scribble_img, output_img, best_mask, low_res_mask, img_features,
608
+ click_coords, click_labels, bbox_coords, seperate_scribble_masks, last_scribble_mask],
609
+ api_name="undo_click"
610
+ )
611
+
612
+ # ----------------------------------------------
613
+ # Scribbles
614
+ # ----------------------------------------------
615
+
616
+ def change_brush_color(seperate_scribble_masks, last_scribble_mask, scribble_img, label):
617
+ """
618
+ Recorn new scribbles when changing brush color
619
+ """
620
+ if label == "Negative (red)":
621
+ brush_update = gr.update(brush=red_brush)
622
+ elif label == "Positive (green)":
623
+ brush_update = gr.update(brush=green_brush)
624
+ else:
625
+ raise TypeError("Invalid brush color")
626
+
627
+ return seperate_scribble_masks, last_scribble_mask, brush_update
628
+
629
+ brush_label.change(fn=change_brush_color,
630
+ inputs=[seperate_scribble_masks, last_scribble_mask, scribble_img, brush_label],
631
+ outputs=[seperate_scribble_masks, last_scribble_mask, scribble_img],
632
+ api_name="change_brush_color"
633
+ )
634
+
635
+
636
+ if __name__ == "__main__":
637
+
638
+ demo.queue(api_open=False).launch(show_api=False)
examples/0.png ADDED
examples/0008733.png ADDED
examples/0015849.png ADDED
examples/0021429.png ADDED
examples/1.png ADDED
examples/2.png ADDED
examples/5.jpg ADDED
examples/6.jpg ADDED
examples/8.jpg ADDED
load_models.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from models import MaskDecoderHQ
2
+ from ppc_decoder import sam_decoder_reg
3
+ from segment_anything import sam_model_registry
4
+ import torch.nn as nn
5
+ import torch
6
+ import torch.nn.functional as F
7
+ import matplotlib.pyplot as plt
8
+ from utils.transforms import ResizeLongestSide
9
+ from typing import List
10
+
11
+ trans = ResizeLongestSide(target_length=1024)
12
+
13
+ def save_prob_visualization(prob, filename="prob_visualization.png"):
14
+ """
15
+ 可视化 1xwxh 的概率图并使用 plt.imshow 保存到本地
16
+ :param prob: 形状为 1xwxh 的 tensor
17
+ :param filename: 保存的文件名,默认为 'prob_visualization.png'
18
+ """
19
+ # 将 prob 转换为 numpy 数组
20
+ prob_np = prob.squeeze(0).squeeze(0).numpy() # 从 1xwxh 转为 wxh
21
+
22
+ # 使用 plt.imshow 可视化
23
+ plt.imshow(prob_np)
24
+ # , cmap='gray', vmin=0, vmax=1) # cmap='gray' 确保图像以灰度显示
25
+ plt.axis('off') # 关闭坐标轴
26
+
27
+ # 保存图像
28
+ plt.savefig(filename, bbox_inches='tight', pad_inches=0)
29
+ plt.close()
30
+ print(f"Probability map saved as {filename}")
31
+
32
+ def pad_to_square(x: torch.Tensor, target_size: int) -> torch.Tensor:
33
+ """Pad the input tensor to a square shape with the specified target size."""
34
+ # Get the current height and width of the image
35
+ h, w = x.shape[-2:]
36
+
37
+ # Calculate padding for height and width
38
+ padh = target_size - h
39
+ padw = target_size - w
40
+
41
+ # Pad the tensor to the target size
42
+ x = F.pad(x, (0, padw, 0, padh))
43
+ return x
44
+
45
+ def remove_none_values(input_dict):
46
+ """
47
+ Remove all items with None as their value from the dictionary.
48
+
49
+ Args:
50
+ input_dict (dict): The dictionary from which to remove None values.
51
+
52
+ Returns:
53
+ dict: A new dictionary with None values removed.
54
+ """
55
+ return {key: value for key, value in input_dict.items() if value is not None}
56
+
57
+ class PPC_SAM():
58
+ def __init__(self, model_type="vit_h",
59
+ ckpt_vit="pretrained_checkpoint/sam_vit_h_4b8939.pth",
60
+ ckpt_ppc="pretrained_checkpoint/ppc_decoder.pth",
61
+ ckpt_hq="pretrained_checkpoint/sam_hq_vit_h_decoder.pth",
62
+ device = "cpu") -> None:
63
+ # Call the parent class's __init__ method first
64
+
65
+ self.device = device
66
+
67
+ # Initialize the decoders
68
+ self.sam_hq_decoder = MaskDecoderHQ(model_type)
69
+ self.ppc_decoder = sam_decoder_reg['default']()
70
+
71
+ # Load state dictionaries
72
+ model_state_hq = torch.load(ckpt_hq, map_location=device)
73
+ self.sam_hq_decoder.load_state_dict(model_state_hq)
74
+ print(f"Loaded HQ decoder checkpoint from {ckpt_hq}")
75
+
76
+ model_state_ppc = torch.load(ckpt_ppc, map_location=device)
77
+ self.ppc_decoder.load_state_dict(model_state_ppc)
78
+ print(f"Loaded PPC decoder checkpoint from {ckpt_ppc}")
79
+
80
+ # Initialize the SAM model
81
+ self.sam = sam_model_registry[model_type](checkpoint=ckpt_vit).to(device)
82
+
83
+
84
+ def predict(self, prompts, multimask_ouput=False):
85
+ with torch.no_grad():
86
+ self.sam = self.sam.to(self.device)
87
+ self.sam_hq_decoder = self.sam_hq_decoder.to(self.device)
88
+ self.ppc_decoder = self.ppc_decoder.to(self.device)
89
+
90
+ batch_input = remove_none_values(prompts[0])
91
+ original_size = batch_input["image"].shape[:2]
92
+ batch_input["original_size"] = original_size
93
+
94
+ input_image = trans.apply_image(batch_input["image"])
95
+ input_image_torch = torch.as_tensor(input_image, device=self.device)
96
+ input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()
97
+ batch_input["image"] = input_image_torch
98
+
99
+ if "boxes" in batch_input:
100
+ batch_input["boxes"] = trans.apply_boxes_torch(batch_input["boxes"], original_size=original_size)
101
+ if "point_coords" in batch_input:
102
+ batch_input["point_coords"] = trans.apply_coords_torch(batch_input["point_coords"], original_size=original_size)
103
+
104
+
105
+ batched_output, interm_embeddings = self.sam([batch_input], multimask_output=multimask_ouput)
106
+
107
+ batch_len = len(batched_output)
108
+ encoder_embedding = torch.cat([batched_output[i_l]['encoder_embedding'] for i_l in range(batch_len)], dim=0)
109
+ image_pe = [batched_output[i_l]['image_pe'] for i_l in range(batch_len)]
110
+ sparse_embeddings = [batched_output[i_l]['sparse_embeddings'] for i_l in range(batch_len)]
111
+ dense_embeddings = [batched_output[i_l]['dense_embeddings'] for i_l in range(batch_len)]
112
+
113
+ masks_sam_in_hq, masks_hq = self.sam_hq_decoder(
114
+ image_embeddings=encoder_embedding,
115
+ image_pe=image_pe,
116
+ sparse_prompt_embeddings=sparse_embeddings,
117
+ dense_prompt_embeddings=dense_embeddings,
118
+ multimask_output=multimask_ouput,
119
+ hq_token_only=False,
120
+ interm_embeddings=interm_embeddings,
121
+ )
122
+
123
+ masks_sam = batched_output[0]["masks"]
124
+
125
+ input_images_ppc = pad_to_square(input_image_torch[None, :,:,:], target_size=1024).float()
126
+ mask_ppc = self.ppc_decoder(x_img=input_images_ppc, hidden_states_out=interm_embeddings, low_res_mask=masks_hq)
127
+
128
+ rescaled_masks_hq=self.sam.postprocess_masks(masks_hq, input_size=input_image_torch.shape[-2:], original_size=original_size)
129
+ rescaled_masks_ppc=self.sam.postprocess_masks(mask_ppc, input_size=input_image_torch.shape[-2:], original_size=original_size)
130
+
131
+ stacked_masks = torch.stack([rescaled_masks_ppc, rescaled_masks_hq, masks_sam.to(torch.uint8)], dim=0).cpu().squeeze(1).squeeze(1)
132
+ return stacked_masks, None, None
models.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from segment_anything.modeling import TwoWayTransformer, MaskDecoder
4
+ from typing import List, Tuple
5
+ import torch.nn.functional as F
6
+
7
+ class LayerNorm2d(nn.Module):
8
+ def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
9
+ super().__init__()
10
+ self.weight = nn.Parameter(torch.ones(num_channels))
11
+ self.bias = nn.Parameter(torch.zeros(num_channels))
12
+ self.eps = eps
13
+
14
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
15
+ u = x.mean(1, keepdim=True)
16
+ s = (x - u).pow(2).mean(1, keepdim=True)
17
+ x = (x - u) / torch.sqrt(s + self.eps)
18
+ x = self.weight[:, None, None] * x + self.bias[:, None, None]
19
+ return x
20
+
21
+ class MLP(nn.Module):
22
+ def __init__(
23
+ self,
24
+ input_dim: int,
25
+ hidden_dim: int,
26
+ output_dim: int,
27
+ num_layers: int,
28
+ sigmoid_output: bool = False,
29
+ ) -> None:
30
+ super().__init__()
31
+ self.num_layers = num_layers
32
+ h = [hidden_dim] * (num_layers - 1)
33
+ self.layers = nn.ModuleList(
34
+ nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
35
+ )
36
+ self.sigmoid_output = sigmoid_output
37
+
38
+ def forward(self, x):
39
+ for i, layer in enumerate(self.layers):
40
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
41
+ if self.sigmoid_output:
42
+ x = F.sigmoid(x)
43
+ return x
44
+
45
+ class MaskDecoderHQ(MaskDecoder):
46
+ def __init__(self, model_type):
47
+ super().__init__(transformer_dim=256,
48
+ transformer=TwoWayTransformer(
49
+ depth=2,
50
+ embedding_dim=256,
51
+ mlp_dim=2048,
52
+ num_heads=8,
53
+ ),
54
+ num_multimask_outputs=3,
55
+ activation=nn.GELU,
56
+ iou_head_depth= 3,
57
+ iou_head_hidden_dim= 256,)
58
+ assert model_type in ["vit_b","vit_l","vit_h"]
59
+
60
+ checkpoint_dict = {"vit_b":"pretrained_checkpoint/sam_vit_b_maskdecoder.pth",
61
+ "vit_l":"pretrained_checkpoint/sam_vit_l_maskdecoder.pth",
62
+ 'vit_h':"pretrained_checkpoint/sam_vit_h_maskdecoder.pth"}
63
+ checkpoint_path = checkpoint_dict[model_type]
64
+ self.load_state_dict(torch.load(checkpoint_path))
65
+ print("HQ Decoder init from SAM MaskDecoder")
66
+ for n,p in self.named_parameters():
67
+ p.requires_grad = False
68
+
69
+ transformer_dim=256
70
+ vit_dim_dict = {"vit_b":768,"vit_l":1024,"vit_h":1280}
71
+ vit_dim = vit_dim_dict[model_type]
72
+
73
+ self.hf_token = nn.Embedding(1, transformer_dim)
74
+ self.hf_mlp = MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3)
75
+ self.num_mask_tokens = self.num_mask_tokens + 1
76
+
77
+ self.compress_vit_feat = nn.Sequential(
78
+ nn.ConvTranspose2d(vit_dim, transformer_dim, kernel_size=2, stride=2),
79
+ LayerNorm2d(transformer_dim),
80
+ nn.GELU(),
81
+ nn.ConvTranspose2d(transformer_dim, transformer_dim // 8, kernel_size=2, stride=2))
82
+
83
+ self.embedding_encoder = nn.Sequential(
84
+ nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2),
85
+ LayerNorm2d(transformer_dim // 4),
86
+ nn.GELU(),
87
+ nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2),
88
+ )
89
+
90
+ self.embedding_maskfeature = nn.Sequential(
91
+ nn.Conv2d(transformer_dim // 8, transformer_dim // 4, 3, 1, 1),
92
+ LayerNorm2d(transformer_dim // 4),
93
+ nn.GELU(),
94
+ nn.Conv2d(transformer_dim // 4, transformer_dim // 8, 3, 1, 1))
95
+
96
+ def forward(
97
+ self,
98
+ image_embeddings: torch.Tensor,
99
+ image_pe: torch.Tensor,
100
+ sparse_prompt_embeddings: torch.Tensor,
101
+ dense_prompt_embeddings: torch.Tensor,
102
+ multimask_output: bool,
103
+ hq_token_only: bool,
104
+ interm_embeddings: torch.Tensor,
105
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
106
+ """
107
+ Predict masks given image and prompt embeddings.
108
+
109
+ Arguments:
110
+ image_embeddings (torch.Tensor): the embeddings from the ViT image encoder
111
+ image_pe (torch.Tensor): positional encoding with the shape of image_embeddings
112
+ sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes
113
+ dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs
114
+ multimask_output (bool): Whether to return multiple masks or a single
115
+ mask.
116
+
117
+ Returns:
118
+ torch.Tensor: batched predicted hq masks
119
+ """
120
+
121
+ vit_features = interm_embeddings[0].permute(0, 3, 1, 2) # early-layer ViT feature, after 1st global attention block in ViT
122
+ hq_features = self.embedding_encoder(image_embeddings) + self.compress_vit_feat(vit_features)
123
+
124
+ batch_len = len(image_embeddings)
125
+ masks = []
126
+ iou_preds = []
127
+ for i_batch in range(batch_len):
128
+ mask, iou_pred = self.predict_masks(
129
+ image_embeddings=image_embeddings[i_batch].unsqueeze(0),
130
+ image_pe=image_pe[i_batch],
131
+ sparse_prompt_embeddings=sparse_prompt_embeddings[i_batch],
132
+ dense_prompt_embeddings=dense_prompt_embeddings[i_batch],
133
+ hq_feature = hq_features[i_batch].unsqueeze(0)
134
+ )
135
+ masks.append(mask)
136
+ iou_preds.append(iou_pred)
137
+ masks = torch.cat(masks,0)
138
+ iou_preds = torch.cat(iou_preds,0)
139
+
140
+ # Select the correct mask or masks for output
141
+ if multimask_output:
142
+ # mask with highest score
143
+ mask_slice = slice(1,self.num_mask_tokens-1)
144
+ iou_preds = iou_preds[:, mask_slice]
145
+ iou_preds, max_iou_idx = torch.max(iou_preds,dim=1)
146
+ iou_preds = iou_preds.unsqueeze(1)
147
+ masks_multi = masks[:, mask_slice, :, :]
148
+ masks_sam = masks_multi[torch.arange(masks_multi.size(0)),max_iou_idx].unsqueeze(1)
149
+ else:
150
+ # singale mask output, default
151
+ mask_slice = slice(0, 1)
152
+ masks_sam = masks[:,mask_slice]
153
+
154
+ masks_hq = masks[:,slice(self.num_mask_tokens-1, self.num_mask_tokens), :, :]
155
+
156
+ if hq_token_only:
157
+ return masks_hq
158
+ else:
159
+ return masks_sam, masks_hq
160
+
161
+ def predict_masks(
162
+ self,
163
+ image_embeddings: torch.Tensor,
164
+ image_pe: torch.Tensor,
165
+ sparse_prompt_embeddings: torch.Tensor,
166
+ dense_prompt_embeddings: torch.Tensor,
167
+ hq_feature: torch.Tensor,
168
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
169
+ """Predicts masks. See 'forward' for more details."""
170
+
171
+ output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight, self.hf_token.weight], dim=0)
172
+ output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1)
173
+ tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
174
+
175
+ # Expand per-image data in batch direction to be per-mask
176
+ src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
177
+ src = src + dense_prompt_embeddings
178
+ pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
179
+ b, c, h, w = src.shape
180
+
181
+ # Run the transformer
182
+ hs, src = self.transformer(src, pos_src, tokens)
183
+ iou_token_out = hs[:, 0, :]
184
+ mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :]
185
+
186
+ # Upscale mask embeddings and predict masks using the mask tokens
187
+ src = src.transpose(1, 2).view(b, c, h, w)
188
+
189
+ upscaled_embedding_sam = self.output_upscaling(src)
190
+ upscaled_embedding_ours = self.embedding_maskfeature(upscaled_embedding_sam) + hq_feature
191
+
192
+ hyper_in_list: List[torch.Tensor] = []
193
+ for i in range(self.num_mask_tokens):
194
+ if i < 4:
195
+ hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]))
196
+ else:
197
+ hyper_in_list.append(self.hf_mlp(mask_tokens_out[:, i, :]))
198
+
199
+ hyper_in = torch.stack(hyper_in_list, dim=1)
200
+ b, c, h, w = upscaled_embedding_sam.shape
201
+
202
+ masks_sam = (hyper_in[:,:4] @ upscaled_embedding_sam.view(b, c, h * w)).view(b, -1, h, w)
203
+ masks_ours = (hyper_in[:,4:] @ upscaled_embedding_ours.view(b, c, h * w)).view(b, -1, h, w)
204
+ masks = torch.cat([masks_sam,masks_ours],dim=1)
205
+
206
+ iou_pred = self.iou_prediction_head(iou_token_out)
207
+
208
+ return masks, iou_pred
ppc_decoder.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from typing import Tuple, Union
4
+
5
+
6
+ from monai.networks.blocks.dynunet_block import UnetOutBlock
7
+ from monai.networks.blocks.unetr_block import UnetrBasicBlock, UnetrPrUpBlock, UnetrUpBlock
8
+
9
+ def build_sam_decoder_vit_h():
10
+ return _build_sam_decoder(
11
+ encoder_embed_dim=1280,
12
+ encoder_num_heads=16,
13
+ )
14
+
15
+ def build_sam_decoder_vit_l():
16
+ return _build_sam_decoder(
17
+ encoder_embed_dim=1024,
18
+ encoder_num_heads=16,
19
+ )
20
+
21
+ def build_sam_decoder_vit_b():
22
+ return _build_sam_decoder(
23
+ encoder_embed_dim=768,
24
+ encoder_num_heads=12,
25
+ )
26
+
27
+ sam_decoder_reg = {
28
+ "default": build_sam_decoder_vit_h,
29
+ "vit_h": build_sam_decoder_vit_h,
30
+ "vit_l": build_sam_decoder_vit_l,
31
+ "vit_b": build_sam_decoder_vit_b,
32
+ }
33
+
34
+ def _build_sam_decoder(
35
+ encoder_embed_dim,
36
+ encoder_num_heads,
37
+ ):
38
+ image_size = 1024
39
+ vit_patch_size = 16
40
+
41
+ return ImageDecoderViT(
42
+ hidden_size=encoder_embed_dim,
43
+ img_size=image_size,
44
+ num_heads=encoder_num_heads,
45
+ patch_size=vit_patch_size,
46
+ )
47
+
48
+ class ImageDecoderViT(nn.Module):
49
+
50
+ def __init__(
51
+ self,
52
+ in_channels: int = 3,
53
+
54
+ feature_size: int = 64,
55
+ hidden_size: int = 1280,
56
+ conv_block: bool = True,
57
+ res_block: bool = True,
58
+ norm_name: Union[Tuple, str] = "instance",
59
+ dropout_rate: float = 0.0,
60
+ spatial_dims: int = 2,
61
+
62
+ img_size: int = 1024,
63
+ patch_size: int = 16,
64
+ out_channels: int = 1,
65
+ num_heads: int = 12,
66
+ ) -> None:
67
+
68
+ super().__init__()
69
+
70
+ if not (0 <= dropout_rate <= 1):
71
+ raise AssertionError("dropout_rate should be between 0 and 1.")
72
+
73
+ if hidden_size % num_heads != 0:
74
+ raise AssertionError("hidden size should be divisible by num_heads.")
75
+
76
+ self.patch_size = patch_size
77
+ self.feat_size = (
78
+ img_size // self.patch_size,
79
+ img_size // self.patch_size
80
+ )
81
+ self.hidden_size = hidden_size
82
+ self.classification = False
83
+
84
+ self.encoder_low_res_mask = nn.Sequential(
85
+ UnetrBasicBlock(
86
+ spatial_dims=spatial_dims,
87
+ in_channels=out_channels,
88
+ out_channels=feature_size,
89
+ kernel_size=3,
90
+ stride=1,
91
+ norm_name=norm_name,
92
+ res_block=res_block,
93
+ ),
94
+ UnetrBasicBlock(
95
+ spatial_dims=spatial_dims,
96
+ in_channels=feature_size,
97
+ out_channels=feature_size * 4,
98
+ kernel_size=3,
99
+ stride=1,
100
+ norm_name=norm_name,
101
+ res_block=res_block,
102
+ ),
103
+ )
104
+
105
+ self.decoder_fuse = UnetrBasicBlock(
106
+ spatial_dims=spatial_dims,
107
+ in_channels=feature_size * 8,
108
+ out_channels=feature_size * 4,
109
+ kernel_size=3,
110
+ stride=1,
111
+ norm_name=norm_name,
112
+ res_block=res_block,
113
+ )
114
+
115
+ self.encoder1 = UnetrBasicBlock(
116
+ spatial_dims=spatial_dims,
117
+ in_channels=in_channels,
118
+ out_channels=feature_size,
119
+ kernel_size=3,
120
+ stride=1,
121
+ norm_name=norm_name,
122
+ res_block=res_block,
123
+ )
124
+ self.encoder2 = UnetrPrUpBlock(
125
+ spatial_dims=spatial_dims,
126
+ in_channels=hidden_size,
127
+ out_channels=feature_size * 2,
128
+ num_layer=2,
129
+ kernel_size=3,
130
+ stride=1,
131
+ upsample_kernel_size=2,
132
+ norm_name=norm_name,
133
+ conv_block=conv_block,
134
+ res_block=res_block,
135
+ )
136
+ self.encoder3 = UnetrPrUpBlock(
137
+ spatial_dims=spatial_dims,
138
+ in_channels=hidden_size,
139
+ out_channels=feature_size * 4,
140
+ num_layer=1,
141
+ kernel_size=3,
142
+ stride=1,
143
+ upsample_kernel_size=2,
144
+ norm_name=norm_name,
145
+ conv_block=conv_block,
146
+ res_block=res_block,
147
+ )
148
+ self.encoder4 = UnetrPrUpBlock(
149
+ spatial_dims=spatial_dims,
150
+ in_channels=hidden_size,
151
+ out_channels=feature_size * 8,
152
+ num_layer=0,
153
+ kernel_size=3,
154
+ stride=1,
155
+ upsample_kernel_size=2,
156
+ norm_name=norm_name,
157
+ conv_block=conv_block,
158
+ res_block=res_block,
159
+ )
160
+ self.decoder5 = UnetrUpBlock(
161
+ spatial_dims=spatial_dims,
162
+ in_channels=hidden_size,
163
+ out_channels=feature_size * 8,
164
+ kernel_size=3,
165
+ upsample_kernel_size=2,
166
+ norm_name=norm_name,
167
+ res_block=res_block,
168
+ )
169
+ self.decoder4 = UnetrUpBlock(
170
+ spatial_dims=spatial_dims,
171
+ in_channels=feature_size * 8,
172
+ out_channels=feature_size * 4,
173
+ kernel_size=3,
174
+ upsample_kernel_size=2,
175
+ norm_name=norm_name,
176
+ res_block=res_block,
177
+ )
178
+ self.decoder3 = UnetrUpBlock(
179
+ spatial_dims=spatial_dims,
180
+ in_channels=feature_size * 4,
181
+ out_channels=feature_size * 2,
182
+ kernel_size=3,
183
+ upsample_kernel_size=2,
184
+ norm_name=norm_name,
185
+ res_block=res_block,
186
+ )
187
+ self.decoder2 = UnetrUpBlock(
188
+ spatial_dims=spatial_dims,
189
+ in_channels=feature_size * 2,
190
+ out_channels=feature_size,
191
+ kernel_size=3,
192
+ upsample_kernel_size=2,
193
+ norm_name=norm_name,
194
+ res_block=res_block,
195
+ )
196
+ self.out = UnetOutBlock(spatial_dims=spatial_dims, in_channels=feature_size, out_channels=out_channels)
197
+ self.proj_axes = (0, spatial_dims + 1) + tuple(d + 1 for d in range(spatial_dims))
198
+ self.proj_view_shape = list(self.feat_size) + [self.hidden_size]
199
+
200
+
201
+ def proj_feat(self, x):
202
+ new_view = [x.size(0)] + self.proj_view_shape
203
+ x = x.view(new_view)
204
+ x = x.permute(self.proj_axes).contiguous()
205
+ return x
206
+
207
+ def forward(self, x_img,hidden_states_out, low_res_mask):
208
+
209
+ enc1 = self.encoder1(x_img)
210
+ x2 = hidden_states_out[0]
211
+ enc2 = self.encoder2(self.proj_feat(x2))
212
+ x3 = hidden_states_out[1]
213
+ enc3 = self.encoder3(self.proj_feat(x3))
214
+ x4 = hidden_states_out[2]
215
+ enc4 = self.encoder4(self.proj_feat(x4))
216
+
217
+ dec4 = self.proj_feat(hidden_states_out[3])
218
+ dec3 = self.decoder5(dec4, enc4)
219
+ dec2 = self.decoder4(dec3, enc3)
220
+
221
+ if low_res_mask != None:
222
+ enc_mask = self.encoder_low_res_mask(low_res_mask)
223
+ fused_dec2 = torch.cat([dec2, enc_mask], dim=1)
224
+ fused_dec2 = self.decoder_fuse(fused_dec2)
225
+ dec1 = self.decoder3(fused_dec2, enc2)
226
+ else:
227
+ dec1 = self.decoder3(dec2, enc2)
228
+
229
+ out = self.decoder2(dec1, enc1)
230
+
231
+ return self.out(out)
requirements.txt ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==0.31.0
2
+ aiofiles==23.2.1
3
+ aiohttp==3.9.5
4
+ aiohttp-socks==0.8.4
5
+ aioimaplib==1.0.1
6
+ aiosignal==1.3.1
7
+ aiosmtplib==2.0.2
8
+ albucore==0.0.12
9
+ albumentations==1.4.10
10
+ annotated-types==0.7.0
11
+ antlr4-python3-runtime==4.9.3
12
+ anyio==4.3.0
13
+ argon2-cffi==23.1.0
14
+ argon2-cffi-bindings==21.2.0
15
+ arrow==1.3.0
16
+ asttokens==2.4.1
17
+ async-lru==2.0.4
18
+ async-timeout==4.0.3
19
+ asyncio-atexit==1.0.1
20
+ attrs==23.2.0
21
+ Babel==2.14.0
22
+ backoff==2.2.1
23
+ base58==2.1.1
24
+ beautifulsoup4==4.12.3
25
+ bech32==1.2.0
26
+ bleach==6.1.0
27
+ cbor2==5.4.6
28
+ certifi==2024.2.2
29
+ cffi==1.16.0
30
+ charset-normalizer==3.3.2
31
+ click==8.1.7
32
+ colorama==0.4.6
33
+ comm==0.2.2
34
+ contourpy==1.2.1
35
+ cpe==1.2.1
36
+ cryptography==40.0.2
37
+ cycler==0.12.1
38
+ debugpy==1.8.1
39
+ decorator==5.1.1
40
+ defusedxml==0.7.1
41
+ Deprecated==1.2.14
42
+ dominate==2.9.1
43
+ einops==0.8.0
44
+ exceptiongroup==1.2.0
45
+ executing==2.0.1
46
+ fastapi==0.112.1
47
+ fastjsonschema==2.17.1
48
+ ffmpy==0.4.0
49
+ filelock==3.13.3
50
+ fonttools==4.51.0
51
+ fqdn==1.5.1
52
+ frozenlist==1.4.1
53
+ fsspec==2024.3.1
54
+ googleapis-common-protos==1.63.2
55
+ gradio==4.41.0
56
+ gradio_client==1.3.0
57
+ h11==0.14.0
58
+ html5lib==1.1
59
+ httpcore==1.0.5
60
+ httpx==0.27.0
61
+ huggingface-hub==0.23.4
62
+ idna==3.4
63
+ imageio==2.34.0
64
+ importlib-metadata==6.11.0
65
+ importlib_resources==6.4.3
66
+ ipykernel==6.29.4
67
+ ipympl==0.9.3
68
+ ipython==8.23.0
69
+ ipython-genutils==0.2.0
70
+ ipywidgets==8.1.2
71
+ isoduration==20.11.0
72
+ jedi==0.19.1
73
+ Jinja2==3.1.3
74
+ joblib==1.4.2
75
+ json5==0.9.24
76
+ jsonpointer==2.4
77
+ jsonschema==4.21.1
78
+ jsonschema-specifications==2023.12.1
79
+ jupyter-events==0.10.0
80
+ jupyter-lsp==2.2.4
81
+ jupyter_client==8.6.1
82
+ jupyter_core==5.7.2
83
+ jupyter_server==2.13.0
84
+ jupyter_server_terminals==0.5.3
85
+ jupyterlab==4.1.5
86
+ jupyterlab_pygments==0.3.0
87
+ jupyterlab_server==2.25.4
88
+ jupyterlab_widgets==3.0.10
89
+ kiwisolver==1.4.5
90
+ lark==1.1.5
91
+ lazy_loader==0.4
92
+ lmdb==1.4.1
93
+ markdown-it-py==3.0.0
94
+ MarkupSafe==2.1.5
95
+ matplotlib==3.8.4
96
+ matplotlib-inline==0.1.6
97
+ mdurl==0.1.2
98
+ -e git+https://github.com/bowang-lab/MedSAM.git@2b7c64cf80bf1aba546627db9b13db045dd1cbab#egg=medsam
99
+ mistune==3.0.2
100
+ -e git+https://github.com/Project-MONAI/MONAI.git@12d00ce1369e37cb06f483735ef83674a208b031#egg=monai
101
+ more-itertools==10.3.0
102
+ mpmath==1.3.0
103
+ msgpack==1.0.8
104
+ multidict==6.0.5
105
+ nbclient==0.10.0
106
+ nbconvert==7.16.3
107
+ nbformat==5.10.4
108
+ nest-asyncio==1.6.0
109
+ networkx==3.3
110
+ nibabel==5.2.1
111
+ notebook_shim==0.2.4
112
+ numpy==1.26.4
113
+ nvidia-cublas-cu12==12.1.3.1
114
+ nvidia-cuda-cupti-cu12==12.1.105
115
+ nvidia-cuda-nvrtc-cu12==12.1.105
116
+ nvidia-cuda-runtime-cu12==12.1.105
117
+ nvidia-cudnn-cu12==8.9.2.26
118
+ nvidia-cufft-cu12==11.0.2.54
119
+ nvidia-curand-cu12==10.3.2.106
120
+ nvidia-cusolver-cu12==11.4.5.107
121
+ nvidia-cusparse-cu12==12.1.0.106
122
+ nvidia-nccl-cu12==2.19.3
123
+ nvidia-nvjitlink-cu12==12.4.127
124
+ nvidia-nvtx-cu12==12.1.105
125
+ oauthlib==3.2.2
126
+ opencv-python==4.9.0.80
127
+ opencv-python-headless==4.10.0.84
128
+ opentelemetry-api==1.21.0
129
+ opentelemetry-exporter-otlp-proto-common==1.21.0
130
+ opentelemetry-exporter-otlp-proto-http==1.21.0
131
+ opentelemetry-proto==1.21.0
132
+ opentelemetry-sdk==1.21.0
133
+ opentelemetry-semantic-conventions==0.42b0
134
+ orjson==3.10.7
135
+ overrides==7.7.0
136
+ packaging==23.2
137
+ pandas==2.2.2
138
+ pandocfilters==1.5.1
139
+ parso==0.8.4
140
+ pexpect==4.9.0
141
+ pillow==10.3.0
142
+ platformdirs==4.2.0
143
+ prometheus_client==0.20.0
144
+ prompt-toolkit==3.0.43
145
+ protobuf==4.25.4
146
+ psutil==5.9.8
147
+ ptyprocess==0.7.0
148
+ pure-eval==0.2.2
149
+ pycparser==2.22
150
+ pycryptodome==3.18.0
151
+ pydantic==2.7.4
152
+ pydantic_core==2.18.4
153
+ pydub==0.25.1
154
+ Pygments==2.15.1
155
+ pyOpenSSL==23.2.0
156
+ pyparsing==3.1.2
157
+ PyQt5==5.15.10
158
+ PyQt5-Qt5==5.15.2
159
+ PyQt5-sip==12.13.0
160
+ python-bitcoinlib==0.12.2
161
+ python-dateutil==2.9.0.post0
162
+ python-json-logger==2.0.7
163
+ python-multipart==0.0.9
164
+ python-socks==2.5.0
165
+ pytz==2023.4
166
+ PyYAML==6.0.1
167
+ pyzmq==25.1.2
168
+ referencing==0.34.0
169
+ regex==2024.5.15
170
+ requests==2.31.0
171
+ rfc3339-validator==0.1.4
172
+ rfc3986-validator==0.1.1
173
+ rich==13.7.1
174
+ rpds-py==0.18.0
175
+ ruff==0.6.1
176
+ safetensors==0.4.3
177
+ scalecodec==1.2.11
178
+ scikit-image==0.22.0
179
+ scikit-learn==1.5.0
180
+ scipy==1.13.0
181
+ semantic-version==2.10.0
182
+ Send2Trash==1.8.3
183
+ shellingham==1.5.4
184
+ SimpleITK==2.3.1
185
+ simplejson==3.19.2
186
+ six==1.16.0
187
+ sniffio==1.3.1
188
+ soupsieve==2.5
189
+ stack-data==0.6.3
190
+ starlette==0.38.2
191
+ stix2-patterns==2.0.0
192
+ stix2-validator==3.2.0
193
+ sympy==1.12
194
+ synapse==2.139.0
195
+ synapseclient==4.4.0
196
+ terminado==0.18.1
197
+ threadpoolctl==3.5.0
198
+ tifffile==2024.2.12
199
+ tinycss2==1.2.1
200
+ tokenizers==0.19.1
201
+ tomli==2.0.1
202
+ tomlkit==0.12.0
203
+ torch==2.2.2
204
+ torchaudio==2.2.2
205
+ torchvision==0.17.2
206
+ tornado==6.4
207
+ tqdm==4.66.2
208
+ traitlets==5.14.2
209
+ transformers==4.41.2
210
+ triton==2.2.0
211
+ typer==0.12.4
212
+ types-python-dateutil==2.9.0.20240316
213
+ typing_extensions==4.11.0
214
+ tzdata==2024.1
215
+ uri-template==1.3.0
216
+ urllib3==2.2.2
217
+ uvicorn==0.30.6
218
+ vcrpy==4.3.1
219
+ wcwidth==0.2.13
220
+ webcolors==1.13
221
+ webencodings==0.5.1
222
+ websocket-client==1.7.0
223
+ websockets==12.0
224
+ widgetsnbextension==4.0.10
225
+ wrapt==1.16.0
226
+ xxhash==3.2.0
227
+ yarl==1.9.4
228
+ zipp==3.19.2
segment_anything/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from .build_sam import (
8
+ build_sam,
9
+ build_sam_vit_h,
10
+ build_sam_vit_l,
11
+ build_sam_vit_b,
12
+ sam_model_registry,
13
+ )
segment_anything/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (282 Bytes). View file
 
segment_anything/__pycache__/build_sam.cpython-310.pyc ADDED
Binary file (2.17 kB). View file
 
segment_anything/build_sam.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+
9
+ from functools import partial
10
+
11
+ from .modeling import ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, TwoWayTransformer
12
+
13
+
14
+ def build_sam_vit_h(checkpoint=None):
15
+ return _build_sam(
16
+ encoder_embed_dim=1280,
17
+ encoder_depth=32,
18
+ encoder_num_heads=16,
19
+ encoder_global_attn_indexes=[7, 15, 23, 31],
20
+ checkpoint=checkpoint,
21
+ )
22
+
23
+
24
+ build_sam = build_sam_vit_h
25
+
26
+
27
+ def build_sam_vit_l(checkpoint=None):
28
+ return _build_sam(
29
+ encoder_embed_dim=1024,
30
+ encoder_depth=24,
31
+ encoder_num_heads=16,
32
+ encoder_global_attn_indexes=[5, 11, 17, 23],
33
+ checkpoint=checkpoint,
34
+ )
35
+
36
+
37
+ def build_sam_vit_b(checkpoint=None):
38
+ return _build_sam(
39
+ encoder_embed_dim=768,
40
+ encoder_depth=12,
41
+ encoder_num_heads=12,
42
+ encoder_global_attn_indexes=[2, 5, 8, 11],
43
+ checkpoint=checkpoint,
44
+ )
45
+
46
+
47
+ sam_model_registry = {
48
+ "default": build_sam,
49
+ "vit_h": build_sam,
50
+ "vit_l": build_sam_vit_l,
51
+ "vit_b": build_sam_vit_b,
52
+ }
53
+
54
+
55
+ def _build_sam(
56
+ encoder_embed_dim,
57
+ encoder_depth,
58
+ encoder_num_heads,
59
+ encoder_global_attn_indexes,
60
+ checkpoint=None,
61
+ ):
62
+ prompt_embed_dim = 256
63
+ image_size = 1024
64
+ vit_patch_size = 16
65
+ image_embedding_size = image_size // vit_patch_size
66
+ sam = Sam(
67
+ image_encoder=ImageEncoderViT(
68
+ depth=encoder_depth,
69
+ embed_dim=encoder_embed_dim,
70
+ img_size=image_size,
71
+ mlp_ratio=4,
72
+ norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
73
+ num_heads=encoder_num_heads,
74
+ patch_size=vit_patch_size,
75
+ qkv_bias=True,
76
+ use_rel_pos=True,
77
+ global_attn_indexes=encoder_global_attn_indexes,
78
+ window_size=14,
79
+ out_chans=prompt_embed_dim,
80
+ ),
81
+ prompt_encoder=PromptEncoder(
82
+ embed_dim=prompt_embed_dim,
83
+ image_embedding_size=(image_embedding_size, image_embedding_size),
84
+ input_image_size=(image_size, image_size),
85
+ mask_in_chans=16,
86
+ ),
87
+ mask_decoder=MaskDecoder(
88
+ num_multimask_outputs=3,
89
+ transformer=TwoWayTransformer(
90
+ depth=2,
91
+ embedding_dim=prompt_embed_dim,
92
+ mlp_dim=2048,
93
+ num_heads=8,
94
+ ),
95
+ transformer_dim=prompt_embed_dim,
96
+ iou_head_depth=3,
97
+ iou_head_hidden_dim=256,
98
+ ),
99
+ pixel_mean=[123.675, 116.28, 103.53],
100
+ pixel_std=[58.395, 57.12, 57.375],
101
+ )
102
+ sam.eval()
103
+ if checkpoint is not None:
104
+ with open(checkpoint, "rb") as f:
105
+ state_dict = torch.load(f)
106
+ sam.load_state_dict(state_dict)
107
+ return sam
segment_anything/modeling/MaskDecoderHQ.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from segment_anything.modeling import TwoWayTransformer, MaskDecoder
5
+ from typing import Dict, List, Tuple
6
+
7
+ class LayerNorm2d(nn.Module):
8
+ def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
9
+ super().__init__()
10
+ self.weight = nn.Parameter(torch.ones(num_channels))
11
+ self.bias = nn.Parameter(torch.zeros(num_channels))
12
+ self.eps = eps
13
+
14
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
15
+ u = x.mean(1, keepdim=True)
16
+ s = (x - u).pow(2).mean(1, keepdim=True)
17
+ x = (x - u) / torch.sqrt(s + self.eps)
18
+ x = self.weight[:, None, None] * x + self.bias[:, None, None]
19
+ return x
20
+
21
+ class MLP(nn.Module):
22
+ def __init__(
23
+ self,
24
+ input_dim: int,
25
+ hidden_dim: int,
26
+ output_dim: int,
27
+ num_layers: int,
28
+ sigmoid_output: bool = False,
29
+ ) -> None:
30
+ super().__init__()
31
+ self.num_layers = num_layers
32
+ h = [hidden_dim] * (num_layers - 1)
33
+ self.layers = nn.ModuleList(
34
+ nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
35
+ )
36
+ self.sigmoid_output = sigmoid_output
37
+
38
+ def forward(self, x):
39
+ for i, layer in enumerate(self.layers):
40
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
41
+ if self.sigmoid_output:
42
+ x = F.sigmoid(x)
43
+ return x
44
+
45
+ class MaskDecoderHQ(MaskDecoder):
46
+ def __init__(self, model_type):
47
+ super().__init__(transformer_dim=256,
48
+ transformer=TwoWayTransformer(
49
+ depth=2,
50
+ embedding_dim=256,
51
+ mlp_dim=2048,
52
+ num_heads=8,
53
+ ),
54
+ num_multimask_outputs=3,
55
+ activation=nn.GELU,
56
+ iou_head_depth= 3,
57
+ iou_head_hidden_dim= 256,)
58
+ assert model_type in ["vit_b","vit_l","vit_h"]
59
+
60
+ checkpoint_dict = {"vit_b":"pretrained_checkpoint/sam_vit_b_maskdecoder.pth",
61
+ "vit_l":"pretrained_checkpoint/sam_vit_l_maskdecoder.pth",
62
+ 'vit_h':"pretrained_checkpoint/sam_vit_h_maskdecoder.pth"}
63
+ checkpoint_path = checkpoint_dict[model_type]
64
+ self.load_state_dict(torch.load(checkpoint_path))
65
+ print("HQ Decoder init from SAM MaskDecoder")
66
+ for n,p in self.named_parameters():
67
+ p.requires_grad = False
68
+
69
+ transformer_dim=256
70
+ vit_dim_dict = {"vit_b":768,"vit_l":1024,"vit_h":1280}
71
+ vit_dim = vit_dim_dict[model_type]
72
+
73
+ self.hf_token = nn.Embedding(1, transformer_dim)
74
+ self.hf_mlp = MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3)
75
+ self.num_mask_tokens = self.num_mask_tokens + 1
76
+
77
+ self.compress_vit_feat = nn.Sequential(
78
+ nn.ConvTranspose2d(vit_dim, transformer_dim, kernel_size=2, stride=2),
79
+ LayerNorm2d(transformer_dim),
80
+ nn.GELU(),
81
+ nn.ConvTranspose2d(transformer_dim, transformer_dim // 8, kernel_size=2, stride=2))
82
+
83
+ self.embedding_encoder = nn.Sequential(
84
+ nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2),
85
+ LayerNorm2d(transformer_dim // 4),
86
+ nn.GELU(),
87
+ nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2),
88
+ )
89
+
90
+ self.embedding_maskfeature = nn.Sequential(
91
+ nn.Conv2d(transformer_dim // 8, transformer_dim // 4, 3, 1, 1),
92
+ LayerNorm2d(transformer_dim // 4),
93
+ nn.GELU(),
94
+ nn.Conv2d(transformer_dim // 4, transformer_dim // 8, 3, 1, 1))
95
+
96
+
97
+ def forward(
98
+ self,
99
+ image_embeddings: torch.Tensor,
100
+ image_pe: torch.Tensor,
101
+ sparse_prompt_embeddings: torch.Tensor,
102
+ dense_prompt_embeddings: torch.Tensor,
103
+ multimask_output: bool,
104
+ hq_token_only: bool,
105
+ interm_embeddings: torch.Tensor,
106
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
107
+ """
108
+ Predict masks given image and prompt embeddings.
109
+
110
+ Arguments:
111
+ image_embeddings (torch.Tensor): the embeddings from the ViT image encoder
112
+ image_pe (torch.Tensor): positional encoding with the shape of image_embeddings
113
+ sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes
114
+ dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs
115
+ multimask_output (bool): Whether to return multiple masks or a single
116
+ mask.
117
+
118
+ Returns:
119
+ torch.Tensor: batched predicted hq masks
120
+ """
121
+
122
+ vit_features = interm_embeddings[0].permute(0, 3, 1, 2) # early-layer ViT feature, after 1st global attention block in ViT
123
+ hq_features = self.embedding_encoder(image_embeddings) + self.compress_vit_feat(vit_features)
124
+
125
+ batch_len = len(image_embeddings)
126
+ masks = []
127
+ iou_preds = []
128
+ for i_batch in range(batch_len):
129
+ mask, iou_pred = self.predict_masks(
130
+ image_embeddings=image_embeddings[i_batch].unsqueeze(0),
131
+ image_pe=image_pe[i_batch],
132
+ sparse_prompt_embeddings=sparse_prompt_embeddings[i_batch],
133
+ dense_prompt_embeddings=dense_prompt_embeddings[i_batch],
134
+ hq_feature = hq_features[i_batch].unsqueeze(0)
135
+ )
136
+ masks.append(mask)
137
+ iou_preds.append(iou_pred)
138
+ masks = torch.cat(masks,0)
139
+ iou_preds = torch.cat(iou_preds,0)
140
+
141
+ # Select the correct mask or masks for output
142
+ if multimask_output:
143
+ # mask with highest score
144
+ mask_slice = slice(1,self.num_mask_tokens-1)
145
+ iou_preds = iou_preds[:, mask_slice]
146
+ iou_preds, max_iou_idx = torch.max(iou_preds,dim=1)
147
+ iou_preds = iou_preds.unsqueeze(1)
148
+ masks_multi = masks[:, mask_slice, :, :]
149
+ masks_sam = masks_multi[torch.arange(masks_multi.size(0)),max_iou_idx].unsqueeze(1)
150
+ else:
151
+ # singale mask output, default
152
+ mask_slice = slice(0, 1)
153
+ masks_sam = masks[:,mask_slice]
154
+
155
+ masks_hq = masks[:,slice(self.num_mask_tokens-1, self.num_mask_tokens), :, :]
156
+
157
+ if hq_token_only:
158
+ return masks_hq
159
+ else:
160
+ return masks_sam, masks_hq
161
+
162
+ def predict_masks(
163
+ self,
164
+ image_embeddings: torch.Tensor,
165
+ image_pe: torch.Tensor,
166
+ sparse_prompt_embeddings: torch.Tensor,
167
+ dense_prompt_embeddings: torch.Tensor,
168
+ hq_feature: torch.Tensor,
169
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
170
+ """Predicts masks. See 'forward' for more details."""
171
+
172
+ output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight, self.hf_token.weight], dim=0)
173
+ output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1)
174
+ tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
175
+
176
+ # Expand per-image data in batch direction to be per-mask
177
+ src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
178
+ src = src + dense_prompt_embeddings
179
+ pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
180
+ b, c, h, w = src.shape
181
+
182
+ # Run the transformer
183
+ hs, src = self.transformer(src, pos_src, tokens)
184
+ iou_token_out = hs[:, 0, :]
185
+ mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :]
186
+
187
+ # Upscale mask embeddings and predict masks using the mask tokens
188
+ src = src.transpose(1, 2).view(b, c, h, w)
189
+
190
+ upscaled_embedding_sam = self.output_upscaling(src)
191
+ upscaled_embedding_ours = self.embedding_maskfeature(upscaled_embedding_sam) + hq_feature
192
+
193
+ hyper_in_list: List[torch.Tensor] = []
194
+ for i in range(self.num_mask_tokens):
195
+ if i < 4:
196
+ hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]))
197
+ else:
198
+ hyper_in_list.append(self.hf_mlp(mask_tokens_out[:, i, :]))
199
+
200
+ hyper_in = torch.stack(hyper_in_list, dim=1)
201
+ b, c, h, w = upscaled_embedding_sam.shape
202
+
203
+ masks_sam = (hyper_in[:,:4] @ upscaled_embedding_sam.view(b, c, h * w)).view(b, -1, h, w)
204
+ masks_ours = (hyper_in[:,4:] @ upscaled_embedding_ours.view(b, c, h * w)).view(b, -1, h, w)
205
+
206
+ masks = torch.cat([masks_sam,masks_ours],dim=1)
207
+
208
+ iou_pred = self.iou_prediction_head(iou_token_out)
209
+
210
+ return masks, iou_pred
segment_anything/modeling/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from .sam import Sam
8
+ from .image_encoder import ImageEncoderViT
9
+ from .mask_decoder import MaskDecoder
10
+ from .prompt_encoder import PromptEncoder
11
+ from .transformer import TwoWayTransformer
12
+ from .MaskDecoderHQ import MaskDecoderHQ
13
+
segment_anything/modeling/__pycache__/MaskDecoderHQ.cpython-310.pyc ADDED
Binary file (6.8 kB). View file
 
segment_anything/modeling/__pycache__/UpNet.cpython-310.pyc ADDED
Binary file (1.85 kB). View file
 
segment_anything/modeling/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (447 Bytes). View file
 
segment_anything/modeling/__pycache__/common.cpython-310.pyc ADDED
Binary file (1.77 kB). View file
 
segment_anything/modeling/__pycache__/image_encoder.cpython-310.pyc ADDED
Binary file (12.7 kB). View file
 
segment_anything/modeling/__pycache__/mask_decoder.cpython-310.pyc ADDED
Binary file (5.49 kB). View file
 
segment_anything/modeling/__pycache__/prompt_encoder.cpython-310.pyc ADDED
Binary file (7.7 kB). View file
 
segment_anything/modeling/__pycache__/sam.cpython-310.pyc ADDED
Binary file (6.76 kB). View file
 
segment_anything/modeling/__pycache__/transformer.cpython-310.pyc ADDED
Binary file (6.62 kB). View file
 
segment_anything/modeling/common.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+ from typing import Type
11
+
12
+
13
+ class MLPBlock(nn.Module):
14
+ def __init__(
15
+ self,
16
+ embedding_dim: int,
17
+ mlp_dim: int,
18
+ act: Type[nn.Module] = nn.GELU,
19
+ ) -> None:
20
+ super().__init__()
21
+ self.lin1 = nn.Linear(embedding_dim, mlp_dim)
22
+ self.lin2 = nn.Linear(mlp_dim, embedding_dim)
23
+ self.act = act()
24
+
25
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
26
+ return self.lin2(self.act(self.lin1(x)))
27
+
28
+
29
+ # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa
30
+ # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa
31
+ class LayerNorm2d(nn.Module):
32
+ def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
33
+ super().__init__()
34
+ self.weight = nn.Parameter(torch.ones(num_channels))
35
+ self.bias = nn.Parameter(torch.zeros(num_channels))
36
+ self.eps = eps
37
+
38
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
39
+ u = x.mean(1, keepdim=True)
40
+ s = (x - u).pow(2).mean(1, keepdim=True)
41
+ x = (x - u) / torch.sqrt(s + self.eps)
42
+ x = self.weight[:, None, None] * x + self.bias[:, None, None]
43
+ return x
segment_anything/modeling/image_encoder.py ADDED
@@ -0,0 +1,398 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+
11
+ from typing import Optional, Tuple, Type
12
+
13
+ from .common import LayerNorm2d, MLPBlock
14
+
15
+
16
+ # This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa
17
+ class ImageEncoderViT(nn.Module):
18
+ def __init__(
19
+ self,
20
+ img_size: int = 1024,
21
+ patch_size: int = 16,
22
+ in_chans: int = 3,
23
+ embed_dim: int = 768,
24
+ depth: int = 12,
25
+ num_heads: int = 12,
26
+ mlp_ratio: float = 4.0,
27
+ out_chans: int = 256,
28
+ qkv_bias: bool = True,
29
+ norm_layer: Type[nn.Module] = nn.LayerNorm,
30
+ act_layer: Type[nn.Module] = nn.GELU,
31
+ use_abs_pos: bool = True,
32
+ use_rel_pos: bool = False,
33
+ rel_pos_zero_init: bool = True,
34
+ window_size: int = 0,
35
+ global_attn_indexes: Tuple[int, ...] = (),
36
+ ) -> None:
37
+ """
38
+ Args:
39
+ img_size (int): Input image size.
40
+ patch_size (int): Patch size.
41
+ in_chans (int): Number of input image channels.
42
+ embed_dim (int): Patch embedding dimension.
43
+ depth (int): Depth of ViT.
44
+ num_heads (int): Number of attention heads in each ViT block.
45
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
46
+ qkv_bias (bool): If True, add a learnable bias to query, key, value.
47
+ norm_layer (nn.Module): Normalization layer.
48
+ act_layer (nn.Module): Activation layer.
49
+ use_abs_pos (bool): If True, use absolute positional embeddings.
50
+ use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
51
+ rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
52
+ window_size (int): Window size for window attention blocks.
53
+ global_attn_indexes (list): Indexes for blocks using global attention.
54
+ """
55
+ super().__init__()
56
+ self.img_size = img_size
57
+
58
+ self.patch_embed = PatchEmbed(
59
+ kernel_size=(patch_size, patch_size),
60
+ stride=(patch_size, patch_size),
61
+ in_chans=in_chans,
62
+ embed_dim=embed_dim,
63
+ )
64
+
65
+ self.pos_embed: Optional[nn.Parameter] = None
66
+ if use_abs_pos:
67
+ # Initialize absolute positional embedding with pretrain image size.
68
+ self.pos_embed = nn.Parameter(
69
+ torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim)
70
+ )
71
+
72
+ self.blocks = nn.ModuleList()
73
+
74
+ for i in range(depth):
75
+ block = Block(
76
+ dim=embed_dim,
77
+ num_heads=num_heads,
78
+ mlp_ratio=mlp_ratio,
79
+ qkv_bias=qkv_bias,
80
+ norm_layer=norm_layer,
81
+ act_layer=act_layer,
82
+ use_rel_pos=use_rel_pos,
83
+ rel_pos_zero_init=rel_pos_zero_init,
84
+ window_size=window_size if i not in global_attn_indexes else 0,
85
+ input_size=(img_size // patch_size, img_size // patch_size),
86
+ )
87
+ self.blocks.append(block)
88
+
89
+ self.neck = nn.Sequential(
90
+ nn.Conv2d(
91
+ embed_dim,
92
+ out_chans,
93
+ kernel_size=1,
94
+ bias=False,
95
+ ),
96
+ LayerNorm2d(out_chans),
97
+ nn.Conv2d(
98
+ out_chans,
99
+ out_chans,
100
+ kernel_size=3,
101
+ padding=1,
102
+ bias=False,
103
+ ),
104
+ LayerNorm2d(out_chans),
105
+ )
106
+
107
+
108
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
109
+ x = self.patch_embed(x)
110
+ if self.pos_embed is not None:
111
+ x = x + self.pos_embed
112
+ interm_embeddings=[]
113
+ for blk in self.blocks:
114
+ x = blk(x)
115
+ if blk.window_size == 0:
116
+ interm_embeddings.append(x)
117
+
118
+ x = self.neck(x.permute(0, 3, 1, 2))
119
+
120
+ return x, interm_embeddings
121
+
122
+
123
+ class Block(nn.Module):
124
+ """Transformer blocks with support of window attention and residual propagation blocks"""
125
+
126
+ def __init__(
127
+ self,
128
+ dim: int,
129
+ num_heads: int,
130
+ mlp_ratio: float = 4.0,
131
+ qkv_bias: bool = True,
132
+ norm_layer: Type[nn.Module] = nn.LayerNorm,
133
+ act_layer: Type[nn.Module] = nn.GELU,
134
+ use_rel_pos: bool = False,
135
+ rel_pos_zero_init: bool = True,
136
+ window_size: int = 0,
137
+ input_size: Optional[Tuple[int, int]] = None,
138
+ ) -> None:
139
+ """
140
+ Args:
141
+ dim (int): Number of input channels.
142
+ num_heads (int): Number of attention heads in each ViT block.
143
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
144
+ qkv_bias (bool): If True, add a learnable bias to query, key, value.
145
+ norm_layer (nn.Module): Normalization layer.
146
+ act_layer (nn.Module): Activation layer.
147
+ use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
148
+ rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
149
+ window_size (int): Window size for window attention blocks. If it equals 0, then
150
+ use global attention.
151
+ input_size (int or None): Input resolution for calculating the relative positional
152
+ parameter size.
153
+ """
154
+ super().__init__()
155
+ self.norm1 = norm_layer(dim)
156
+ self.attn = Attention(
157
+ dim,
158
+ num_heads=num_heads,
159
+ qkv_bias=qkv_bias,
160
+ use_rel_pos=use_rel_pos,
161
+ rel_pos_zero_init=rel_pos_zero_init,
162
+ input_size=input_size if window_size == 0 else (window_size, window_size),
163
+ )
164
+
165
+ self.norm2 = norm_layer(dim)
166
+ self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer)
167
+
168
+ self.window_size = window_size
169
+
170
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
171
+ shortcut = x
172
+ x = self.norm1(x)
173
+ # Window partition
174
+ if self.window_size > 0:
175
+ H, W = x.shape[1], x.shape[2]
176
+ x, pad_hw = window_partition(x, self.window_size)
177
+
178
+ x = self.attn(x)
179
+ # Reverse window partition
180
+ if self.window_size > 0:
181
+ x = window_unpartition(x, self.window_size, pad_hw, (H, W))
182
+
183
+ x = shortcut + x
184
+ x = x + self.mlp(self.norm2(x))
185
+
186
+ return x
187
+
188
+ class Attention(nn.Module):
189
+ """Multi-head Attention block with relative position embeddings."""
190
+
191
+ def __init__(
192
+ self,
193
+ dim: int,
194
+ num_heads: int = 8,
195
+ qkv_bias: bool = True,
196
+ use_rel_pos: bool = False,
197
+ rel_pos_zero_init: bool = True,
198
+ input_size: Optional[Tuple[int, int]] = None,
199
+ ) -> None:
200
+ """
201
+ Args:
202
+ dim (int): Number of input channels.
203
+ num_heads (int): Number of attention heads.
204
+ qkv_bias (bool: If True, add a learnable bias to query, key, value.
205
+ rel_pos (bool): If True, add relative positional embeddings to the attention map.
206
+ rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
207
+ input_size (int or None): Input resolution for calculating the relative positional
208
+ parameter size.
209
+ """
210
+ super().__init__()
211
+ self.num_heads = num_heads
212
+ head_dim = dim // num_heads
213
+ self.scale = head_dim**-0.5
214
+
215
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
216
+ self.proj = nn.Linear(dim, dim)
217
+
218
+ self.use_rel_pos = use_rel_pos
219
+ if self.use_rel_pos:
220
+ assert (
221
+ input_size is not None
222
+ ), "Input size must be provided if using relative positional encoding."
223
+ # initialize relative positional embeddings
224
+ self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
225
+ self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
226
+
227
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
228
+ B, H, W, _ = x.shape
229
+ # qkv with shape (3, B, nHead, H * W, C)
230
+ qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
231
+ # q, k, v with shape (B * nHead, H * W, C)
232
+ q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)
233
+
234
+ attn = (q * self.scale) @ k.transpose(-2, -1)
235
+
236
+ if self.use_rel_pos:
237
+ attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W))
238
+
239
+ attn = attn.softmax(dim=-1)
240
+ x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)
241
+ x = self.proj(x)
242
+ return x
243
+
244
+
245
+
246
+ def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]:
247
+ """
248
+ Partition into non-overlapping windows with padding if needed.
249
+ Args:
250
+ x (tensor): input tokens with [B, H, W, C].
251
+ window_size (int): window size.
252
+
253
+ Returns:
254
+ windows: windows after partition with [B * num_windows, window_size, window_size, C].
255
+ (Hp, Wp): padded height and width before partition
256
+ """
257
+ B, H, W, C = x.shape
258
+
259
+ pad_h = (window_size - H % window_size) % window_size
260
+ pad_w = (window_size - W % window_size) % window_size
261
+ if pad_h > 0 or pad_w > 0:
262
+ x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
263
+ Hp, Wp = H + pad_h, W + pad_w
264
+
265
+ x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
266
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
267
+ return windows, (Hp, Wp)
268
+
269
+
270
+ def window_unpartition(
271
+ windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int]
272
+ ) -> torch.Tensor:
273
+ """
274
+ Window unpartition into original sequences and removing padding.
275
+ Args:
276
+ x (tensor): input tokens with [B * num_windows, window_size, window_size, C].
277
+ window_size (int): window size.
278
+ pad_hw (Tuple): padded height and width (Hp, Wp).
279
+ hw (Tuple): original height and width (H, W) before padding.
280
+
281
+ Returns:
282
+ x: unpartitioned sequences with [B, H, W, C].
283
+ """
284
+ Hp, Wp = pad_hw
285
+ H, W = hw
286
+ B = windows.shape[0] // (Hp * Wp // window_size // window_size)
287
+ x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)
288
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
289
+
290
+ if Hp > H or Wp > W:
291
+ x = x[:, :H, :W, :].contiguous()
292
+ return x
293
+
294
+
295
+ def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
296
+ """
297
+ Get relative positional embeddings according to the relative positions of
298
+ query and key sizes.
299
+ Args:
300
+ q_size (int): size of query q.
301
+ k_size (int): size of key k.
302
+ rel_pos (Tensor): relative position embeddings (L, C).
303
+
304
+ Returns:
305
+ Extracted positional embeddings according to relative positions.
306
+ """
307
+ max_rel_dist = int(2 * max(q_size, k_size) - 1)
308
+ # Interpolate rel pos if needed.
309
+ if rel_pos.shape[0] != max_rel_dist:
310
+ # Interpolate rel pos.
311
+ rel_pos_resized = F.interpolate(
312
+ rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
313
+ size=max_rel_dist,
314
+ mode="linear",
315
+ )
316
+ rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
317
+ else:
318
+ rel_pos_resized = rel_pos
319
+
320
+ # Scale the coords with short length if shapes for q and k are different.
321
+ q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
322
+ k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
323
+ relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
324
+
325
+ return rel_pos_resized[relative_coords.long()]
326
+
327
+
328
+ def add_decomposed_rel_pos(
329
+ attn: torch.Tensor,
330
+ q: torch.Tensor,
331
+ rel_pos_h: torch.Tensor,
332
+ rel_pos_w: torch.Tensor,
333
+ q_size: Tuple[int, int],
334
+ k_size: Tuple[int, int],
335
+ ) -> torch.Tensor:
336
+ """
337
+ Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
338
+ https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950
339
+ Args:
340
+ attn (Tensor): attention map.
341
+ q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
342
+ rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
343
+ rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
344
+ q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
345
+ k_size (Tuple): spatial sequence size of key k with (k_h, k_w).
346
+
347
+ Returns:
348
+ attn (Tensor): attention map with added relative positional embeddings.
349
+ """
350
+ q_h, q_w = q_size
351
+ k_h, k_w = k_size
352
+ Rh = get_rel_pos(q_h, k_h, rel_pos_h)
353
+ Rw = get_rel_pos(q_w, k_w, rel_pos_w)
354
+
355
+ B, _, dim = q.shape
356
+ r_q = q.reshape(B, q_h, q_w, dim)
357
+ rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
358
+ rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
359
+
360
+ attn = (
361
+ attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
362
+ ).view(B, q_h * q_w, k_h * k_w)
363
+
364
+ return attn
365
+
366
+
367
+ class PatchEmbed(nn.Module):
368
+ """
369
+ Image to Patch Embedding.
370
+ """
371
+
372
+ def __init__(
373
+ self,
374
+ kernel_size: Tuple[int, int] = (16, 16),
375
+ stride: Tuple[int, int] = (16, 16),
376
+ padding: Tuple[int, int] = (0, 0),
377
+ in_chans: int = 3,
378
+ embed_dim: int = 768,
379
+ ) -> None:
380
+ """
381
+ Args:
382
+ kernel_size (Tuple): kernel size of the projection layer.
383
+ stride (Tuple): stride of the projection layer.
384
+ padding (Tuple): padding size of the projection layer.
385
+ in_chans (int): Number of input image channels.
386
+ embed_dim (int): embed_dim (int): Patch embedding dimension.
387
+ """
388
+ super().__init__()
389
+
390
+ self.proj = nn.Conv2d(
391
+ in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
392
+ )
393
+
394
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
395
+ x = self.proj(x)
396
+ # B C H W -> B H W C
397
+ x = x.permute(0, 2, 3, 1)
398
+ return x
segment_anything/modeling/mask_decoder.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ from torch import nn
9
+ from torch.nn import functional as F
10
+
11
+ from typing import List, Tuple, Type
12
+
13
+ from .common import LayerNorm2d
14
+
15
+
16
+ class MaskDecoder(nn.Module):
17
+ def __init__(
18
+ self,
19
+ *,
20
+ transformer_dim: int,
21
+ transformer: nn.Module,
22
+ num_multimask_outputs: int = 3,
23
+ activation: Type[nn.Module] = nn.GELU,
24
+ iou_head_depth: int = 3,
25
+ iou_head_hidden_dim: int = 256,
26
+ ) -> None:
27
+ """
28
+ Predicts masks given an image and prompt embeddings, using a
29
+ tranformer architecture.
30
+
31
+ Arguments:
32
+ transformer_dim (int): the channel dimension of the transformer
33
+ transformer (nn.Module): the transformer used to predict masks
34
+ num_multimask_outputs (int): the number of masks to predict
35
+ when disambiguating masks
36
+ activation (nn.Module): the type of activation to use when
37
+ upscaling masks
38
+ iou_head_depth (int): the depth of the MLP used to predict
39
+ mask quality
40
+ iou_head_hidden_dim (int): the hidden dimension of the MLP
41
+ used to predict mask quality
42
+ """
43
+ super().__init__()
44
+ self.transformer_dim = transformer_dim
45
+ self.transformer = transformer
46
+
47
+ self.num_multimask_outputs = num_multimask_outputs
48
+
49
+ self.iou_token = nn.Embedding(1, transformer_dim)
50
+ self.num_mask_tokens = num_multimask_outputs + 1
51
+ self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)
52
+
53
+ self.output_upscaling = nn.Sequential(
54
+ nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2),
55
+ LayerNorm2d(transformer_dim // 4),
56
+ activation(),
57
+ nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2),
58
+ activation(),
59
+ )
60
+ self.output_hypernetworks_mlps = nn.ModuleList(
61
+ [
62
+ MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3)
63
+ for i in range(self.num_mask_tokens)
64
+ ]
65
+ )
66
+
67
+ self.iou_prediction_head = MLP(
68
+ transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth
69
+ )
70
+
71
+ def forward(
72
+ self,
73
+ image_embeddings: torch.Tensor,
74
+ image_pe: torch.Tensor,
75
+ sparse_prompt_embeddings: torch.Tensor,
76
+ dense_prompt_embeddings: torch.Tensor,
77
+ multimask_output: bool,
78
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
79
+ """
80
+ Predict masks given image and prompt embeddings.
81
+
82
+ Arguments:
83
+ image_embeddings (torch.Tensor): the embeddings from the image encoder
84
+ image_pe (torch.Tensor): positional encoding with the shape of image_embeddings
85
+ sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes
86
+ dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs
87
+ multimask_output (bool): Whether to return multiple masks or a single
88
+ mask.
89
+
90
+ Returns:
91
+ torch.Tensor: batched predicted masks
92
+ torch.Tensor: batched predictions of mask quality
93
+ """
94
+ masks, iou_pred = self.predict_masks(
95
+ image_embeddings=image_embeddings,
96
+ image_pe=image_pe,
97
+ sparse_prompt_embeddings=sparse_prompt_embeddings,
98
+ dense_prompt_embeddings=dense_prompt_embeddings,
99
+ )
100
+
101
+ # Select the correct mask or masks for outptu
102
+ if multimask_output:
103
+ mask_slice = slice(1, None)
104
+ else:
105
+ mask_slice = slice(0, 1)
106
+ masks = masks[:, mask_slice, :, :]
107
+ iou_pred = iou_pred[:, mask_slice]
108
+
109
+ # Prepare output
110
+ return masks, iou_pred
111
+
112
+ def predict_masks(
113
+ self,
114
+ image_embeddings: torch.Tensor,
115
+ image_pe: torch.Tensor,
116
+ sparse_prompt_embeddings: torch.Tensor,
117
+ dense_prompt_embeddings: torch.Tensor,
118
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
119
+ """Predicts masks. See 'forward' for more details."""
120
+ # Concatenate output tokens
121
+ output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
122
+ output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1)
123
+ tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
124
+
125
+ # Expand per-image data in batch direction to be per-mask
126
+ src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
127
+ src = src + dense_prompt_embeddings
128
+ pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
129
+ b, c, h, w = src.shape
130
+
131
+ # Run the transformer
132
+ hs, src = self.transformer(src, pos_src, tokens)
133
+ iou_token_out = hs[:, 0, :]
134
+ mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :]
135
+
136
+ # Upscale mask embeddings and predict masks using the mask tokens
137
+ src = src.transpose(1, 2).view(b, c, h, w)
138
+ upscaled_embedding = self.output_upscaling(src)
139
+ hyper_in_list: List[torch.Tensor] = []
140
+ for i in range(self.num_mask_tokens):
141
+ hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]))
142
+ hyper_in = torch.stack(hyper_in_list, dim=1)
143
+ b, c, h, w = upscaled_embedding.shape
144
+ masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)
145
+
146
+ # Generate mask quality predictions
147
+ iou_pred = self.iou_prediction_head(iou_token_out)
148
+
149
+ return masks, iou_pred
150
+
151
+
152
+ # Lightly adapted from
153
+ # https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa
154
+ class MLP(nn.Module):
155
+ def __init__(
156
+ self,
157
+ input_dim: int,
158
+ hidden_dim: int,
159
+ output_dim: int,
160
+ num_layers: int,
161
+ sigmoid_output: bool = False,
162
+ ) -> None:
163
+ super().__init__()
164
+ self.num_layers = num_layers
165
+ h = [hidden_dim] * (num_layers - 1)
166
+ self.layers = nn.ModuleList(
167
+ nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
168
+ )
169
+ self.sigmoid_output = sigmoid_output
170
+
171
+ def forward(self, x):
172
+ for i, layer in enumerate(self.layers):
173
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
174
+ if self.sigmoid_output:
175
+ x = F.sigmoid(x)
176
+ return x
segment_anything/modeling/prompt_encoder.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import numpy as np
8
+ import torch
9
+ from torch import nn
10
+
11
+ from typing import Any, Optional, Tuple, Type
12
+
13
+ from .common import LayerNorm2d
14
+
15
+
16
+ class PromptEncoder(nn.Module):
17
+ def __init__(
18
+ self,
19
+ embed_dim: int,
20
+ image_embedding_size: Tuple[int, int],
21
+ input_image_size: Tuple[int, int],
22
+ mask_in_chans: int,
23
+ activation: Type[nn.Module] = nn.GELU,
24
+ ) -> None:
25
+ """
26
+ Encodes prompts for input to SAM's mask decoder.
27
+
28
+ Arguments:
29
+ embed_dim (int): The prompts' embedding dimension
30
+ image_embedding_size (tuple(int, int)): The spatial size of the
31
+ image embedding, as (H, W).
32
+ input_image_size (int): The padded size of the image as input
33
+ to the image encoder, as (H, W).
34
+ mask_in_chans (int): The number of hidden channels used for
35
+ encoding input masks.
36
+ activation (nn.Module): The activation to use when encoding
37
+ input masks.
38
+ """
39
+ super().__init__()
40
+ self.embed_dim = embed_dim
41
+ self.input_image_size = input_image_size
42
+ self.image_embedding_size = image_embedding_size
43
+ self.pe_layer = PositionEmbeddingRandom(embed_dim // 2)
44
+
45
+ self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners
46
+ point_embeddings = [nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)]
47
+ self.point_embeddings = nn.ModuleList(point_embeddings)
48
+ self.not_a_point_embed = nn.Embedding(1, embed_dim)
49
+
50
+ self.mask_input_size = (4 * image_embedding_size[0], 4 * image_embedding_size[1])
51
+ self.mask_downscaling = nn.Sequential(
52
+ nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2),
53
+ LayerNorm2d(mask_in_chans // 4),
54
+ activation(),
55
+ nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2),
56
+ LayerNorm2d(mask_in_chans),
57
+ activation(),
58
+ nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1),
59
+ )
60
+ self.no_mask_embed = nn.Embedding(1, embed_dim)
61
+
62
+ def get_dense_pe(self) -> torch.Tensor:
63
+ """
64
+ Returns the positional encoding used to encode point prompts,
65
+ applied to a dense set of points the shape of the image encoding.
66
+
67
+ Returns:
68
+ torch.Tensor: Positional encoding with shape
69
+ 1x(embed_dim)x(embedding_h)x(embedding_w)
70
+ """
71
+ return self.pe_layer(self.image_embedding_size).unsqueeze(0)
72
+
73
+ def _embed_points(
74
+ self,
75
+ points: torch.Tensor,
76
+ labels: torch.Tensor,
77
+ pad: bool,
78
+ ) -> torch.Tensor:
79
+ """Embeds point prompts."""
80
+ points = points + 0.5 # Shift to center of pixel
81
+ if pad:
82
+ padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device)
83
+ padding_label = -torch.ones((labels.shape[0], 1), device=labels.device)
84
+ points = torch.cat([points, padding_point], dim=1)
85
+ labels = torch.cat([labels, padding_label], dim=1)
86
+ point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size)
87
+ point_embedding[labels == -1] = 0.0
88
+ point_embedding[labels == -1] += self.not_a_point_embed.weight
89
+ point_embedding[labels == 0] += self.point_embeddings[0].weight
90
+ point_embedding[labels == 1] += self.point_embeddings[1].weight
91
+ return point_embedding
92
+
93
+ def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
94
+ """Embeds box prompts."""
95
+ boxes = boxes + 0.5 # Shift to center of pixel
96
+ coords = boxes.reshape(-1, 2, 2)
97
+ corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size)
98
+ corner_embedding[:, 0, :] += self.point_embeddings[2].weight
99
+ corner_embedding[:, 1, :] += self.point_embeddings[3].weight
100
+ return corner_embedding
101
+
102
+ def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor:
103
+ """Embeds mask inputs."""
104
+ mask_embedding = self.mask_downscaling(masks)
105
+ return mask_embedding
106
+
107
+ def _get_batch_size(
108
+ self,
109
+ points: Optional[Tuple[torch.Tensor, torch.Tensor]],
110
+ boxes: Optional[torch.Tensor],
111
+ masks: Optional[torch.Tensor],
112
+ ) -> int:
113
+ """
114
+ Gets the batch size of the output given the batch size of the input prompts.
115
+ """
116
+ if points is not None:
117
+ return points[0].shape[0]
118
+ elif boxes is not None:
119
+ return boxes.shape[0]
120
+ elif masks is not None:
121
+ return masks.shape[0]
122
+ else:
123
+ return 1
124
+
125
+ def _get_device(self) -> torch.device:
126
+ return self.point_embeddings[0].weight.device
127
+
128
+ def forward(
129
+ self,
130
+ points: Optional[Tuple[torch.Tensor, torch.Tensor]],
131
+ boxes: Optional[torch.Tensor],
132
+ masks: Optional[torch.Tensor],
133
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
134
+ """
135
+ Embeds different types of prompts, returning both sparse and dense
136
+ embeddings.
137
+
138
+ Arguments:
139
+ points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates
140
+ and labels to embed.
141
+ boxes (torch.Tensor or none): boxes to embed
142
+ masks (torch.Tensor or none): masks to embed
143
+
144
+ Returns:
145
+ torch.Tensor: sparse embeddings for the points and boxes, with shape
146
+ BxNx(embed_dim), where N is determined by the number of input points
147
+ and boxes.
148
+ torch.Tensor: dense embeddings for the masks, in the shape
149
+ Bx(embed_dim)x(embed_H)x(embed_W)
150
+ """
151
+ bs = self._get_batch_size(points, boxes, masks)
152
+ sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device())
153
+ if points is not None:
154
+ coords, labels = points
155
+ point_embeddings = self._embed_points(coords, labels, pad=(boxes is None))
156
+ sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1)
157
+ if boxes is not None:
158
+ box_embeddings = self._embed_boxes(boxes)
159
+ sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1)
160
+
161
+ if masks is not None:
162
+ dense_embeddings = self._embed_masks(masks)
163
+ else:
164
+ dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
165
+ bs, -1, self.image_embedding_size[0], self.image_embedding_size[1]
166
+ )
167
+
168
+ return sparse_embeddings, dense_embeddings
169
+
170
+
171
+ class PositionEmbeddingRandom(nn.Module):
172
+ """
173
+ Positional encoding using random spatial frequencies.
174
+ """
175
+
176
+ def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:
177
+ super().__init__()
178
+ if scale is None or scale <= 0.0:
179
+ scale = 1.0
180
+ self.register_buffer(
181
+ "positional_encoding_gaussian_matrix",
182
+ scale * torch.randn((2, num_pos_feats)),
183
+ )
184
+
185
+ def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
186
+ """Positionally encode points that are normalized to [0,1]."""
187
+ # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
188
+ coords = 2 * coords - 1
189
+ coords = coords @ self.positional_encoding_gaussian_matrix
190
+ coords = 2 * np.pi * coords
191
+ # outputs d_1 x ... x d_n x C shape
192
+ return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
193
+
194
+ def forward(self, size: Tuple[int, int]) -> torch.Tensor:
195
+ """Generate positional encoding for a grid of the specified size."""
196
+ h, w = size
197
+ device: Any = self.positional_encoding_gaussian_matrix.device
198
+ grid = torch.ones((h, w), device=device, dtype=torch.float32)
199
+ y_embed = grid.cumsum(dim=0) - 0.5
200
+ x_embed = grid.cumsum(dim=1) - 0.5
201
+ y_embed = y_embed / h
202
+ x_embed = x_embed / w
203
+
204
+ pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1))
205
+ return pe.permute(2, 0, 1) # C x H x W
206
+
207
+ def forward_with_coords(
208
+ self, coords_input: torch.Tensor, image_size: Tuple[int, int]
209
+ ) -> torch.Tensor:
210
+ """Positionally encode points that are not normalized to [0,1]."""
211
+ coords = coords_input.clone()
212
+ coords[:, :, 0] = coords[:, :, 0] / image_size[1]
213
+ coords[:, :, 1] = coords[:, :, 1] / image_size[0]
214
+ return self._pe_encoding(coords.to(torch.float)) # B x N x C
segment_anything/modeling/sam.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ from torch import nn
9
+ from torch.nn import functional as F
10
+
11
+ from typing import Any, Dict, List, Tuple
12
+
13
+ from .image_encoder import ImageEncoderViT
14
+ from .mask_decoder import MaskDecoder
15
+ from .prompt_encoder import PromptEncoder
16
+
17
+
18
+ class Sam(nn.Module):
19
+ mask_threshold: float = 0.0
20
+ image_format: str = "RGB"
21
+
22
+ def __init__(
23
+ self,
24
+ image_encoder: ImageEncoderViT,
25
+ prompt_encoder: PromptEncoder,
26
+ mask_decoder: MaskDecoder,
27
+ pixel_mean: List[float] = [123.675, 116.28, 103.53],
28
+ pixel_std: List[float] = [58.395, 57.12, 57.375],
29
+ ) -> None:
30
+ """
31
+ SAM predicts object masks from an image and input prompts.
32
+
33
+ Arguments:
34
+ image_encoder (ImageEncoderViT): The backbone used to encode the
35
+ image into image embeddings that allow for efficient mask prediction.
36
+ prompt_encoder (PromptEncoder): Encodes various types of input prompts.
37
+ mask_decoder (MaskDecoder): Predicts masks from the image embeddings
38
+ and encoded prompts.
39
+ pixel_mean (list(float)): Mean values for normalizing pixels in the input image.
40
+ pixel_std (list(float)): Std values for normalizing pixels in the input image.
41
+ """
42
+ super().__init__()
43
+ self.image_encoder = image_encoder
44
+ self.prompt_encoder = prompt_encoder
45
+ self.mask_decoder = mask_decoder
46
+ self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False)
47
+ self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)
48
+
49
+ @property
50
+ def device(self) -> Any:
51
+ return self.pixel_mean.device
52
+
53
+ @torch.no_grad()
54
+ def forward(
55
+ self,
56
+ batched_input: List[Dict[str, Any]],
57
+ multimask_output: bool,
58
+ ) -> List[Dict[str, torch.Tensor]]:
59
+ """
60
+ Predicts masks end-to-end from provided images and prompts.
61
+ If prompts are not known in advance, using SamPredictor is
62
+ recommended over calling the model directly.
63
+
64
+ Arguments:
65
+ batched_input (list(dict)): A list over input images, each a
66
+ dictionary with the following keys. A prompt key can be
67
+ excluded if it is not present.
68
+ 'image': The image as a torch tensor in 3xHxW format,
69
+ already transformed for input to the model.
70
+ 'original_size': (tuple(int, int)) The original size of
71
+ the image before transformation, as (H, W).
72
+ 'point_coords': (torch.Tensor) Batched point prompts for
73
+ this image, with shape BxNx2. Already transformed to the
74
+ input frame of the model.
75
+ 'point_labels': (torch.Tensor) Batched labels for point prompts,
76
+ with shape BxN.
77
+ 'boxes': (torch.Tensor) Batched box inputs, with shape Bx4.
78
+ Already transformed to the input frame of the model.
79
+ 'mask_inputs': (torch.Tensor) Batched mask inputs to the model,
80
+ in the form Bx1xHxW.
81
+ multimask_output (bool): Whether the model should predict multiple
82
+ disambiguating masks, or return a single mask.
83
+
84
+ Returns:
85
+ (list(dict)): A list over input images, where each element is
86
+ as dictionary with the following keys.
87
+ 'masks': (torch.Tensor) Batched binary mask predictions,
88
+ with shape BxCxHxW, where B is the number of input promts,
89
+ C is determiend by multimask_output, and (H, W) is the
90
+ original size of the image.
91
+ 'iou_predictions': (torch.Tensor) The model's predictions
92
+ of mask quality, in shape BxC.
93
+ 'low_res_logits': (torch.Tensor) Low resolution logits with
94
+ shape BxCxHxW, where H=W=256. Can be passed as mask input
95
+ to subsequent iterations of prediction.
96
+ """
97
+ input_images = torch.stack([self.preprocess(x["image"]) for x in batched_input], dim=0)
98
+
99
+ image_embeddings, interm_embeddings = self.image_encoder(input_images)
100
+
101
+ outputs = []
102
+ for image_record, curr_embedding in zip(batched_input, image_embeddings):
103
+ if "point_coords" in image_record:
104
+ points = (image_record["point_coords"], image_record["point_labels"])
105
+ else:
106
+ points = None
107
+ sparse_embeddings, dense_embeddings = self.prompt_encoder(
108
+ points=points,
109
+ boxes=image_record.get("boxes", None),
110
+ masks=image_record.get("mask_inputs", None),
111
+ )
112
+ low_res_masks, iou_predictions = self.mask_decoder(
113
+ image_embeddings=curr_embedding.unsqueeze(0),
114
+ image_pe=self.prompt_encoder.get_dense_pe(),
115
+ sparse_prompt_embeddings=sparse_embeddings,
116
+ dense_prompt_embeddings=dense_embeddings,
117
+ multimask_output=multimask_output
118
+ )
119
+
120
+ masks = self.postprocess_masks(
121
+ low_res_masks,
122
+ input_size=image_record["image"].shape[-2:],
123
+ original_size=image_record["original_size"],
124
+ )
125
+ masks = masks > self.mask_threshold
126
+
127
+ outputs.append(
128
+ {
129
+ "masks": masks,
130
+ "iou_predictions": iou_predictions,
131
+ "low_res_logits": low_res_masks,
132
+ "encoder_embedding": curr_embedding.unsqueeze(0),
133
+ "image_pe": self.prompt_encoder.get_dense_pe(),
134
+ "sparse_embeddings":sparse_embeddings,
135
+ "dense_embeddings":dense_embeddings,
136
+ }
137
+ )
138
+
139
+ return outputs, interm_embeddings
140
+
141
+ def postprocess_masks(
142
+ self,
143
+ masks: torch.Tensor,
144
+ input_size: Tuple[int, ...],
145
+ original_size: Tuple[int, ...],
146
+ ) -> torch.Tensor:
147
+ """
148
+ Remove padding and upscale masks to the original image size.
149
+
150
+ Arguments:
151
+ masks (torch.Tensor): Batched masks from the mask_decoder,
152
+ in BxCxHxW format.
153
+ input_size (tuple(int, int)): The size of the image input to the
154
+ model, in (H, W) format. Used to remove padding.
155
+ original_size (tuple(int, int)): The original size of the image
156
+ before resizing for input to the model, in (H, W) format.
157
+
158
+ Returns:
159
+ (torch.Tensor): Batched masks in BxCxHxW format, where (H, W)
160
+ is given by original_size.
161
+ """
162
+ masks = F.interpolate(
163
+ masks,
164
+ (self.image_encoder.img_size, self.image_encoder.img_size),
165
+ mode="bilinear",
166
+ align_corners=False,
167
+ )
168
+ masks = masks[..., : input_size[0], : input_size[1]]
169
+ masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False)
170
+ return masks
171
+
172
+ def preprocess(self, x: torch.Tensor) -> torch.Tensor:
173
+ """Normalize pixel values and pad to a square input."""
174
+ # Normalize colors
175
+ x = (x - self.pixel_mean) / self.pixel_std
176
+
177
+ # Pad
178
+ h, w = x.shape[-2:]
179
+ padh = self.image_encoder.img_size - h
180
+ padw = self.image_encoder.img_size - w
181
+ x = F.pad(x, (0, padw, 0, padh))
182
+ return x
segment_anything/modeling/transformer.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ from torch import Tensor, nn
9
+
10
+ import math
11
+ from typing import Tuple, Type
12
+
13
+ from .common import MLPBlock
14
+
15
+
16
+ class TwoWayTransformer(nn.Module):
17
+ def __init__(
18
+ self,
19
+ depth: int,
20
+ embedding_dim: int,
21
+ num_heads: int,
22
+ mlp_dim: int,
23
+ activation: Type[nn.Module] = nn.ReLU,
24
+ attention_downsample_rate: int = 2,
25
+ ) -> None:
26
+ """
27
+ A transformer decoder that attends to an input image using
28
+ queries whose positional embedding is supplied.
29
+
30
+ Args:
31
+ depth (int): number of layers in the transformer
32
+ embedding_dim (int): the channel dimension for the input embeddings
33
+ num_heads (int): the number of heads for multihead attention. Must
34
+ divide embedding_dim
35
+ mlp_dim (int): the channel dimension internal to the MLP block
36
+ activation (nn.Module): the activation to use in the MLP block
37
+ """
38
+ super().__init__()
39
+ self.depth = depth
40
+ self.embedding_dim = embedding_dim
41
+ self.num_heads = num_heads
42
+ self.mlp_dim = mlp_dim
43
+ self.layers = nn.ModuleList()
44
+
45
+ for i in range(depth):
46
+ self.layers.append(
47
+ TwoWayAttentionBlock(
48
+ embedding_dim=embedding_dim,
49
+ num_heads=num_heads,
50
+ mlp_dim=mlp_dim,
51
+ activation=activation,
52
+ attention_downsample_rate=attention_downsample_rate,
53
+ skip_first_layer_pe=(i == 0),
54
+ )
55
+ )
56
+
57
+ self.final_attn_token_to_image = Attention(
58
+ embedding_dim, num_heads, downsample_rate=attention_downsample_rate
59
+ )
60
+ self.norm_final_attn = nn.LayerNorm(embedding_dim)
61
+
62
+ def forward(
63
+ self,
64
+ image_embedding: Tensor,
65
+ image_pe: Tensor,
66
+ point_embedding: Tensor,
67
+ ) -> Tuple[Tensor, Tensor]:
68
+ """
69
+ Args:
70
+ image_embedding (torch.Tensor): image to attend to. Should be shape
71
+ B x embedding_dim x h x w for any h and w.
72
+ image_pe (torch.Tensor): the positional encoding to add to the image. Must
73
+ have the same shape as image_embedding.
74
+ point_embedding (torch.Tensor): the embedding to add to the query points.
75
+ Must have shape B x N_points x embedding_dim for any N_points.
76
+
77
+ Returns:
78
+ torch.Tensor: the processed point_embedding
79
+ torch.Tensor: the processed image_embedding
80
+ """
81
+ # BxCxHxW -> BxHWxC == B x N_image_tokens x C
82
+ bs, c, h, w = image_embedding.shape
83
+ image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
84
+ image_pe = image_pe.flatten(2).permute(0, 2, 1)
85
+
86
+ # Prepare queries
87
+ queries = point_embedding
88
+ keys = image_embedding
89
+
90
+ # Apply transformer blocks and final layernorm
91
+ for layer in self.layers:
92
+ queries, keys = layer(
93
+ queries=queries,
94
+ keys=keys,
95
+ query_pe=point_embedding,
96
+ key_pe=image_pe,
97
+ )
98
+
99
+ # Apply the final attenion layer from the points to the image
100
+ q = queries + point_embedding
101
+ k = keys + image_pe
102
+ attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys)
103
+ queries = queries + attn_out
104
+ queries = self.norm_final_attn(queries)
105
+
106
+ return queries, keys
107
+
108
+
109
+ class TwoWayAttentionBlock(nn.Module):
110
+ def __init__(
111
+ self,
112
+ embedding_dim: int,
113
+ num_heads: int,
114
+ mlp_dim: int = 2048,
115
+ activation: Type[nn.Module] = nn.ReLU,
116
+ attention_downsample_rate: int = 2,
117
+ skip_first_layer_pe: bool = False,
118
+ ) -> None:
119
+ """
120
+ A transformer block with four layers: (1) self-attention of sparse
121
+ inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp
122
+ block on sparse inputs, and (4) cross attention of dense inputs to sparse
123
+ inputs.
124
+
125
+ Arguments:
126
+ embedding_dim (int): the channel dimension of the embeddings
127
+ num_heads (int): the number of heads in the attention layers
128
+ mlp_dim (int): the hidden dimension of the mlp block
129
+ activation (nn.Module): the activation of the mlp block
130
+ skip_first_layer_pe (bool): skip the PE on the first layer
131
+ """
132
+ super().__init__()
133
+ self.self_attn = Attention(embedding_dim, num_heads)
134
+ self.norm1 = nn.LayerNorm(embedding_dim)
135
+
136
+ self.cross_attn_token_to_image = Attention(
137
+ embedding_dim, num_heads, downsample_rate=attention_downsample_rate
138
+ )
139
+ self.norm2 = nn.LayerNorm(embedding_dim)
140
+
141
+ self.mlp = MLPBlock(embedding_dim, mlp_dim, activation)
142
+ self.norm3 = nn.LayerNorm(embedding_dim)
143
+
144
+ self.norm4 = nn.LayerNorm(embedding_dim)
145
+ self.cross_attn_image_to_token = Attention(
146
+ embedding_dim, num_heads, downsample_rate=attention_downsample_rate
147
+ )
148
+
149
+ self.skip_first_layer_pe = skip_first_layer_pe
150
+
151
+ def forward(
152
+ self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor
153
+ ) -> Tuple[Tensor, Tensor]:
154
+ # Self attention block
155
+ if self.skip_first_layer_pe:
156
+ queries = self.self_attn(q=queries, k=queries, v=queries)
157
+ else:
158
+ q = queries + query_pe
159
+ attn_out = self.self_attn(q=q, k=q, v=queries)
160
+ queries = queries + attn_out
161
+ queries = self.norm1(queries)
162
+
163
+ # Cross attention block, tokens attending to image embedding
164
+ q = queries + query_pe
165
+ k = keys + key_pe
166
+ attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)
167
+ queries = queries + attn_out
168
+ queries = self.norm2(queries)
169
+
170
+ # MLP block
171
+ mlp_out = self.mlp(queries)
172
+ queries = queries + mlp_out
173
+ queries = self.norm3(queries)
174
+
175
+ # Cross attention block, image embedding attending to tokens
176
+ q = queries + query_pe
177
+ k = keys + key_pe
178
+ attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)
179
+ keys = keys + attn_out
180
+ keys = self.norm4(keys)
181
+
182
+ return queries, keys
183
+
184
+
185
+ class Attention(nn.Module):
186
+ """
187
+ An attention layer that allows for downscaling the size of the embedding
188
+ after projection to queries, keys, and values.
189
+ """
190
+
191
+ def __init__(
192
+ self,
193
+ embedding_dim: int,
194
+ num_heads: int,
195
+ downsample_rate: int = 1,
196
+ ) -> None:
197
+ super().__init__()
198
+ self.embedding_dim = embedding_dim
199
+ self.internal_dim = embedding_dim // downsample_rate
200
+ self.num_heads = num_heads
201
+ assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim."
202
+
203
+ self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
204
+ self.k_proj = nn.Linear(embedding_dim, self.internal_dim)
205
+ self.v_proj = nn.Linear(embedding_dim, self.internal_dim)
206
+ self.out_proj = nn.Linear(self.internal_dim, embedding_dim)
207
+
208
+ def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor:
209
+ b, n, c = x.shape
210
+ x = x.reshape(b, n, num_heads, c // num_heads)
211
+ return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head
212
+
213
+ def _recombine_heads(self, x: Tensor) -> Tensor:
214
+ b, n_heads, n_tokens, c_per_head = x.shape
215
+ x = x.transpose(1, 2)
216
+ return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C
217
+
218
+ def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
219
+ # Input projections
220
+ q = self.q_proj(q)
221
+ k = self.k_proj(k)
222
+ v = self.v_proj(v)
223
+
224
+ # Separate into heads
225
+ q = self._separate_heads(q, self.num_heads)
226
+ k = self._separate_heads(k, self.num_heads)
227
+ v = self._separate_heads(v, self.num_heads)
228
+
229
+ # Attention
230
+ _, _, _, c_per_head = q.shape
231
+ attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens
232
+ attn = attn / math.sqrt(c_per_head)
233
+ attn = torch.softmax(attn, dim=-1)
234
+
235
+ # Get output
236
+ out = attn @ v
237
+ out = self._recombine_heads(out)
238
+ out = self.out_proj(out)
239
+
240
+ return out
segment_anything/utils/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
segment_anything/utils/transforms.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import numpy as np
8
+ import torch
9
+ from torch.nn import functional as F
10
+ from torchvision.transforms.functional import resize, to_pil_image # type: ignore
11
+
12
+ from copy import deepcopy
13
+ from typing import Tuple
14
+
15
+
16
+ class ResizeLongestSide:
17
+ """
18
+ Resizes images to longest side 'target_length', as well as provides
19
+ methods for resizing coordinates and boxes. Provides methods for
20
+ transforming both numpy array and batched torch tensors.
21
+ """
22
+
23
+ def __init__(self, target_length: int) -> None:
24
+ self.target_length = target_length
25
+
26
+ def apply_image(self, image: np.ndarray) -> np.ndarray:
27
+ """
28
+ Expects a numpy array with shape HxWxC in uint8 format.
29
+ """
30
+ target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length)
31
+ return np.array(resize(to_pil_image(image), target_size))
32
+
33
+ def apply_coords(self, coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray:
34
+ """
35
+ Expects a numpy array of length 2 in the final dimension. Requires the
36
+ original image size in (H, W) format.
37
+ """
38
+ old_h, old_w = original_size
39
+ new_h, new_w = self.get_preprocess_shape(
40
+ original_size[0], original_size[1], self.target_length
41
+ )
42
+ coords = deepcopy(coords).astype(float)
43
+ coords[..., 0] = coords[..., 0] * (new_w / old_w)
44
+ coords[..., 1] = coords[..., 1] * (new_h / old_h)
45
+ return coords
46
+
47
+ def apply_boxes(self, boxes: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray:
48
+ """
49
+ Expects a numpy array shape Bx4. Requires the original image size
50
+ in (H, W) format.
51
+ """
52
+ boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size)
53
+ return boxes.reshape(-1, 4)
54
+
55
+ def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor:
56
+ """
57
+ Expects batched images with shape BxCxHxW and float format. This
58
+ transformation may not exactly match apply_image. apply_image is
59
+ the transformation expected by the model.
60
+ """
61
+ # Expects an image in BCHW format. May not exactly match apply_image.
62
+ target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length)
63
+ return F.interpolate(
64
+ image, target_size, mode="bilinear", align_corners=False, antialias=True
65
+ )
66
+
67
+ def apply_coords_torch(
68
+ self, coords: torch.Tensor, original_size: Tuple[int, ...]
69
+ ) -> torch.Tensor:
70
+ """
71
+ Expects a torch tensor with length 2 in the last dimension. Requires the
72
+ original image size in (H, W) format.
73
+ """
74
+ old_h, old_w = original_size
75
+ new_h, new_w = self.get_preprocess_shape(
76
+ original_size[0], original_size[1], self.target_length
77
+ )
78
+ coords = deepcopy(coords).to(torch.float)
79
+ coords[..., 0] = coords[..., 0] * (new_w / old_w)
80
+ coords[..., 1] = coords[..., 1] * (new_h / old_h)
81
+ return coords
82
+
83
+ def apply_boxes_torch(
84
+ self, boxes: torch.Tensor, original_size: Tuple[int, ...]
85
+ ) -> torch.Tensor:
86
+ """
87
+ Expects a torch tensor with shape Bx4. Requires the original image
88
+ size in (H, W) format.
89
+ """
90
+ boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size)
91
+ return boxes.reshape(-1, 4)
92
+
93
+ @staticmethod
94
+ def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]:
95
+ """
96
+ Compute the output size given input size and target long side length.
97
+ """
98
+ scale = long_side_length * 1.0 / max(oldh, oldw)
99
+ newh, neww = oldh * scale, oldw * scale
100
+ neww = int(neww + 0.5)
101
+ newh = int(newh + 0.5)
102
+ return (newh, neww)
utils/__pycache__/box_ops.cpython-310.pyc ADDED
Binary file (3.82 kB). View file
 
utils/__pycache__/misc.cpython-310.pyc ADDED
Binary file (20.3 kB). View file
 
utils/__pycache__/transforms.cpython-310.pyc ADDED
Binary file (3.93 kB). View file
 
utils/box_ops.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+ """
3
+ Utilities for bounding box manipulation and GIoU.
4
+ """
5
+ import torch
6
+ from torchvision.ops.boxes import box_area
7
+
8
+
9
+ def box_cxcywh_to_xyxy(x):
10
+ x_c, y_c, w, h = x.unbind(-1)
11
+ b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)]
12
+ return torch.stack(b, dim=-1)
13
+
14
+
15
+ def box_xyxy_to_cxcywh(x):
16
+ x0, y0, x1, y1 = x.unbind(-1)
17
+ b = [(x0 + x1) / 2, (y0 + y1) / 2, (x1 - x0), (y1 - y0)]
18
+ return torch.stack(b, dim=-1)
19
+
20
+
21
+ # modified from torchvision to also return the union
22
+ def box_iou(boxes1, boxes2):
23
+ area1 = box_area(boxes1)
24
+ area2 = box_area(boxes2)
25
+
26
+ # import ipdb; ipdb.set_trace()
27
+ lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
28
+ rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
29
+
30
+ wh = (rb - lt).clamp(min=0) # [N,M,2]
31
+ inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]
32
+
33
+ union = area1[:, None] + area2 - inter
34
+
35
+ iou = inter / (union + 1e-6)
36
+ return iou, union
37
+
38
+
39
+ def generalized_box_iou(boxes1, boxes2):
40
+ """
41
+ Generalized IoU from https://giou.stanford.edu/
42
+
43
+ The boxes should be in [x0, y0, x1, y1] format
44
+
45
+ Returns a [N, M] pairwise matrix, where N = len(boxes1)
46
+ and M = len(boxes2)
47
+ """
48
+ # degenerate boxes gives inf / nan results
49
+ # so do an early check
50
+ assert (boxes1[:, 2:] >= boxes1[:, :2]).all()
51
+ assert (boxes2[:, 2:] >= boxes2[:, :2]).all()
52
+ # except:
53
+ # import ipdb; ipdb.set_trace()
54
+ iou, union = box_iou(boxes1, boxes2)
55
+
56
+ lt = torch.min(boxes1[:, None, :2], boxes2[:, :2])
57
+ rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])
58
+
59
+ wh = (rb - lt).clamp(min=0) # [N,M,2]
60
+ area = wh[:, :, 0] * wh[:, :, 1]
61
+
62
+ return iou - (area - union) / (area + 1e-6)
63
+
64
+
65
+ # modified from torchvision to also return the union
66
+ def box_iou_pairwise(boxes1, boxes2):
67
+ area1 = box_area(boxes1)
68
+ area2 = box_area(boxes2)
69
+
70
+ lt = torch.max(boxes1[:, :2], boxes2[:, :2]) # [N,2]
71
+ rb = torch.min(boxes1[:, 2:], boxes2[:, 2:]) # [N,2]
72
+
73
+ wh = (rb - lt).clamp(min=0) # [N,2]
74
+ inter = wh[:, 0] * wh[:, 1] # [N]
75
+
76
+ union = area1 + area2 - inter
77
+
78
+ iou = inter / union
79
+ return iou, union
80
+
81
+
82
+ def generalized_box_iou_pairwise(boxes1, boxes2):
83
+ """
84
+ Generalized IoU from https://giou.stanford.edu/
85
+
86
+ Input:
87
+ - boxes1, boxes2: N,4
88
+ Output:
89
+ - giou: N, 4
90
+ """
91
+ # degenerate boxes gives inf / nan results
92
+ # so do an early check
93
+ assert (boxes1[:, 2:] >= boxes1[:, :2]).all()
94
+ assert (boxes2[:, 2:] >= boxes2[:, :2]).all()
95
+ assert boxes1.shape == boxes2.shape
96
+ iou, union = box_iou_pairwise(boxes1, boxes2) # N, 4
97
+
98
+ lt = torch.min(boxes1[:, :2], boxes2[:, :2])
99
+ rb = torch.max(boxes1[:, 2:], boxes2[:, 2:])
100
+
101
+ wh = (rb - lt).clamp(min=0) # [N,2]
102
+ area = wh[:, 0] * wh[:, 1]
103
+
104
+ return iou - (area - union) / area
105
+
106
+
107
+ def masks_to_boxes(masks):
108
+ """Compute the bounding boxes around the provided masks
109
+
110
+ The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions.
111
+
112
+ Returns a [N, 4] tensors, with the boxes in xyxy format
113
+ """
114
+ if masks.numel() == 0:
115
+ return torch.zeros((0, 4), device=masks.device)
116
+
117
+ h, w = masks.shape[-2:]
118
+
119
+ y = torch.arange(0, h, dtype=torch.float)
120
+ x = torch.arange(0, w, dtype=torch.float)
121
+ y, x = torch.meshgrid(y, x)
122
+
123
+ x_mask = masks * x.unsqueeze(0)
124
+ x_max = x_mask.flatten(1).max(-1)[0]
125
+ x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]
126
+
127
+ y_mask = masks * y.unsqueeze(0)
128
+ y_max = y_mask.flatten(1).max(-1)[0]
129
+ y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]
130
+
131
+ return torch.stack([x_min, y_min, x_max, y_max], 1)
132
+
133
+
134
+ if __name__ == "__main__":
135
+ x = torch.rand(5, 4)
136
+ y = torch.rand(3, 4)
137
+ iou, union = box_iou(x, y)
138
+ import ipdb
139
+
140
+ ipdb.set_trace()
utils/datasets/__init__.py ADDED
File without changes
utils/datasets/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (163 Bytes). View file
 
utils/datasets/__pycache__/transforms.cpython-310.pyc ADDED
Binary file (10.1 kB). View file
 
utils/datasets/transforms.py ADDED
@@ -0,0 +1,311 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+ """
3
+ Transforms and data augmentation for both image + bbox.
4
+ """
5
+ import os
6
+ import random
7
+
8
+ import PIL
9
+ import torch
10
+ import torchvision.transforms as T
11
+ import torchvision.transforms.functional as F
12
+
13
+ from utils.box_ops import box_xyxy_to_cxcywh
14
+ from utils.misc import interpolate
15
+
16
+
17
+ def crop(image, target, region):
18
+ cropped_image = F.crop(image, *region)
19
+
20
+ target = target.copy()
21
+ i, j, h, w = region
22
+
23
+ # should we do something wrt the original size?
24
+ target["size"] = torch.tensor([h, w])
25
+
26
+ fields = ["labels", "area", "iscrowd", "positive_map"]
27
+
28
+ if "boxes" in target:
29
+ boxes = target["boxes"]
30
+ max_size = torch.as_tensor([w, h], dtype=torch.float32)
31
+ cropped_boxes = boxes - torch.as_tensor([j, i, j, i])
32
+ cropped_boxes = torch.min(cropped_boxes.reshape(-1, 2, 2), max_size)
33
+ cropped_boxes = cropped_boxes.clamp(min=0)
34
+ area = (cropped_boxes[:, 1, :] - cropped_boxes[:, 0, :]).prod(dim=1)
35
+ target["boxes"] = cropped_boxes.reshape(-1, 4)
36
+ target["area"] = area
37
+ fields.append("boxes")
38
+
39
+ if "masks" in target:
40
+ # FIXME should we update the area here if there are no boxes?
41
+ target["masks"] = target["masks"][:, i : i + h, j : j + w]
42
+ fields.append("masks")
43
+
44
+ # remove elements for which the boxes or masks that have zero area
45
+ if "boxes" in target or "masks" in target:
46
+ # favor boxes selection when defining which elements to keep
47
+ # this is compatible with previous implementation
48
+ if "boxes" in target:
49
+ cropped_boxes = target["boxes"].reshape(-1, 2, 2)
50
+ keep = torch.all(cropped_boxes[:, 1, :] > cropped_boxes[:, 0, :], dim=1)
51
+ else:
52
+ keep = target["masks"].flatten(1).any(1)
53
+
54
+ for field in fields:
55
+ if field in target:
56
+ target[field] = target[field][keep]
57
+
58
+ if os.environ.get("IPDB_SHILONG_DEBUG", None) == "INFO":
59
+ # for debug and visualization only.
60
+ if "strings_positive" in target:
61
+ target["strings_positive"] = [
62
+ _i for _i, _j in zip(target["strings_positive"], keep) if _j
63
+ ]
64
+
65
+ return cropped_image, target
66
+
67
+
68
+ def hflip(image, target):
69
+ flipped_image = F.hflip(image)
70
+
71
+ w, h = image.size
72
+
73
+ target = target.copy()
74
+ if "boxes" in target:
75
+ boxes = target["boxes"]
76
+ boxes = boxes[:, [2, 1, 0, 3]] * torch.as_tensor([-1, 1, -1, 1]) + torch.as_tensor(
77
+ [w, 0, w, 0]
78
+ )
79
+ target["boxes"] = boxes
80
+
81
+ if "masks" in target:
82
+ target["masks"] = target["masks"].flip(-1)
83
+
84
+ return flipped_image, target
85
+
86
+
87
+ def resize(image, target, size, max_size=None):
88
+ # size can be min_size (scalar) or (w, h) tuple
89
+
90
+ def get_size_with_aspect_ratio(image_size, size, max_size=None):
91
+ w, h = image_size
92
+ if max_size is not None:
93
+ min_original_size = float(min((w, h)))
94
+ max_original_size = float(max((w, h)))
95
+ if max_original_size / min_original_size * size > max_size:
96
+ size = int(round(max_size * min_original_size / max_original_size))
97
+
98
+ if (w <= h and w == size) or (h <= w and h == size):
99
+ return (h, w)
100
+
101
+ if w < h:
102
+ ow = size
103
+ oh = int(size * h / w)
104
+ else:
105
+ oh = size
106
+ ow = int(size * w / h)
107
+
108
+ return (oh, ow)
109
+
110
+ def get_size(image_size, size, max_size=None):
111
+ if isinstance(size, (list, tuple)):
112
+ return size[::-1]
113
+ else:
114
+ return get_size_with_aspect_ratio(image_size, size, max_size)
115
+
116
+ size = get_size(image.size, size, max_size)
117
+ rescaled_image = F.resize(image, size)
118
+
119
+ if target is None:
120
+ return rescaled_image, None
121
+
122
+ ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(rescaled_image.size, image.size))
123
+ ratio_width, ratio_height = ratios
124
+
125
+ target = target.copy()
126
+ if "boxes" in target:
127
+ boxes = target["boxes"]
128
+ scaled_boxes = boxes * torch.as_tensor(
129
+ [ratio_width, ratio_height, ratio_width, ratio_height]
130
+ )
131
+ target["boxes"] = scaled_boxes
132
+
133
+ if "area" in target:
134
+ area = target["area"]
135
+ scaled_area = area * (ratio_width * ratio_height)
136
+ target["area"] = scaled_area
137
+
138
+ h, w = size
139
+ target["size"] = torch.tensor([h, w])
140
+
141
+ if "masks" in target:
142
+ target["masks"] = (
143
+ interpolate(target["masks"][:, None].float(), size, mode="nearest")[:, 0] > 0.5
144
+ )
145
+
146
+ return rescaled_image, target
147
+
148
+
149
+ def pad(image, target, padding):
150
+ # assumes that we only pad on the bottom right corners
151
+ padded_image = F.pad(image, (0, 0, padding[0], padding[1]))
152
+ if target is None:
153
+ return padded_image, None
154
+ target = target.copy()
155
+ # should we do something wrt the original size?
156
+ target["size"] = torch.tensor(padded_image.size[::-1])
157
+ if "masks" in target:
158
+ target["masks"] = torch.nn.functional.pad(target["masks"], (0, padding[0], 0, padding[1]))
159
+ return padded_image, target
160
+
161
+
162
+ class ResizeDebug(object):
163
+ def __init__(self, size):
164
+ self.size = size
165
+
166
+ def __call__(self, img, target):
167
+ return resize(img, target, self.size)
168
+
169
+
170
+ class RandomCrop(object):
171
+ def __init__(self, size):
172
+ self.size = size
173
+
174
+ def __call__(self, img, target):
175
+ region = T.RandomCrop.get_params(img, self.size)
176
+ return crop(img, target, region)
177
+
178
+
179
+ class RandomSizeCrop(object):
180
+ def __init__(self, min_size: int, max_size: int, respect_boxes: bool = False):
181
+ # respect_boxes: True to keep all boxes
182
+ # False to tolerence box filter
183
+ self.min_size = min_size
184
+ self.max_size = max_size
185
+ self.respect_boxes = respect_boxes
186
+
187
+ def __call__(self, img: PIL.Image.Image, target: dict):
188
+ init_boxes = len(target["boxes"])
189
+ max_patience = 10
190
+ for i in range(max_patience):
191
+ w = random.randint(self.min_size, min(img.width, self.max_size))
192
+ h = random.randint(self.min_size, min(img.height, self.max_size))
193
+ region = T.RandomCrop.get_params(img, [h, w])
194
+ result_img, result_target = crop(img, target, region)
195
+ if (
196
+ not self.respect_boxes
197
+ or len(result_target["boxes"]) == init_boxes
198
+ or i == max_patience - 1
199
+ ):
200
+ return result_img, result_target
201
+ return result_img, result_target
202
+
203
+
204
+ class CenterCrop(object):
205
+ def __init__(self, size):
206
+ self.size = size
207
+
208
+ def __call__(self, img, target):
209
+ image_width, image_height = img.size
210
+ crop_height, crop_width = self.size
211
+ crop_top = int(round((image_height - crop_height) / 2.0))
212
+ crop_left = int(round((image_width - crop_width) / 2.0))
213
+ return crop(img, target, (crop_top, crop_left, crop_height, crop_width))
214
+
215
+
216
+ class RandomHorizontalFlip(object):
217
+ def __init__(self, p=0.5):
218
+ self.p = p
219
+
220
+ def __call__(self, img, target):
221
+ if random.random() < self.p:
222
+ return hflip(img, target)
223
+ return img, target
224
+
225
+
226
+ class RandomResize(object):
227
+ def __init__(self, sizes, max_size=None):
228
+ assert isinstance(sizes, (list, tuple))
229
+ self.sizes = sizes
230
+ self.max_size = max_size
231
+
232
+ def __call__(self, img, target=None):
233
+ size = random.choice(self.sizes)
234
+ return resize(img, target, size, self.max_size)
235
+
236
+
237
+ class RandomPad(object):
238
+ def __init__(self, max_pad):
239
+ self.max_pad = max_pad
240
+
241
+ def __call__(self, img, target):
242
+ pad_x = random.randint(0, self.max_pad)
243
+ pad_y = random.randint(0, self.max_pad)
244
+ return pad(img, target, (pad_x, pad_y))
245
+
246
+
247
+ class RandomSelect(object):
248
+ """
249
+ Randomly selects between transforms1 and transforms2,
250
+ with probability p for transforms1 and (1 - p) for transforms2
251
+ """
252
+
253
+ def __init__(self, transforms1, transforms2, p=0.5):
254
+ self.transforms1 = transforms1
255
+ self.transforms2 = transforms2
256
+ self.p = p
257
+
258
+ def __call__(self, img, target):
259
+ if random.random() < self.p:
260
+ return self.transforms1(img, target)
261
+ return self.transforms2(img, target)
262
+
263
+
264
+ class ToTensor(object):
265
+ def __call__(self, img, target):
266
+ return F.to_tensor(img), target
267
+
268
+
269
+ class RandomErasing(object):
270
+ def __init__(self, *args, **kwargs):
271
+ self.eraser = T.RandomErasing(*args, **kwargs)
272
+
273
+ def __call__(self, img, target):
274
+ return self.eraser(img), target
275
+
276
+
277
+ class Normalize(object):
278
+ def __init__(self, mean, std):
279
+ self.mean = mean
280
+ self.std = std
281
+
282
+ def __call__(self, image, target=None):
283
+ image = F.normalize(image, mean=self.mean, std=self.std)
284
+ if target is None:
285
+ return image, None
286
+ target = target.copy()
287
+ h, w = image.shape[-2:]
288
+ if "boxes" in target:
289
+ boxes = target["boxes"]
290
+ boxes = box_xyxy_to_cxcywh(boxes)
291
+ boxes = boxes / torch.tensor([w, h, w, h], dtype=torch.float32)
292
+ target["boxes"] = boxes
293
+ return image, target
294
+
295
+
296
+ class Compose(object):
297
+ def __init__(self, transforms):
298
+ self.transforms = transforms
299
+
300
+ def __call__(self, image, target):
301
+ for t in self.transforms:
302
+ image, target = t(image, target)
303
+ return image, target
304
+
305
+ def __repr__(self):
306
+ format_string = self.__class__.__name__ + "("
307
+ for t in self.transforms:
308
+ format_string += "\n"
309
+ format_string += " {0}".format(t)
310
+ format_string += "\n)"
311
+ return format_string
utils/misc.py ADDED
@@ -0,0 +1,717 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+ """
3
+ Misc functions, including distributed helpers.
4
+
5
+ Mostly copy-paste from torchvision references.
6
+ """
7
+ import colorsys
8
+ import datetime
9
+ import functools
10
+ import io
11
+ import json
12
+ import os
13
+ import pickle
14
+ import subprocess
15
+ import time
16
+ from collections import OrderedDict, defaultdict, deque
17
+ from typing import List, Optional
18
+
19
+ import numpy as np
20
+ import torch
21
+ import torch.distributed as dist
22
+
23
+ # needed due to empty tensor bug in pytorch and torchvision 0.5
24
+ import torchvision
25
+ from torch import Tensor
26
+
27
+ __torchvision_need_compat_flag = float(torchvision.__version__.split(".")[1]) < 7
28
+ if __torchvision_need_compat_flag:
29
+ from torchvision.ops import _new_empty_tensor
30
+ from torchvision.ops.misc import _output_size
31
+
32
+
33
+ class SmoothedValue(object):
34
+ """Track a series of values and provide access to smoothed values over a
35
+ window or the global series average.
36
+ """
37
+
38
+ def __init__(self, window_size=20, fmt=None):
39
+ if fmt is None:
40
+ fmt = "{median:.4f} ({global_avg:.4f})"
41
+ self.deque = deque(maxlen=window_size)
42
+ self.total = 0.0
43
+ self.count = 0
44
+ self.fmt = fmt
45
+
46
+ def update(self, value, n=1):
47
+ self.deque.append(value)
48
+ self.count += n
49
+ self.total += value * n
50
+
51
+ def synchronize_between_processes(self):
52
+ """
53
+ Warning: does not synchronize the deque!
54
+ """
55
+ if not is_dist_avail_and_initialized():
56
+ return
57
+ t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda")
58
+ dist.barrier()
59
+ dist.all_reduce(t)
60
+ t = t.tolist()
61
+ self.count = int(t[0])
62
+ self.total = t[1]
63
+
64
+ @property
65
+ def median(self):
66
+ d = torch.tensor(list(self.deque))
67
+ if d.shape[0] == 0:
68
+ return 0
69
+ return d.median().item()
70
+
71
+ @property
72
+ def avg(self):
73
+ d = torch.tensor(list(self.deque), dtype=torch.float32)
74
+ return d.mean().item()
75
+
76
+ @property
77
+ def global_avg(self):
78
+ if os.environ.get("SHILONG_AMP", None) == "1":
79
+ eps = 1e-4
80
+ else:
81
+ eps = 1e-6
82
+ return self.total / (self.count + eps)
83
+
84
+ @property
85
+ def max(self):
86
+ return max(self.deque)
87
+
88
+ @property
89
+ def value(self):
90
+ return self.deque[-1]
91
+
92
+ def __str__(self):
93
+ return self.fmt.format(
94
+ median=self.median,
95
+ avg=self.avg,
96
+ global_avg=self.global_avg,
97
+ max=self.max,
98
+ value=self.value,
99
+ )
100
+
101
+
102
+ @functools.lru_cache()
103
+ def _get_global_gloo_group():
104
+ """
105
+ Return a process group based on gloo backend, containing all the ranks
106
+ The result is cached.
107
+ """
108
+
109
+ if dist.get_backend() == "nccl":
110
+ return dist.new_group(backend="gloo")
111
+
112
+ return dist.group.WORLD
113
+
114
+
115
+ def all_gather_cpu(data):
116
+ """
117
+ Run all_gather on arbitrary picklable data (not necessarily tensors)
118
+ Args:
119
+ data: any picklable object
120
+ Returns:
121
+ list[data]: list of data gathered from each rank
122
+ """
123
+
124
+ world_size = get_world_size()
125
+ if world_size == 1:
126
+ return [data]
127
+
128
+ cpu_group = _get_global_gloo_group()
129
+
130
+ buffer = io.BytesIO()
131
+ torch.save(data, buffer)
132
+ data_view = buffer.getbuffer()
133
+ device = "cuda" if cpu_group is None else "cpu"
134
+ tensor = torch.ByteTensor(data_view).to(device)
135
+
136
+ # obtain Tensor size of each rank
137
+ local_size = torch.tensor([tensor.numel()], device=device, dtype=torch.long)
138
+ size_list = [torch.tensor([0], device=device, dtype=torch.long) for _ in range(world_size)]
139
+ if cpu_group is None:
140
+ dist.all_gather(size_list, local_size)
141
+ else:
142
+ print("gathering on cpu")
143
+ dist.all_gather(size_list, local_size, group=cpu_group)
144
+ size_list = [int(size.item()) for size in size_list]
145
+ max_size = max(size_list)
146
+ assert isinstance(local_size.item(), int)
147
+ local_size = int(local_size.item())
148
+
149
+ # receiving Tensor from all ranks
150
+ # we pad the tensor because torch all_gather does not support
151
+ # gathering tensors of different shapes
152
+ tensor_list = []
153
+ for _ in size_list:
154
+ tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device=device))
155
+ if local_size != max_size:
156
+ padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device=device)
157
+ tensor = torch.cat((tensor, padding), dim=0)
158
+ if cpu_group is None:
159
+ dist.all_gather(tensor_list, tensor)
160
+ else:
161
+ dist.all_gather(tensor_list, tensor, group=cpu_group)
162
+
163
+ data_list = []
164
+ for size, tensor in zip(size_list, tensor_list):
165
+ tensor = torch.split(tensor, [size, max_size - size], dim=0)[0]
166
+ buffer = io.BytesIO(tensor.cpu().numpy())
167
+ obj = torch.load(buffer)
168
+ data_list.append(obj)
169
+
170
+ return data_list
171
+
172
+
173
+ def all_gather(data):
174
+ """
175
+ Run all_gather on arbitrary picklable data (not necessarily tensors)
176
+ Args:
177
+ data: any picklable object
178
+ Returns:
179
+ list[data]: list of data gathered from each rank
180
+ """
181
+
182
+ if os.getenv("CPU_REDUCE") == "1":
183
+ return all_gather_cpu(data)
184
+
185
+ world_size = get_world_size()
186
+ if world_size == 1:
187
+ return [data]
188
+
189
+ # serialized to a Tensor
190
+ buffer = pickle.dumps(data)
191
+ storage = torch.ByteStorage.from_buffer(buffer)
192
+ tensor = torch.ByteTensor(storage).to("cuda")
193
+
194
+ # obtain Tensor size of each rank
195
+ local_size = torch.tensor([tensor.numel()], device="cuda")
196
+ size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)]
197
+ dist.all_gather(size_list, local_size)
198
+ size_list = [int(size.item()) for size in size_list]
199
+ max_size = max(size_list)
200
+
201
+ # receiving Tensor from all ranks
202
+ # we pad the tensor because torch all_gather does not support
203
+ # gathering tensors of different shapes
204
+ tensor_list = []
205
+ for _ in size_list:
206
+ tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda"))
207
+ if local_size != max_size:
208
+ padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda")
209
+ tensor = torch.cat((tensor, padding), dim=0)
210
+ dist.all_gather(tensor_list, tensor)
211
+
212
+ data_list = []
213
+ for size, tensor in zip(size_list, tensor_list):
214
+ buffer = tensor.cpu().numpy().tobytes()[:size]
215
+ data_list.append(pickle.loads(buffer))
216
+
217
+ return data_list
218
+
219
+
220
+ def reduce_dict(input_dict, average=True):
221
+ """
222
+ Args:
223
+ input_dict (dict): all the values will be reduced
224
+ average (bool): whether to do average or sum
225
+ Reduce the values in the dictionary from all processes so that all processes
226
+ have the averaged results. Returns a dict with the same fields as
227
+ input_dict, after reduction.
228
+ """
229
+ world_size = get_world_size()
230
+ if world_size < 2:
231
+ return input_dict
232
+ with torch.no_grad():
233
+ names = []
234
+ values = []
235
+ # sort the keys so that they are consistent across processes
236
+ for k in sorted(input_dict.keys()):
237
+ names.append(k)
238
+ values.append(input_dict[k])
239
+ values = torch.stack(values, dim=0)
240
+ dist.all_reduce(values)
241
+ if average:
242
+ values /= world_size
243
+ reduced_dict = {k: v for k, v in zip(names, values)}
244
+ return reduced_dict
245
+
246
+
247
+ class MetricLogger(object):
248
+ def __init__(self, delimiter="\t"):
249
+ self.meters = defaultdict(SmoothedValue)
250
+ self.delimiter = delimiter
251
+
252
+ def update(self, **kwargs):
253
+ for k, v in kwargs.items():
254
+ if isinstance(v, torch.Tensor):
255
+ v = v.item()
256
+ assert isinstance(v, (float, int))
257
+ self.meters[k].update(v)
258
+
259
+ def __getattr__(self, attr):
260
+ if attr in self.meters:
261
+ return self.meters[attr]
262
+ if attr in self.__dict__:
263
+ return self.__dict__[attr]
264
+ raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, attr))
265
+
266
+ def __str__(self):
267
+ loss_str = []
268
+ for name, meter in self.meters.items():
269
+ # print(name, str(meter))
270
+ # import ipdb;ipdb.set_trace()
271
+ if meter.count > 0:
272
+ loss_str.append("{}: {}".format(name, str(meter)))
273
+ return self.delimiter.join(loss_str)
274
+
275
+ def synchronize_between_processes(self):
276
+ for meter in self.meters.values():
277
+ meter.synchronize_between_processes()
278
+
279
+ def add_meter(self, name, meter):
280
+ self.meters[name] = meter
281
+
282
+ def log_every(self, iterable, print_freq, header=None, logger=None):
283
+ if logger is None:
284
+ print_func = print
285
+ else:
286
+ print_func = logger.info
287
+
288
+ i = 0
289
+ if not header:
290
+ header = ""
291
+ start_time = time.time()
292
+ end = time.time()
293
+ iter_time = SmoothedValue(fmt="{avg:.4f}")
294
+ data_time = SmoothedValue(fmt="{avg:.4f}")
295
+ space_fmt = ":" + str(len(str(len(iterable)))) + "d"
296
+ if torch.cuda.is_available():
297
+ log_msg = self.delimiter.join(
298
+ [
299
+ header,
300
+ "[{0" + space_fmt + "}/{1}]",
301
+ "eta: {eta}",
302
+ "{meters}",
303
+ "time: {time}",
304
+ "data: {data}",
305
+ "max mem: {memory:.0f}",
306
+ ]
307
+ )
308
+ else:
309
+ log_msg = self.delimiter.join(
310
+ [
311
+ header,
312
+ "[{0" + space_fmt + "}/{1}]",
313
+ "eta: {eta}",
314
+ "{meters}",
315
+ "time: {time}",
316
+ "data: {data}",
317
+ ]
318
+ )
319
+ MB = 1024.0 * 1024.0
320
+ for obj in iterable:
321
+ data_time.update(time.time() - end)
322
+ yield obj
323
+ # import ipdb; ipdb.set_trace()
324
+ iter_time.update(time.time() - end)
325
+ if i % print_freq == 0 or i == len(iterable) - 1:
326
+ eta_seconds = iter_time.global_avg * (len(iterable) - i)
327
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
328
+ if torch.cuda.is_available():
329
+ print_func(
330
+ log_msg.format(
331
+ i,
332
+ len(iterable),
333
+ eta=eta_string,
334
+ meters=str(self),
335
+ time=str(iter_time),
336
+ data=str(data_time),
337
+ memory=torch.cuda.max_memory_allocated() / MB,
338
+ )
339
+ )
340
+ else:
341
+ print_func(
342
+ log_msg.format(
343
+ i,
344
+ len(iterable),
345
+ eta=eta_string,
346
+ meters=str(self),
347
+ time=str(iter_time),
348
+ data=str(data_time),
349
+ )
350
+ )
351
+ i += 1
352
+ end = time.time()
353
+ total_time = time.time() - start_time
354
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
355
+ print_func(
356
+ "{} Total time: {} ({:.4f} s / it)".format(
357
+ header, total_time_str, total_time / len(iterable)
358
+ )
359
+ )
360
+
361
+
362
+ def get_sha():
363
+ cwd = os.path.dirname(os.path.abspath(__file__))
364
+
365
+ def _run(command):
366
+ return subprocess.check_output(command, cwd=cwd).decode("ascii").strip()
367
+
368
+ sha = "N/A"
369
+ diff = "clean"
370
+ branch = "N/A"
371
+ try:
372
+ sha = _run(["git", "rev-parse", "HEAD"])
373
+ subprocess.check_output(["git", "diff"], cwd=cwd)
374
+ diff = _run(["git", "diff-index", "HEAD"])
375
+ diff = "has uncommited changes" if diff else "clean"
376
+ branch = _run(["git", "rev-parse", "--abbrev-ref", "HEAD"])
377
+ except Exception:
378
+ pass
379
+ message = f"sha: {sha}, status: {diff}, branch: {branch}"
380
+ return message
381
+
382
+
383
+ def collate_fn(batch):
384
+ # import ipdb; ipdb.set_trace()
385
+ batch = list(zip(*batch))
386
+ batch[0] = nested_tensor_from_tensor_list(batch[0])
387
+ return tuple(batch)
388
+
389
+
390
+ def _max_by_axis(the_list):
391
+ # type: (List[List[int]]) -> List[int]
392
+ maxes = the_list[0]
393
+ for sublist in the_list[1:]:
394
+ for index, item in enumerate(sublist):
395
+ maxes[index] = max(maxes[index], item)
396
+ return maxes
397
+
398
+
399
+ class NestedTensor(object):
400
+ def __init__(self, tensors, mask: Optional[Tensor]):
401
+ self.tensors = tensors
402
+ self.mask = mask
403
+ if mask == "auto":
404
+ self.mask = torch.zeros_like(tensors).to(tensors.device)
405
+ if self.mask.dim() == 3:
406
+ self.mask = self.mask.sum(0).to(bool)
407
+ elif self.mask.dim() == 4:
408
+ self.mask = self.mask.sum(1).to(bool)
409
+ else:
410
+ raise ValueError(
411
+ "tensors dim must be 3 or 4 but {}({})".format(
412
+ self.tensors.dim(), self.tensors.shape
413
+ )
414
+ )
415
+
416
+ def imgsize(self):
417
+ res = []
418
+ for i in range(self.tensors.shape[0]):
419
+ mask = self.mask[i]
420
+ maxH = (~mask).sum(0).max()
421
+ maxW = (~mask).sum(1).max()
422
+ res.append(torch.Tensor([maxH, maxW]))
423
+ return res
424
+
425
+ def to(self, device):
426
+ # type: (Device) -> NestedTensor # noqa
427
+ cast_tensor = self.tensors.to(device)
428
+ mask = self.mask
429
+ if mask is not None:
430
+ assert mask is not None
431
+ cast_mask = mask.to(device)
432
+ else:
433
+ cast_mask = None
434
+ return NestedTensor(cast_tensor, cast_mask)
435
+
436
+ def to_img_list_single(self, tensor, mask):
437
+ assert tensor.dim() == 3, "dim of tensor should be 3 but {}".format(tensor.dim())
438
+ maxH = (~mask).sum(0).max()
439
+ maxW = (~mask).sum(1).max()
440
+ img = tensor[:, :maxH, :maxW]
441
+ return img
442
+
443
+ def to_img_list(self):
444
+ """remove the padding and convert to img list
445
+
446
+ Returns:
447
+ [type]: [description]
448
+ """
449
+ if self.tensors.dim() == 3:
450
+ return self.to_img_list_single(self.tensors, self.mask)
451
+ else:
452
+ res = []
453
+ for i in range(self.tensors.shape[0]):
454
+ tensor_i = self.tensors[i]
455
+ mask_i = self.mask[i]
456
+ res.append(self.to_img_list_single(tensor_i, mask_i))
457
+ return res
458
+
459
+ @property
460
+ def device(self):
461
+ return self.tensors.device
462
+
463
+ def decompose(self):
464
+ return self.tensors, self.mask
465
+
466
+ def __repr__(self):
467
+ return str(self.tensors)
468
+
469
+ @property
470
+ def shape(self):
471
+ return {"tensors.shape": self.tensors.shape, "mask.shape": self.mask.shape}
472
+
473
+
474
+ def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
475
+ # TODO make this more general
476
+ if tensor_list[0].ndim == 3:
477
+ if torchvision._is_tracing():
478
+ # nested_tensor_from_tensor_list() does not export well to ONNX
479
+ # call _onnx_nested_tensor_from_tensor_list() instead
480
+ return _onnx_nested_tensor_from_tensor_list(tensor_list)
481
+
482
+ # TODO make it support different-sized images
483
+ max_size = _max_by_axis([list(img.shape) for img in tensor_list])
484
+ # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))
485
+ batch_shape = [len(tensor_list)] + max_size
486
+ b, c, h, w = batch_shape
487
+ dtype = tensor_list[0].dtype
488
+ device = tensor_list[0].device
489
+ tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
490
+ mask = torch.ones((b, h, w), dtype=torch.bool, device=device)
491
+ for img, pad_img, m in zip(tensor_list, tensor, mask):
492
+ pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
493
+ m[: img.shape[1], : img.shape[2]] = False
494
+ else:
495
+ raise ValueError("not supported")
496
+ return NestedTensor(tensor, mask)
497
+
498
+
499
+ # _onnx_nested_tensor_from_tensor_list() is an implementation of
500
+ # nested_tensor_from_tensor_list() that is supported by ONNX tracing.
501
+ @torch.jit.unused
502
+ def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor:
503
+ max_size = []
504
+ for i in range(tensor_list[0].dim()):
505
+ max_size_i = torch.max(
506
+ torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32)
507
+ ).to(torch.int64)
508
+ max_size.append(max_size_i)
509
+ max_size = tuple(max_size)
510
+
511
+ # work around for
512
+ # pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
513
+ # m[: img.shape[1], :img.shape[2]] = False
514
+ # which is not yet supported in onnx
515
+ padded_imgs = []
516
+ padded_masks = []
517
+ for img in tensor_list:
518
+ padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))]
519
+ padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0]))
520
+ padded_imgs.append(padded_img)
521
+
522
+ m = torch.zeros_like(img[0], dtype=torch.int, device=img.device)
523
+ padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1)
524
+ padded_masks.append(padded_mask.to(torch.bool))
525
+
526
+ tensor = torch.stack(padded_imgs)
527
+ mask = torch.stack(padded_masks)
528
+
529
+ return NestedTensor(tensor, mask=mask)
530
+
531
+
532
+ def setup_for_distributed(is_master):
533
+ """
534
+ This function disables printing when not in master process
535
+ """
536
+ import builtins as __builtin__
537
+
538
+ builtin_print = __builtin__.print
539
+
540
+ def print(*args, **kwargs):
541
+ force = kwargs.pop("force", False)
542
+ if is_master or force:
543
+ builtin_print(*args, **kwargs)
544
+
545
+ __builtin__.print = print
546
+
547
+
548
+ def is_dist_avail_and_initialized():
549
+ if not dist.is_available():
550
+ return False
551
+ if not dist.is_initialized():
552
+ return False
553
+ return True
554
+
555
+
556
+ def get_world_size():
557
+ if not is_dist_avail_and_initialized():
558
+ return 1
559
+ return dist.get_world_size()
560
+
561
+
562
+ def get_rank():
563
+ if not is_dist_avail_and_initialized():
564
+ return 0
565
+ return dist.get_rank()
566
+
567
+
568
+ def is_main_process():
569
+ return get_rank() == 0
570
+
571
+
572
+ def save_on_master(*args, **kwargs):
573
+ if is_main_process():
574
+ torch.save(*args, **kwargs)
575
+
576
+
577
+ def init_distributed_mode(args):
578
+ if "WORLD_SIZE" in os.environ and os.environ["WORLD_SIZE"] != "": # 'RANK' in os.environ and
579
+ args.rank = int(os.environ["RANK"])
580
+ args.world_size = int(os.environ["WORLD_SIZE"])
581
+ args.gpu = args.local_rank = int(os.environ["LOCAL_RANK"])
582
+
583
+ # launch by torch.distributed.launch
584
+ # Single node
585
+ # python -m torch.distributed.launch --nproc_per_node=8 main.py --world-size 1 --rank 0 ...
586
+ # Multi nodes
587
+ # python -m torch.distributed.launch --nproc_per_node=8 main.py --world-size 2 --rank 0 --dist-url 'tcp://IP_OF_NODE0:FREEPORT' ...
588
+ # python -m torch.distributed.launch --nproc_per_node=8 main.py --world-size 2 --rank 1 --dist-url 'tcp://IP_OF_NODE0:FREEPORT' ...
589
+ # args.rank = int(os.environ.get('OMPI_COMM_WORLD_RANK'))
590
+ # local_world_size = int(os.environ['GPU_PER_NODE_COUNT'])
591
+ # args.world_size = args.world_size * local_world_size
592
+ # args.gpu = args.local_rank = int(os.environ['LOCAL_RANK'])
593
+ # args.rank = args.rank * local_world_size + args.local_rank
594
+ print(
595
+ "world size: {}, rank: {}, local rank: {}".format(
596
+ args.world_size, args.rank, args.local_rank
597
+ )
598
+ )
599
+ print(json.dumps(dict(os.environ), indent=2))
600
+ elif "SLURM_PROCID" in os.environ:
601
+ args.rank = int(os.environ["SLURM_PROCID"])
602
+ args.gpu = args.local_rank = int(os.environ["SLURM_LOCALID"])
603
+ args.world_size = int(os.environ["SLURM_NPROCS"])
604
+
605
+ print(
606
+ "world size: {}, world rank: {}, local rank: {}, device_count: {}".format(
607
+ args.world_size, args.rank, args.local_rank, torch.cuda.device_count()
608
+ )
609
+ )
610
+ else:
611
+ print("Not using distributed mode")
612
+ args.distributed = False
613
+ args.world_size = 1
614
+ args.rank = 0
615
+ args.local_rank = 0
616
+ return
617
+
618
+ print("world_size:{} rank:{} local_rank:{}".format(args.world_size, args.rank, args.local_rank))
619
+ args.distributed = True
620
+ torch.cuda.set_device(args.local_rank)
621
+ args.dist_backend = "nccl"
622
+ print("| distributed init (rank {}): {}".format(args.rank, args.dist_url), flush=True)
623
+
624
+ torch.distributed.init_process_group(
625
+ backend=args.dist_backend,
626
+ world_size=args.world_size,
627
+ rank=args.rank,
628
+ init_method=args.dist_url,
629
+ )
630
+
631
+ print("Before torch.distributed.barrier()")
632
+ torch.distributed.barrier()
633
+ print("End torch.distributed.barrier()")
634
+ setup_for_distributed(args.rank == 0)
635
+
636
+
637
+ @torch.no_grad()
638
+ def accuracy(output, target, topk=(1,)):
639
+ """Computes the precision@k for the specified values of k"""
640
+ if target.numel() == 0:
641
+ return [torch.zeros([], device=output.device)]
642
+ maxk = max(topk)
643
+ batch_size = target.size(0)
644
+
645
+ _, pred = output.topk(maxk, 1, True, True)
646
+ pred = pred.t()
647
+ correct = pred.eq(target.view(1, -1).expand_as(pred))
648
+
649
+ res = []
650
+ for k in topk:
651
+ correct_k = correct[:k].view(-1).float().sum(0)
652
+ res.append(correct_k.mul_(100.0 / batch_size))
653
+ return res
654
+
655
+
656
+ @torch.no_grad()
657
+ def accuracy_onehot(pred, gt):
658
+ """_summary_
659
+
660
+ Args:
661
+ pred (_type_): n, c
662
+ gt (_type_): n, c
663
+ """
664
+ tp = ((pred - gt).abs().sum(-1) < 1e-4).float().sum()
665
+ acc = tp / gt.shape[0] * 100
666
+ return acc
667
+
668
+
669
+ def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corners=None):
670
+ # type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor
671
+ """
672
+ Equivalent to nn.functional.interpolate, but with support for empty batch sizes.
673
+ This will eventually be supported natively by PyTorch, and this
674
+ class can go away.
675
+ """
676
+ if __torchvision_need_compat_flag < 0.7:
677
+ if input.numel() > 0:
678
+ return torch.nn.functional.interpolate(input, size, scale_factor, mode, align_corners)
679
+
680
+ output_shape = _output_size(2, input, size, scale_factor)
681
+ output_shape = list(input.shape[:-2]) + list(output_shape)
682
+ return _new_empty_tensor(input, output_shape)
683
+ else:
684
+ return torchvision.ops.misc.interpolate(input, size, scale_factor, mode, align_corners)
685
+
686
+
687
+ class color_sys:
688
+ def __init__(self, num_colors) -> None:
689
+ self.num_colors = num_colors
690
+ colors = []
691
+ for i in np.arange(0.0, 360.0, 360.0 / num_colors):
692
+ hue = i / 360.0
693
+ lightness = (50 + np.random.rand() * 10) / 100.0
694
+ saturation = (90 + np.random.rand() * 10) / 100.0
695
+ colors.append(
696
+ tuple([int(j * 255) for j in colorsys.hls_to_rgb(hue, lightness, saturation)])
697
+ )
698
+ self.colors = colors
699
+
700
+ def __call__(self, idx):
701
+ return self.colors[idx]
702
+
703
+
704
+ def inverse_sigmoid(x, eps=1e-3):
705
+ x = x.clamp(min=0, max=1)
706
+ x1 = x.clamp(min=eps)
707
+ x2 = (1 - x).clamp(min=eps)
708
+ return torch.log(x1 / x2)
709
+
710
+
711
+ def clean_state_dict(state_dict):
712
+ new_state_dict = OrderedDict()
713
+ for k, v in state_dict.items():
714
+ if k[:7] == "module.":
715
+ k = k[7:] # remove `module.`
716
+ new_state_dict[k] = v
717
+ return new_state_dict
utils/transforms.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import numpy as np
8
+ import torch
9
+ from torch.nn import functional as F
10
+ from torchvision.transforms.functional import resize, to_pil_image # type: ignore
11
+
12
+ from copy import deepcopy
13
+ from typing import Tuple
14
+
15
+
16
+ class ResizeLongestSide:
17
+ """
18
+ Resizes images to longest side 'target_length', as well as provides
19
+ methods for resizing coordinates and boxes. Provides methods for
20
+ transforming both numpy array and batched torch tensors.
21
+ """
22
+
23
+ def __init__(self, target_length: int) -> None:
24
+ self.target_length = target_length
25
+
26
+ def apply_image(self, image: np.ndarray) -> np.ndarray:
27
+ """
28
+ Expects a numpy array with shape HxWxC in uint8 format.
29
+ """
30
+ target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length)
31
+ return np.array(resize(to_pil_image(image), target_size))
32
+
33
+ def apply_coords(self, coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray:
34
+ """
35
+ Expects a numpy array of length 2 in the final dimension. Requires the
36
+ original image size in (H, W) format.
37
+ """
38
+ old_h, old_w = original_size
39
+ new_h, new_w = self.get_preprocess_shape(
40
+ original_size[0], original_size[1], self.target_length
41
+ )
42
+ coords = deepcopy(coords).astype(float)
43
+ coords[..., 0] = coords[..., 0] * (new_w / old_w)
44
+ coords[..., 1] = coords[..., 1] * (new_h / old_h)
45
+ return coords
46
+
47
+ def apply_boxes(self, boxes: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray:
48
+ """
49
+ Expects a numpy array shape Bx4. Requires the original image size
50
+ in (H, W) format.
51
+ """
52
+ boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size)
53
+ return boxes.reshape(-1, 4)
54
+
55
+ def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor:
56
+ """
57
+ Expects batched images with shape BxCxHxW and float format. This
58
+ transformation may not exactly match apply_image. apply_image is
59
+ the transformation expected by the model.
60
+ """
61
+ # Expects an image in BCHW format. May not exactly match apply_image.
62
+ target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length)
63
+ return F.interpolate(
64
+ image, target_size, mode="bilinear", align_corners=False, antialias=True
65
+ )
66
+
67
+ def apply_coords_torch(
68
+ self, coords: torch.Tensor, original_size: Tuple[int, ...]
69
+ ) -> torch.Tensor:
70
+ """
71
+ Expects a torch tensor with length 2 in the last dimension. Requires the
72
+ original image size in (H, W) format.
73
+ """
74
+ old_h, old_w = original_size
75
+ new_h, new_w = self.get_preprocess_shape(
76
+ original_size[0], original_size[1], self.target_length
77
+ )
78
+ coords = deepcopy(coords).to(torch.float)
79
+ coords[..., 0] = coords[..., 0] * (new_w / old_w)
80
+ coords[..., 1] = coords[..., 1] * (new_h / old_h)
81
+ return coords
82
+
83
+ def apply_boxes_torch(
84
+ self, boxes: torch.Tensor, original_size: Tuple[int, ...]
85
+ ) -> torch.Tensor:
86
+ """
87
+ Expects a torch tensor with shape Bx4. Requires the original image
88
+ size in (H, W) format.
89
+ """
90
+ boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size)
91
+ return boxes.reshape(-1, 4)
92
+
93
+ @staticmethod
94
+ def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]:
95
+ """
96
+ Compute the output size given input size and target long side length.
97
+ """
98
+ scale = long_side_length * 1.0 / max(oldh, oldw)
99
+ newh, neww = oldh * scale, oldw * scale
100
+ neww = int(neww + 0.5)
101
+ newh = int(newh + 0.5)
102
+ return (newh, neww)