yunyangx commited on
Commit
bd9da36
1 Parent(s): 82174ad

efficient track anything built on sam2

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. sam2/.DS_Store +0 -0
  2. sam2/__init__.py +11 -0
  3. sam2/__pycache__/__init__.cpython-312.pyc +0 -0
  4. sam2/__pycache__/automatic_mask_generator.cpython-312.pyc +0 -0
  5. sam2/__pycache__/build_sam.cpython-312.pyc +0 -0
  6. sam2/__pycache__/sam2_image_predictor.cpython-312.pyc +0 -0
  7. sam2/__pycache__/sam2_video_predictor.cpython-312.pyc +0 -0
  8. sam2/automatic_mask_generator.py +434 -0
  9. sam2/build_sam.py +111 -0
  10. sam2/configs/.DS_Store +0 -0
  11. sam2/configs/__init__.py +5 -0
  12. sam2/configs/__pycache__/__init__.cpython-312.pyc +0 -0
  13. sam2/configs/efficientam_s.yaml +123 -0
  14. sam2/configs/efficienttam_s_1.yaml +123 -0
  15. sam2/configs/efficienttam_s_2.yaml +123 -0
  16. sam2/configs/efficienttam_s_512x512.yaml +123 -0
  17. sam2/configs/efficienttam_ti.yaml +123 -0
  18. sam2/configs/efficienttam_ti_1.yaml +123 -0
  19. sam2/configs/efficienttam_ti_2.yaml +123 -0
  20. sam2/configs/efficienttam_ti_512x512.yaml +123 -0
  21. sam2/configs/sam2_hiera_b+.yaml +113 -0
  22. sam2/configs/sam2_hiera_l.yaml +117 -0
  23. sam2/configs/sam2_hiera_s.yaml +116 -0
  24. sam2/configs/sam2_hiera_t.yaml +118 -0
  25. sam2/modeling/__init__.py +5 -0
  26. sam2/modeling/__pycache__/__init__.cpython-312.pyc +0 -0
  27. sam2/modeling/__pycache__/memory_attention.cpython-312.pyc +0 -0
  28. sam2/modeling/__pycache__/memory_encoder.cpython-312.pyc +0 -0
  29. sam2/modeling/__pycache__/position_encoding.cpython-312.pyc +0 -0
  30. sam2/modeling/__pycache__/sam2_base.cpython-312.pyc +0 -0
  31. sam2/modeling/__pycache__/sam2_utils.cpython-312.pyc +0 -0
  32. sam2/modeling/backbones/__init__.py +5 -0
  33. sam2/modeling/backbones/__pycache__/__init__.cpython-312.pyc +0 -0
  34. sam2/modeling/backbones/__pycache__/hieradet.cpython-312.pyc +0 -0
  35. sam2/modeling/backbones/__pycache__/image_encoder.cpython-312.pyc +0 -0
  36. sam2/modeling/backbones/__pycache__/utils.cpython-312.pyc +0 -0
  37. sam2/modeling/backbones/__pycache__/vitdet.cpython-312.pyc +0 -0
  38. sam2/modeling/backbones/hieradet.py +294 -0
  39. sam2/modeling/backbones/image_encoder.py +196 -0
  40. sam2/modeling/backbones/utils.py +125 -0
  41. sam2/modeling/backbones/vitdet.py +307 -0
  42. sam2/modeling/memory_attention.py +170 -0
  43. sam2/modeling/memory_encoder.py +181 -0
  44. sam2/modeling/position_encoding.py +215 -0
  45. sam2/modeling/sam/__init__.py +5 -0
  46. sam2/modeling/sam/__pycache__/__init__.cpython-312.pyc +0 -0
  47. sam2/modeling/sam/__pycache__/mask_decoder.cpython-312.pyc +0 -0
  48. sam2/modeling/sam/__pycache__/prompt_encoder.cpython-312.pyc +0 -0
  49. sam2/modeling/sam/__pycache__/transformer.cpython-312.pyc +0 -0
  50. sam2/modeling/sam/mask_decoder.py +295 -0
sam2/.DS_Store ADDED
Binary file (6.15 kB). View file
 
sam2/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from hydra import initialize
8
+
9
+ from .build_sam import load_model
10
+
11
+ initialize("configs", version_base="1.2")
sam2/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (328 Bytes). View file
 
sam2/__pycache__/automatic_mask_generator.cpython-312.pyc ADDED
Binary file (18.9 kB). View file
 
sam2/__pycache__/build_sam.cpython-312.pyc ADDED
Binary file (10.7 kB). View file
 
sam2/__pycache__/sam2_image_predictor.cpython-312.pyc ADDED
Binary file (21.8 kB). View file
 
sam2/__pycache__/sam2_video_predictor.cpython-312.pyc ADDED
Binary file (28.5 kB). View file
 
sam2/automatic_mask_generator.py ADDED
@@ -0,0 +1,434 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # Adapted from https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/automatic_mask_generator.py
8
+ from typing import Any, Dict, List, Optional, Tuple
9
+
10
+ import numpy as np
11
+ import torch
12
+ from torchvision.ops.boxes import batched_nms, box_area # type: ignore
13
+
14
+ from sam2.modeling.sam2_base import SAM2Base
15
+ from sam2.sam2_image_predictor import SAM2ImagePredictor
16
+ from sam2.utils.amg import (
17
+ MaskData,
18
+ area_from_rle,
19
+ batch_iterator,
20
+ batched_mask_to_box,
21
+ box_xyxy_to_xywh,
22
+ build_all_layer_point_grids,
23
+ calculate_stability_score,
24
+ coco_encode_rle,
25
+ generate_crop_boxes,
26
+ is_box_near_crop_edge,
27
+ mask_to_rle_pytorch,
28
+ remove_small_regions,
29
+ rle_to_mask,
30
+ uncrop_boxes_xyxy,
31
+ uncrop_masks,
32
+ uncrop_points,
33
+ )
34
+
35
+
36
+ class SAM2AutomaticMaskGenerator:
37
+ def __init__(
38
+ self,
39
+ model: SAM2Base,
40
+ points_per_side: Optional[int] = 32,
41
+ points_per_batch: int = 64,
42
+ pred_iou_thresh: float = 0.8,
43
+ stability_score_thresh: float = 0.95,
44
+ stability_score_offset: float = 1.0,
45
+ mask_threshold: float = 0.0,
46
+ box_nms_thresh: float = 0.7,
47
+ crop_n_layers: int = 0,
48
+ crop_nms_thresh: float = 0.7,
49
+ crop_overlap_ratio: float = 512 / 1500,
50
+ crop_n_points_downscale_factor: int = 1,
51
+ point_grids: Optional[List[np.ndarray]] = None,
52
+ min_mask_region_area: int = 0,
53
+ output_mode: str = "binary_mask",
54
+ use_m2m: bool = False,
55
+ multimask_output: bool = True,
56
+ ) -> None:
57
+ """
58
+ Using a SAM 2 model, generates masks for the entire image.
59
+ Generates a grid of point prompts over the image, then filters
60
+ low quality and duplicate masks. The default settings are chosen
61
+ for SAM 2 with a HieraL backbone.
62
+
63
+ Arguments:
64
+ model (Sam): The SAM 2 model to use for mask prediction.
65
+ points_per_side (int or None): The number of points to be sampled
66
+ along one side of the image. The total number of points is
67
+ points_per_side**2. If None, 'point_grids' must provide explicit
68
+ point sampling.
69
+ points_per_batch (int): Sets the number of points run simultaneously
70
+ by the model. Higher numbers may be faster but use more GPU memory.
71
+ pred_iou_thresh (float): A filtering threshold in [0,1], using the
72
+ model's predicted mask quality.
73
+ stability_score_thresh (float): A filtering threshold in [0,1], using
74
+ the stability of the mask under changes to the cutoff used to binarize
75
+ the model's mask predictions.
76
+ stability_score_offset (float): The amount to shift the cutoff when
77
+ calculated the stability score.
78
+ mask_threshold (float): Threshold for binarizing the mask logits
79
+ box_nms_thresh (float): The box IoU cutoff used by non-maximal
80
+ suppression to filter duplicate masks.
81
+ crop_n_layers (int): If >0, mask prediction will be run again on
82
+ crops of the image. Sets the number of layers to run, where each
83
+ layer has 2**i_layer number of image crops.
84
+ crop_nms_thresh (float): The box IoU cutoff used by non-maximal
85
+ suppression to filter duplicate masks between different crops.
86
+ crop_overlap_ratio (float): Sets the degree to which crops overlap.
87
+ In the first crop layer, crops will overlap by this fraction of
88
+ the image length. Later layers with more crops scale down this overlap.
89
+ crop_n_points_downscale_factor (int): The number of points-per-side
90
+ sampled in layer n is scaled down by crop_n_points_downscale_factor**n.
91
+ point_grids (list(np.ndarray) or None): A list over explicit grids
92
+ of points used for sampling, normalized to [0,1]. The nth grid in the
93
+ list is used in the nth crop layer. Exclusive with points_per_side.
94
+ min_mask_region_area (int): If >0, postprocessing will be applied
95
+ to remove disconnected regions and holes in masks with area smaller
96
+ than min_mask_region_area. Requires opencv.
97
+ output_mode (str): The form masks are returned in. Can be 'binary_mask',
98
+ 'uncompressed_rle', or 'coco_rle'. 'coco_rle' requires pycocotools.
99
+ For large resolutions, 'binary_mask' may consume large amounts of
100
+ memory.
101
+ use_m2m (bool): Whether to add a one step refinement using previous mask predictions.
102
+ multimask_output (bool): Whether to output multimask at each point of the grid.
103
+ """
104
+
105
+ assert (points_per_side is None) != (
106
+ point_grids is None
107
+ ), "Exactly one of points_per_side or point_grid must be provided."
108
+ if points_per_side is not None:
109
+ self.point_grids = build_all_layer_point_grids(
110
+ points_per_side,
111
+ crop_n_layers,
112
+ crop_n_points_downscale_factor,
113
+ )
114
+ elif point_grids is not None:
115
+ self.point_grids = point_grids
116
+ else:
117
+ raise ValueError("Can't have both points_per_side and point_grid be None.")
118
+
119
+ assert output_mode in [
120
+ "binary_mask",
121
+ "uncompressed_rle",
122
+ "coco_rle",
123
+ ], f"Unknown output_mode {output_mode}."
124
+ if output_mode == "coco_rle":
125
+ try:
126
+ from pycocotools import mask as mask_utils # type: ignore # noqa: F401
127
+ except ImportError as e:
128
+ print("Please install pycocotools")
129
+ raise e
130
+
131
+ self.predictor = SAM2ImagePredictor(
132
+ model,
133
+ max_hole_area=min_mask_region_area,
134
+ max_sprinkle_area=min_mask_region_area,
135
+ )
136
+ self.points_per_batch = points_per_batch
137
+ self.pred_iou_thresh = pred_iou_thresh
138
+ self.stability_score_thresh = stability_score_thresh
139
+ self.stability_score_offset = stability_score_offset
140
+ self.mask_threshold = mask_threshold
141
+ self.box_nms_thresh = box_nms_thresh
142
+ self.crop_n_layers = crop_n_layers
143
+ self.crop_nms_thresh = crop_nms_thresh
144
+ self.crop_overlap_ratio = crop_overlap_ratio
145
+ self.crop_n_points_downscale_factor = crop_n_points_downscale_factor
146
+ self.min_mask_region_area = min_mask_region_area
147
+ self.output_mode = output_mode
148
+ self.use_m2m = use_m2m
149
+ self.multimask_output = multimask_output
150
+
151
+ @torch.no_grad()
152
+ def generate(self, image: np.ndarray) -> List[Dict[str, Any]]:
153
+ """
154
+ Generates masks for the given image.
155
+
156
+ Arguments:
157
+ image (np.ndarray): The image to generate masks for, in HWC uint8 format.
158
+
159
+ Returns:
160
+ list(dict(str, any)): A list over records for masks. Each record is
161
+ a dict containing the following keys:
162
+ segmentation (dict(str, any) or np.ndarray): The mask. If
163
+ output_mode='binary_mask', is an array of shape HW. Otherwise,
164
+ is a dictionary containing the RLE.
165
+ bbox (list(float)): The box around the mask, in XYWH format.
166
+ area (int): The area in pixels of the mask.
167
+ predicted_iou (float): The model's own prediction of the mask's
168
+ quality. This is filtered by the pred_iou_thresh parameter.
169
+ point_coords (list(list(float))): The point coordinates input
170
+ to the model to generate this mask.
171
+ stability_score (float): A measure of the mask's quality. This
172
+ is filtered on using the stability_score_thresh parameter.
173
+ crop_box (list(float)): The crop of the image used to generate
174
+ the mask, given in XYWH format.
175
+ """
176
+
177
+ # Generate masks
178
+ mask_data = self._generate_masks(image)
179
+
180
+ # Encode masks
181
+ if self.output_mode == "coco_rle":
182
+ mask_data["segmentations"] = [
183
+ coco_encode_rle(rle) for rle in mask_data["rles"]
184
+ ]
185
+ elif self.output_mode == "binary_mask":
186
+ mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]]
187
+ else:
188
+ mask_data["segmentations"] = mask_data["rles"]
189
+
190
+ # Write mask records
191
+ curr_anns = []
192
+ for idx in range(len(mask_data["segmentations"])):
193
+ ann = {
194
+ "segmentation": mask_data["segmentations"][idx],
195
+ "area": area_from_rle(mask_data["rles"][idx]),
196
+ "bbox": box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(),
197
+ "predicted_iou": mask_data["iou_preds"][idx].item(),
198
+ "point_coords": [mask_data["points"][idx].tolist()],
199
+ "stability_score": mask_data["stability_score"][idx].item(),
200
+ "crop_box": box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(),
201
+ }
202
+ curr_anns.append(ann)
203
+
204
+ return curr_anns
205
+
206
+ def _generate_masks(self, image: np.ndarray) -> MaskData:
207
+ orig_size = image.shape[:2]
208
+ crop_boxes, layer_idxs = generate_crop_boxes(
209
+ orig_size, self.crop_n_layers, self.crop_overlap_ratio
210
+ )
211
+
212
+ # Iterate over image crops
213
+ data = MaskData()
214
+ for crop_box, layer_idx in zip(crop_boxes, layer_idxs):
215
+ crop_data = self._process_crop(image, crop_box, layer_idx, orig_size)
216
+ data.cat(crop_data)
217
+
218
+ # Remove duplicate masks between crops
219
+ if len(crop_boxes) > 1:
220
+ # Prefer masks from smaller crops
221
+ scores = 1 / box_area(data["crop_boxes"])
222
+ scores = scores.to(data["boxes"].device)
223
+ keep_by_nms = batched_nms(
224
+ data["boxes"].float(),
225
+ scores,
226
+ torch.zeros_like(data["boxes"][:, 0]), # categories
227
+ iou_threshold=self.crop_nms_thresh,
228
+ )
229
+ data.filter(keep_by_nms)
230
+ data.to_numpy()
231
+ return data
232
+
233
+ def _process_crop(
234
+ self,
235
+ image: np.ndarray,
236
+ crop_box: List[int],
237
+ crop_layer_idx: int,
238
+ orig_size: Tuple[int, ...],
239
+ ) -> MaskData:
240
+ # Crop the image and calculate embeddings
241
+ x0, y0, x1, y1 = crop_box
242
+ cropped_im = image[y0:y1, x0:x1, :]
243
+ cropped_im_size = cropped_im.shape[:2]
244
+ self.predictor.set_image(cropped_im)
245
+
246
+ # Get points for this crop
247
+ points_scale = np.array(cropped_im_size)[None, ::-1]
248
+ points_for_image = self.point_grids[crop_layer_idx] * points_scale
249
+
250
+ # Generate masks for this crop in batches
251
+ data = MaskData()
252
+ for (points,) in batch_iterator(self.points_per_batch, points_for_image):
253
+ batch_data = self._process_batch(
254
+ points, cropped_im_size, crop_box, orig_size, normalize=True
255
+ )
256
+ data.cat(batch_data)
257
+ del batch_data
258
+ self.predictor.reset_predictor()
259
+
260
+ # Remove duplicates within this crop.
261
+ keep_by_nms = batched_nms(
262
+ data["boxes"].float(),
263
+ data["iou_preds"],
264
+ torch.zeros_like(data["boxes"][:, 0]), # categories
265
+ iou_threshold=self.box_nms_thresh,
266
+ )
267
+ data.filter(keep_by_nms)
268
+
269
+ # Return to the original image frame
270
+ data["boxes"] = uncrop_boxes_xyxy(data["boxes"], crop_box)
271
+ data["points"] = uncrop_points(data["points"], crop_box)
272
+ data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))])
273
+
274
+ return data
275
+
276
+ def _process_batch(
277
+ self,
278
+ points: np.ndarray,
279
+ im_size: Tuple[int, ...],
280
+ crop_box: List[int],
281
+ orig_size: Tuple[int, ...],
282
+ normalize=False,
283
+ ) -> MaskData:
284
+ orig_h, orig_w = orig_size
285
+
286
+ # Run model on this batch
287
+ points = torch.as_tensor(points, device=self.predictor.device)
288
+ in_points = self.predictor._transforms.transform_coords(
289
+ points, normalize=normalize, orig_hw=im_size
290
+ )
291
+ in_labels = torch.ones(
292
+ in_points.shape[0], dtype=torch.int, device=in_points.device
293
+ )
294
+ masks, iou_preds, low_res_masks = self.predictor._predict(
295
+ in_points[:, None, :],
296
+ in_labels[:, None],
297
+ multimask_output=self.multimask_output,
298
+ return_logits=True,
299
+ )
300
+
301
+ # Serialize predictions and store in MaskData
302
+ data = MaskData(
303
+ masks=masks.flatten(0, 1),
304
+ iou_preds=iou_preds.flatten(0, 1),
305
+ points=points.repeat_interleave(masks.shape[1], dim=0),
306
+ low_res_masks=low_res_masks.flatten(0, 1),
307
+ )
308
+ del masks
309
+
310
+ if not self.use_m2m:
311
+ # Filter by predicted IoU
312
+ if self.pred_iou_thresh > 0.0:
313
+ keep_mask = data["iou_preds"] > self.pred_iou_thresh
314
+ data.filter(keep_mask)
315
+
316
+ # Calculate and filter by stability score
317
+ data["stability_score"] = calculate_stability_score(
318
+ data["masks"], self.mask_threshold, self.stability_score_offset
319
+ )
320
+ if self.stability_score_thresh > 0.0:
321
+ keep_mask = data["stability_score"] >= self.stability_score_thresh
322
+ data.filter(keep_mask)
323
+ else:
324
+ # One step refinement using previous mask predictions
325
+ in_points = self.predictor._transforms.transform_coords(
326
+ data["points"], normalize=normalize, orig_hw=im_size
327
+ )
328
+ labels = torch.ones(
329
+ in_points.shape[0], dtype=torch.int, device=in_points.device
330
+ )
331
+ masks, ious = self.refine_with_m2m(
332
+ in_points, labels, data["low_res_masks"], self.points_per_batch
333
+ )
334
+ data["masks"] = masks.squeeze(1)
335
+ data["iou_preds"] = ious.squeeze(1)
336
+
337
+ if self.pred_iou_thresh > 0.0:
338
+ keep_mask = data["iou_preds"] > self.pred_iou_thresh
339
+ data.filter(keep_mask)
340
+
341
+ data["stability_score"] = calculate_stability_score(
342
+ data["masks"], self.mask_threshold, self.stability_score_offset
343
+ )
344
+ if self.stability_score_thresh > 0.0:
345
+ keep_mask = data["stability_score"] >= self.stability_score_thresh
346
+ data.filter(keep_mask)
347
+
348
+ # Threshold masks and calculate boxes
349
+ data["masks"] = data["masks"] > self.mask_threshold
350
+ data["boxes"] = batched_mask_to_box(data["masks"])
351
+
352
+ # Filter boxes that touch crop boundaries
353
+ keep_mask = ~is_box_near_crop_edge(
354
+ data["boxes"], crop_box, [0, 0, orig_w, orig_h]
355
+ )
356
+ if not torch.all(keep_mask):
357
+ data.filter(keep_mask)
358
+
359
+ # Compress to RLE
360
+ data["masks"] = uncrop_masks(data["masks"], crop_box, orig_h, orig_w)
361
+ data["rles"] = mask_to_rle_pytorch(data["masks"])
362
+ del data["masks"]
363
+
364
+ return data
365
+
366
+ @staticmethod
367
+ def postprocess_small_regions(
368
+ mask_data: MaskData, min_area: int, nms_thresh: float
369
+ ) -> MaskData:
370
+ """
371
+ Removes small disconnected regions and holes in masks, then reruns
372
+ box NMS to remove any new duplicates.
373
+
374
+ Edits mask_data in place.
375
+
376
+ Requires open-cv as a dependency.
377
+ """
378
+ if len(mask_data["rles"]) == 0:
379
+ return mask_data
380
+
381
+ # Filter small disconnected regions and holes
382
+ new_masks = []
383
+ scores = []
384
+ for rle in mask_data["rles"]:
385
+ mask = rle_to_mask(rle)
386
+
387
+ mask, changed = remove_small_regions(mask, min_area, mode="holes")
388
+ unchanged = not changed
389
+ mask, changed = remove_small_regions(mask, min_area, mode="islands")
390
+ unchanged = unchanged and not changed
391
+
392
+ new_masks.append(torch.as_tensor(mask).unsqueeze(0))
393
+ # Give score=0 to changed masks and score=1 to unchanged masks
394
+ # so NMS will prefer ones that didn't need postprocessing
395
+ scores.append(float(unchanged))
396
+
397
+ # Recalculate boxes and remove any new duplicates
398
+ masks = torch.cat(new_masks, dim=0)
399
+ boxes = batched_mask_to_box(masks)
400
+ keep_by_nms = batched_nms(
401
+ boxes.float(),
402
+ torch.as_tensor(scores),
403
+ torch.zeros_like(boxes[:, 0]), # categories
404
+ iou_threshold=nms_thresh,
405
+ )
406
+
407
+ # Only recalculate RLEs for masks that have changed
408
+ for i_mask in keep_by_nms:
409
+ if scores[i_mask] == 0.0:
410
+ mask_torch = masks[i_mask].unsqueeze(0)
411
+ mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0]
412
+ mask_data["boxes"][i_mask] = boxes[i_mask] # update res directly
413
+ mask_data.filter(keep_by_nms)
414
+
415
+ return mask_data
416
+
417
+ def refine_with_m2m(self, points, point_labels, low_res_masks, points_per_batch):
418
+ new_masks = []
419
+ new_iou_preds = []
420
+
421
+ for cur_points, cur_point_labels, low_res_mask in batch_iterator(
422
+ points_per_batch, points, point_labels, low_res_masks
423
+ ):
424
+ best_masks, best_iou_preds, _ = self.predictor._predict(
425
+ cur_points[:, None, :],
426
+ cur_point_labels[:, None],
427
+ mask_input=low_res_mask[:, None, :],
428
+ multimask_output=False,
429
+ return_logits=True,
430
+ )
431
+ new_masks.append(best_masks)
432
+ new_iou_preds.append(best_iou_preds)
433
+ masks = torch.cat(new_masks, dim=0)
434
+ return masks, torch.cat(new_iou_preds, dim=0)
sam2/build_sam.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import logging
8
+
9
+ import torch
10
+ from hydra import compose
11
+ from hydra.utils import instantiate
12
+ from omegaconf import OmegaConf
13
+
14
+ from .utils.misc import VARIANTS, variant_to_config_mapping
15
+
16
+
17
+ def load_model(
18
+ variant: str,
19
+ ckpt_path=None,
20
+ device="cuda",
21
+ mode="eval",
22
+ hydra_overrides_extra=[],
23
+ apply_postprocessing=True,
24
+ ) -> torch.nn.Module:
25
+ assert variant in VARIANTS, f"only accepted variants are {VARIANTS}"
26
+
27
+ return build_sam2(
28
+ config_file=variant_to_config_mapping[variant],
29
+ ckpt_path=ckpt_path,
30
+ device=device,
31
+ mode=mode,
32
+ hydra_overrides_extra=hydra_overrides_extra,
33
+ apply_postprocessing=apply_postprocessing,
34
+ )
35
+
36
+
37
+ def build_sam2(
38
+ config_file,
39
+ ckpt_path=None,
40
+ device="cuda",
41
+ mode="eval",
42
+ hydra_overrides_extra=[],
43
+ apply_postprocessing=True,
44
+ ):
45
+
46
+ if apply_postprocessing:
47
+ hydra_overrides_extra = hydra_overrides_extra.copy()
48
+ hydra_overrides_extra += [
49
+ # dynamically fall back to multi-mask if the single mask is not stable
50
+ "++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true",
51
+ "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05",
52
+ "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98",
53
+ ]
54
+ # Read config and init model
55
+ cfg = compose(config_name=config_file, overrides=hydra_overrides_extra)
56
+ OmegaConf.resolve(cfg)
57
+ model = instantiate(cfg.model, _recursive_=True)
58
+ _load_checkpoint(model, ckpt_path)
59
+ model = model.to(device)
60
+ if mode == "eval":
61
+ model.eval()
62
+ return model
63
+
64
+
65
+ def build_sam2_video_predictor(
66
+ config_file,
67
+ ckpt_path=None,
68
+ device="cuda",
69
+ mode="eval",
70
+ hydra_overrides_extra=[],
71
+ apply_postprocessing=True,
72
+ ):
73
+ hydra_overrides = [
74
+ "++model._target_=sam2.sam2_video_predictor.SAM2VideoPredictor",
75
+ ]
76
+ if apply_postprocessing:
77
+ hydra_overrides_extra = hydra_overrides_extra.copy()
78
+ hydra_overrides_extra += [
79
+ # dynamically fall back to multi-mask if the single mask is not stable
80
+ "++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true",
81
+ "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05",
82
+ "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98",
83
+ # the sigmoid mask logits on interacted frames with clicks in the memory encoder so that the encoded masks are exactly as what users see from clicking
84
+ "++model.binarize_mask_from_pts_for_mem_enc=true",
85
+ # fill small holes in the low-res masks up to `fill_hole_area` (before resizing them to the original video resolution)
86
+ # "++model.fill_hole_area=8",
87
+ ]
88
+ hydra_overrides.extend(hydra_overrides_extra)
89
+
90
+ # Read config and init model
91
+ cfg = compose(config_name=config_file, overrides=hydra_overrides)
92
+ OmegaConf.resolve(cfg)
93
+ model = instantiate(cfg.model, _recursive_=True)
94
+ _load_checkpoint(model, ckpt_path)
95
+ model = model.to(device)
96
+ if mode == "eval":
97
+ model.eval()
98
+ return model
99
+
100
+
101
+ def _load_checkpoint(model, ckpt_path):
102
+ if ckpt_path is not None:
103
+ sd = torch.load(ckpt_path, map_location="cpu", weights_only=True)["model"]
104
+ missing_keys, unexpected_keys = model.load_state_dict(sd)
105
+ if missing_keys:
106
+ logging.error(missing_keys)
107
+ raise RuntimeError()
108
+ if unexpected_keys:
109
+ logging.error(unexpected_keys)
110
+ raise RuntimeError()
111
+ logging.info("Loaded checkpoint sucessfully")
sam2/configs/.DS_Store ADDED
Binary file (6.15 kB). View file
 
sam2/configs/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
sam2/configs/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (181 Bytes). View file
 
sam2/configs/efficientam_s.yaml ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # Model
4
+ model:
5
+ _target_: sam2.modeling.sam2_base.SAM2Base
6
+ image_encoder:
7
+ _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
8
+ scalp: 0
9
+ trunk:
10
+ _target_: sam2.modeling.backbones.vitdet.ViT
11
+ patch_size: 16
12
+ embed_dim: 384
13
+ depth: 12
14
+ num_heads: 6
15
+ mlp_ratio: 4.0
16
+ qkv_bias: true
17
+ drop_path_rate: 0.0
18
+ use_rel_pos: false
19
+ window_size: 14
20
+ window_block_indexes: [0, 1, 3, 4, 6, 7, 9, 10]
21
+ neck:
22
+ _target_: sam2.modeling.backbones.image_encoder.ViTDetNeck
23
+ position_encoding:
24
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
25
+ num_pos_feats: 256
26
+ normalize: true
27
+ scale: null
28
+ temperature: 10000
29
+ d_model: 256
30
+ backbone_channel_list: [384,]
31
+ neck_norm: LN
32
+
33
+ memory_attention:
34
+ _target_: sam2.modeling.memory_attention.MemoryAttention
35
+ d_model: 256
36
+ pos_enc_at_input: true
37
+ layer:
38
+ _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
39
+ activation: relu
40
+ dim_feedforward: 2048
41
+ dropout: 0.1
42
+ pos_enc_at_attn: false
43
+ self_attention:
44
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
45
+ rope_theta: 10000.0
46
+ feat_sizes: [32, 32]
47
+ embedding_dim: 256
48
+ num_heads: 1
49
+ downsample_rate: 1
50
+ dropout: 0.1
51
+ d_model: 256
52
+ pos_enc_at_cross_attn_keys: true
53
+ pos_enc_at_cross_attn_queries: false
54
+ cross_attention:
55
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
56
+ rope_theta: 10000.0
57
+ feat_sizes: [32, 32]
58
+ rope_k_repeat: True
59
+ embedding_dim: 256
60
+ num_heads: 1
61
+ downsample_rate: 1
62
+ dropout: 0.1
63
+ kv_in_dim: 64
64
+ num_layers: 4
65
+
66
+ memory_encoder:
67
+ _target_: sam2.modeling.memory_encoder.MemoryEncoder
68
+ out_dim: 64
69
+ position_encoding:
70
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
71
+ num_pos_feats: 64
72
+ normalize: true
73
+ scale: null
74
+ temperature: 10000
75
+ mask_downsampler:
76
+ _target_: sam2.modeling.memory_encoder.MaskDownSampler
77
+ kernel_size: 3
78
+ stride: 2
79
+ padding: 1
80
+ fuser:
81
+ _target_: sam2.modeling.memory_encoder.Fuser
82
+ layer:
83
+ _target_: sam2.modeling.memory_encoder.CXBlock
84
+ dim: 256
85
+ kernel_size: 7
86
+ padding: 3
87
+ layer_scale_init_value: 1e-6
88
+ use_dwconv: True # depth-wise convs
89
+ num_layers: 2
90
+
91
+ num_maskmem: 7
92
+ image_size: 1024
93
+ # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
94
+ # SAM decoder
95
+ sigmoid_scale_for_mem_enc: 20.0
96
+ sigmoid_bias_for_mem_enc: -10.0
97
+ use_mask_input_as_output_without_sam: true
98
+ # Memory
99
+ directly_add_no_mem_embed: true
100
+ # use high-resolution feature map in the SAM mask decoder
101
+ # use_high_res_features_in_sam: true
102
+ use_high_res_features_in_sam: false
103
+ # output 3 masks on the first click on initial conditioning frames
104
+ multimask_output_in_sam: true
105
+ # SAM heads
106
+ iou_prediction_use_sigmoid: True
107
+ # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
108
+ use_obj_ptrs_in_encoder: true
109
+ add_tpos_enc_to_obj_ptrs: false
110
+ only_obj_ptrs_in_the_past_for_eval: true
111
+ # object occlusion prediction
112
+ pred_obj_scores: true
113
+ pred_obj_scores_mlp: true
114
+ fixed_no_obj_ptr: true
115
+ # multimask tracking settings
116
+ multimask_output_for_tracking: true
117
+ use_multimask_token_for_obj_ptr: true
118
+ multimask_min_pt_num: 0
119
+ multimask_max_pt_num: 1
120
+ use_mlp_for_obj_ptr_proj: true
121
+ # Compilation flag
122
+ # HieraT does not currently support compilation, should always be set to False
123
+ compile_image_encoder: false
sam2/configs/efficienttam_s_1.yaml ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # Model
4
+ model:
5
+ _target_: sam2.modeling.sam2_base.SAM2Base
6
+ image_encoder:
7
+ _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
8
+ scalp: 0
9
+ trunk:
10
+ _target_: sam2.modeling.backbones.vitdet.ViT
11
+ patch_size: 16
12
+ embed_dim: 384
13
+ depth: 12
14
+ num_heads: 6
15
+ mlp_ratio: 4.0
16
+ qkv_bias: true
17
+ drop_path_rate: 0.0
18
+ use_rel_pos: false
19
+ window_size: 14
20
+ window_block_indexes: [0, 1, 3, 4, 6, 7, 9, 10]
21
+ neck:
22
+ _target_: sam2.modeling.backbones.image_encoder.ViTDetNeck
23
+ position_encoding:
24
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
25
+ num_pos_feats: 256
26
+ normalize: true
27
+ scale: null
28
+ temperature: 10000
29
+ d_model: 256
30
+ backbone_channel_list: [384,]
31
+ neck_norm: LN
32
+
33
+ memory_attention:
34
+ _target_: sam2.modeling.memory_attention.MemoryAttention
35
+ d_model: 256
36
+ pos_enc_at_input: true
37
+ layer:
38
+ _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
39
+ activation: relu
40
+ dim_feedforward: 2048
41
+ dropout: 0.1
42
+ pos_enc_at_attn: false
43
+ self_attention:
44
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
45
+ rope_theta: 10000.0
46
+ feat_sizes: [32, 32]
47
+ embedding_dim: 256
48
+ num_heads: 1
49
+ downsample_rate: 1
50
+ dropout: 0.1
51
+ d_model: 256
52
+ pos_enc_at_cross_attn_keys: true
53
+ pos_enc_at_cross_attn_queries: false
54
+ cross_attention:
55
+ _target_: sam2.modeling.sam.transformer.EfficientRoPEAttention1
56
+ rope_theta: 10000.0
57
+ feat_sizes: [32, 32]
58
+ rope_k_repeat: True
59
+ embedding_dim: 256
60
+ num_heads: 1
61
+ downsample_rate: 1
62
+ dropout: 0.1
63
+ kv_in_dim: 64
64
+ num_layers: 4
65
+
66
+ memory_encoder:
67
+ _target_: sam2.modeling.memory_encoder.MemoryEncoder
68
+ out_dim: 64
69
+ position_encoding:
70
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
71
+ num_pos_feats: 64
72
+ normalize: true
73
+ scale: null
74
+ temperature: 10000
75
+ mask_downsampler:
76
+ _target_: sam2.modeling.memory_encoder.MaskDownSampler
77
+ kernel_size: 3
78
+ stride: 2
79
+ padding: 1
80
+ fuser:
81
+ _target_: sam2.modeling.memory_encoder.Fuser
82
+ layer:
83
+ _target_: sam2.modeling.memory_encoder.CXBlock
84
+ dim: 256
85
+ kernel_size: 7
86
+ padding: 3
87
+ layer_scale_init_value: 1e-6
88
+ use_dwconv: True # depth-wise convs
89
+ num_layers: 2
90
+
91
+ num_maskmem: 7
92
+ image_size: 1024
93
+ # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
94
+ # SAM decoder
95
+ sigmoid_scale_for_mem_enc: 20.0
96
+ sigmoid_bias_for_mem_enc: -10.0
97
+ use_mask_input_as_output_without_sam: true
98
+ # Memory
99
+ directly_add_no_mem_embed: true
100
+ # use high-resolution feature map in the SAM mask decoder
101
+ # use_high_res_features_in_sam: true
102
+ use_high_res_features_in_sam: false
103
+ # output 3 masks on the first click on initial conditioning frames
104
+ multimask_output_in_sam: true
105
+ # SAM heads
106
+ iou_prediction_use_sigmoid: True
107
+ # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
108
+ use_obj_ptrs_in_encoder: true
109
+ add_tpos_enc_to_obj_ptrs: false
110
+ only_obj_ptrs_in_the_past_for_eval: true
111
+ # object occlusion prediction
112
+ pred_obj_scores: true
113
+ pred_obj_scores_mlp: true
114
+ fixed_no_obj_ptr: true
115
+ # multimask tracking settings
116
+ multimask_output_for_tracking: true
117
+ use_multimask_token_for_obj_ptr: true
118
+ multimask_min_pt_num: 0
119
+ multimask_max_pt_num: 1
120
+ use_mlp_for_obj_ptr_proj: true
121
+ # Compilation flag
122
+ # HieraT does not currently support compilation, should always be set to False
123
+ compile_image_encoder: false
sam2/configs/efficienttam_s_2.yaml ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # Model
4
+ model:
5
+ _target_: sam2.modeling.sam2_base.SAM2Base
6
+ image_encoder:
7
+ _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
8
+ scalp: 0
9
+ trunk:
10
+ _target_: sam2.modeling.backbones.vitdet.ViT
11
+ patch_size: 16
12
+ embed_dim: 384
13
+ depth: 12
14
+ num_heads: 6
15
+ mlp_ratio: 4.0
16
+ qkv_bias: true
17
+ drop_path_rate: 0.0
18
+ use_rel_pos: false
19
+ window_size: 14
20
+ window_block_indexes: [0, 1, 3, 4, 6, 7, 9, 10]
21
+ neck:
22
+ _target_: sam2.modeling.backbones.image_encoder.ViTDetNeck
23
+ position_encoding:
24
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
25
+ num_pos_feats: 256
26
+ normalize: true
27
+ scale: null
28
+ temperature: 10000
29
+ d_model: 256
30
+ backbone_channel_list: [384,]
31
+ neck_norm: LN
32
+
33
+ memory_attention:
34
+ _target_: sam2.modeling.memory_attention.MemoryAttention
35
+ d_model: 256
36
+ pos_enc_at_input: true
37
+ layer:
38
+ _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
39
+ activation: relu
40
+ dim_feedforward: 2048
41
+ dropout: 0.1
42
+ pos_enc_at_attn: false
43
+ self_attention:
44
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
45
+ rope_theta: 10000.0
46
+ feat_sizes: [32, 32]
47
+ embedding_dim: 256
48
+ num_heads: 1
49
+ downsample_rate: 1
50
+ dropout: 0.1
51
+ d_model: 256
52
+ pos_enc_at_cross_attn_keys: true
53
+ pos_enc_at_cross_attn_queries: false
54
+ cross_attention:
55
+ _target_: sam2.modeling.sam.transformer.EfficientRoPEAttention2
56
+ rope_theta: 10000.0
57
+ feat_sizes: [32, 32]
58
+ rope_k_repeat: True
59
+ embedding_dim: 256
60
+ num_heads: 1
61
+ downsample_rate: 1
62
+ dropout: 0.1
63
+ kv_in_dim: 64
64
+ num_layers: 4
65
+
66
+ memory_encoder:
67
+ _target_: sam2.modeling.memory_encoder.MemoryEncoder
68
+ out_dim: 64
69
+ position_encoding:
70
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
71
+ num_pos_feats: 64
72
+ normalize: true
73
+ scale: null
74
+ temperature: 10000
75
+ mask_downsampler:
76
+ _target_: sam2.modeling.memory_encoder.MaskDownSampler
77
+ kernel_size: 3
78
+ stride: 2
79
+ padding: 1
80
+ fuser:
81
+ _target_: sam2.modeling.memory_encoder.Fuser
82
+ layer:
83
+ _target_: sam2.modeling.memory_encoder.CXBlock
84
+ dim: 256
85
+ kernel_size: 7
86
+ padding: 3
87
+ layer_scale_init_value: 1e-6
88
+ use_dwconv: True # depth-wise convs
89
+ num_layers: 2
90
+
91
+ num_maskmem: 7
92
+ image_size: 1024
93
+ # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
94
+ # SAM decoder
95
+ sigmoid_scale_for_mem_enc: 20.0
96
+ sigmoid_bias_for_mem_enc: -10.0
97
+ use_mask_input_as_output_without_sam: true
98
+ # Memory
99
+ directly_add_no_mem_embed: true
100
+ # use high-resolution feature map in the SAM mask decoder
101
+ # use_high_res_features_in_sam: true
102
+ use_high_res_features_in_sam: false
103
+ # output 3 masks on the first click on initial conditioning frames
104
+ multimask_output_in_sam: true
105
+ # SAM heads
106
+ iou_prediction_use_sigmoid: True
107
+ # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
108
+ use_obj_ptrs_in_encoder: true
109
+ add_tpos_enc_to_obj_ptrs: false
110
+ only_obj_ptrs_in_the_past_for_eval: true
111
+ # object occlusion prediction
112
+ pred_obj_scores: true
113
+ pred_obj_scores_mlp: true
114
+ fixed_no_obj_ptr: true
115
+ # multimask tracking settings
116
+ multimask_output_for_tracking: true
117
+ use_multimask_token_for_obj_ptr: true
118
+ multimask_min_pt_num: 0
119
+ multimask_max_pt_num: 1
120
+ use_mlp_for_obj_ptr_proj: true
121
+ # Compilation flag
122
+ # HieraT does not currently support compilation, should always be set to False
123
+ compile_image_encoder: false
sam2/configs/efficienttam_s_512x512.yaml ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # Model
4
+ model:
5
+ _target_: sam2.modeling.sam2_base.SAM2Base
6
+ image_encoder:
7
+ _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
8
+ scalp: 0
9
+ trunk:
10
+ _target_: sam2.modeling.backbones.vitdet.ViT
11
+ patch_size: 16
12
+ embed_dim: 384
13
+ depth: 12
14
+ num_heads: 6
15
+ mlp_ratio: 4.0
16
+ qkv_bias: true
17
+ drop_path_rate: 0.0
18
+ use_rel_pos: false
19
+ window_size: 14
20
+ window_block_indexes: [0, 1, 3, 4, 6, 7, 9, 10]
21
+ neck:
22
+ _target_: sam2.modeling.backbones.image_encoder.ViTDetNeck
23
+ position_encoding:
24
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
25
+ num_pos_feats: 256
26
+ normalize: true
27
+ scale: null
28
+ temperature: 10000
29
+ d_model: 256
30
+ backbone_channel_list: [384,]
31
+ neck_norm: LN
32
+
33
+ memory_attention:
34
+ _target_: sam2.modeling.memory_attention.MemoryAttention
35
+ d_model: 256
36
+ pos_enc_at_input: true
37
+ layer:
38
+ _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
39
+ activation: relu
40
+ dim_feedforward: 2048
41
+ dropout: 0.1
42
+ pos_enc_at_attn: false
43
+ self_attention:
44
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
45
+ rope_theta: 10000.0
46
+ feat_sizes: [32, 32]
47
+ embedding_dim: 256
48
+ num_heads: 1
49
+ downsample_rate: 1
50
+ dropout: 0.1
51
+ d_model: 256
52
+ pos_enc_at_cross_attn_keys: true
53
+ pos_enc_at_cross_attn_queries: false
54
+ cross_attention:
55
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
56
+ rope_theta: 10000.0
57
+ feat_sizes: [32, 32]
58
+ rope_k_repeat: True
59
+ embedding_dim: 256
60
+ num_heads: 1
61
+ downsample_rate: 1
62
+ dropout: 0.1
63
+ kv_in_dim: 64
64
+ num_layers: 4
65
+
66
+ memory_encoder:
67
+ _target_: sam2.modeling.memory_encoder.MemoryEncoder
68
+ out_dim: 64
69
+ position_encoding:
70
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
71
+ num_pos_feats: 64
72
+ normalize: true
73
+ scale: null
74
+ temperature: 10000
75
+ mask_downsampler:
76
+ _target_: sam2.modeling.memory_encoder.MaskDownSampler
77
+ kernel_size: 3
78
+ stride: 2
79
+ padding: 1
80
+ fuser:
81
+ _target_: sam2.modeling.memory_encoder.Fuser
82
+ layer:
83
+ _target_: sam2.modeling.memory_encoder.CXBlock
84
+ dim: 256
85
+ kernel_size: 7
86
+ padding: 3
87
+ layer_scale_init_value: 1e-6
88
+ use_dwconv: True # depth-wise convs
89
+ num_layers: 2
90
+
91
+ num_maskmem: 7
92
+ image_size: 512
93
+ # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
94
+ # SAM decoder
95
+ sigmoid_scale_for_mem_enc: 20.0
96
+ sigmoid_bias_for_mem_enc: -10.0
97
+ use_mask_input_as_output_without_sam: true
98
+ # Memory
99
+ directly_add_no_mem_embed: true
100
+ # use high-resolution feature map in the SAM mask decoder
101
+ # use_high_res_features_in_sam: true
102
+ use_high_res_features_in_sam: false
103
+ # output 3 masks on the first click on initial conditioning frames
104
+ multimask_output_in_sam: true
105
+ # SAM heads
106
+ iou_prediction_use_sigmoid: True
107
+ # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
108
+ use_obj_ptrs_in_encoder: true
109
+ add_tpos_enc_to_obj_ptrs: false
110
+ only_obj_ptrs_in_the_past_for_eval: true
111
+ # object occlusion prediction
112
+ pred_obj_scores: true
113
+ pred_obj_scores_mlp: true
114
+ fixed_no_obj_ptr: true
115
+ # multimask tracking settings
116
+ multimask_output_for_tracking: true
117
+ use_multimask_token_for_obj_ptr: true
118
+ multimask_min_pt_num: 0
119
+ multimask_max_pt_num: 1
120
+ use_mlp_for_obj_ptr_proj: true
121
+ # Compilation flag
122
+ # HieraT does not currently support compilation, should always be set to False
123
+ compile_image_encoder: false
sam2/configs/efficienttam_ti.yaml ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # Model
4
+ model:
5
+ _target_: sam2.modeling.sam2_base.SAM2Base
6
+ image_encoder:
7
+ _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
8
+ scalp: 0
9
+ trunk:
10
+ _target_: sam2.modeling.backbones.vitdet.ViT
11
+ patch_size: 16
12
+ embed_dim: 192
13
+ depth: 12
14
+ num_heads: 3
15
+ mlp_ratio: 4.0
16
+ qkv_bias: true
17
+ drop_path_rate: 0.0
18
+ use_rel_pos: false
19
+ window_size: 14
20
+ window_block_indexes: [0, 1, 3, 4, 6, 7, 9, 10]
21
+ neck:
22
+ _target_: sam2.modeling.backbones.image_encoder.ViTDetNeck
23
+ position_encoding:
24
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
25
+ num_pos_feats: 256
26
+ normalize: true
27
+ scale: null
28
+ temperature: 10000
29
+ d_model: 256
30
+ backbone_channel_list: [192,]
31
+ neck_norm: LN
32
+
33
+ memory_attention:
34
+ _target_: sam2.modeling.memory_attention.MemoryAttention
35
+ d_model: 256
36
+ pos_enc_at_input: true
37
+ layer:
38
+ _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
39
+ activation: relu
40
+ dim_feedforward: 2048
41
+ dropout: 0.1
42
+ pos_enc_at_attn: false
43
+ self_attention:
44
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
45
+ rope_theta: 10000.0
46
+ feat_sizes: [32, 32]
47
+ embedding_dim: 256
48
+ num_heads: 1
49
+ downsample_rate: 1
50
+ dropout: 0.1
51
+ d_model: 256
52
+ pos_enc_at_cross_attn_keys: true
53
+ pos_enc_at_cross_attn_queries: false
54
+ cross_attention:
55
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
56
+ rope_theta: 10000.0
57
+ feat_sizes: [32, 32]
58
+ rope_k_repeat: True
59
+ embedding_dim: 256
60
+ num_heads: 1
61
+ downsample_rate: 1
62
+ dropout: 0.1
63
+ kv_in_dim: 64
64
+ num_layers: 4
65
+
66
+ memory_encoder:
67
+ _target_: sam2.modeling.memory_encoder.MemoryEncoder
68
+ out_dim: 64
69
+ position_encoding:
70
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
71
+ num_pos_feats: 64
72
+ normalize: true
73
+ scale: null
74
+ temperature: 10000
75
+ mask_downsampler:
76
+ _target_: sam2.modeling.memory_encoder.MaskDownSampler
77
+ kernel_size: 3
78
+ stride: 2
79
+ padding: 1
80
+ fuser:
81
+ _target_: sam2.modeling.memory_encoder.Fuser
82
+ layer:
83
+ _target_: sam2.modeling.memory_encoder.CXBlock
84
+ dim: 256
85
+ kernel_size: 7
86
+ padding: 3
87
+ layer_scale_init_value: 1e-6
88
+ use_dwconv: True # depth-wise convs
89
+ num_layers: 2
90
+
91
+ num_maskmem: 7
92
+ image_size: 1024
93
+ # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
94
+ # SAM decoder
95
+ sigmoid_scale_for_mem_enc: 20.0
96
+ sigmoid_bias_for_mem_enc: -10.0
97
+ use_mask_input_as_output_without_sam: true
98
+ # Memory
99
+ directly_add_no_mem_embed: true
100
+ # use high-resolution feature map in the SAM mask decoder
101
+ # use_high_res_features_in_sam: true
102
+ use_high_res_features_in_sam: false
103
+ # output 3 masks on the first click on initial conditioning frames
104
+ multimask_output_in_sam: true
105
+ # SAM heads
106
+ iou_prediction_use_sigmoid: True
107
+ # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
108
+ use_obj_ptrs_in_encoder: true
109
+ add_tpos_enc_to_obj_ptrs: false
110
+ only_obj_ptrs_in_the_past_for_eval: true
111
+ # object occlusion prediction
112
+ pred_obj_scores: true
113
+ pred_obj_scores_mlp: true
114
+ fixed_no_obj_ptr: true
115
+ # multimask tracking settings
116
+ multimask_output_for_tracking: true
117
+ use_multimask_token_for_obj_ptr: true
118
+ multimask_min_pt_num: 0
119
+ multimask_max_pt_num: 1
120
+ use_mlp_for_obj_ptr_proj: true
121
+ # Compilation flag
122
+ # HieraT does not currently support compilation, should always be set to False
123
+ compile_image_encoder: false
sam2/configs/efficienttam_ti_1.yaml ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # Model
4
+ model:
5
+ _target_: sam2.modeling.sam2_base.SAM2Base
6
+ image_encoder:
7
+ _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
8
+ scalp: 0
9
+ trunk:
10
+ _target_: sam2.modeling.backbones.vitdet.ViT
11
+ patch_size: 16
12
+ embed_dim: 192
13
+ depth: 12
14
+ num_heads: 3
15
+ mlp_ratio: 4.0
16
+ qkv_bias: true
17
+ drop_path_rate: 0.0
18
+ use_rel_pos: false
19
+ window_size: 14
20
+ window_block_indexes: [0, 1, 3, 4, 6, 7, 9, 10]
21
+ neck:
22
+ _target_: sam2.modeling.backbones.image_encoder.ViTDetNeck
23
+ position_encoding:
24
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
25
+ num_pos_feats: 256
26
+ normalize: true
27
+ scale: null
28
+ temperature: 10000
29
+ d_model: 256
30
+ backbone_channel_list: [192,]
31
+ neck_norm: LN
32
+
33
+ memory_attention:
34
+ _target_: sam2.modeling.memory_attention.MemoryAttention
35
+ d_model: 256
36
+ pos_enc_at_input: true
37
+ layer:
38
+ _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
39
+ activation: relu
40
+ dim_feedforward: 2048
41
+ dropout: 0.1
42
+ pos_enc_at_attn: false
43
+ self_attention:
44
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
45
+ rope_theta: 10000.0
46
+ feat_sizes: [32, 32]
47
+ embedding_dim: 256
48
+ num_heads: 1
49
+ downsample_rate: 1
50
+ dropout: 0.1
51
+ d_model: 256
52
+ pos_enc_at_cross_attn_keys: true
53
+ pos_enc_at_cross_attn_queries: false
54
+ cross_attention:
55
+ _target_: sam2.modeling.sam.transformer.EfficientRoPEAttention1
56
+ rope_theta: 10000.0
57
+ feat_sizes: [32, 32]
58
+ rope_k_repeat: True
59
+ embedding_dim: 256
60
+ num_heads: 1
61
+ downsample_rate: 1
62
+ dropout: 0.1
63
+ kv_in_dim: 64
64
+ num_layers: 4
65
+
66
+ memory_encoder:
67
+ _target_: sam2.modeling.memory_encoder.MemoryEncoder
68
+ out_dim: 64
69
+ position_encoding:
70
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
71
+ num_pos_feats: 64
72
+ normalize: true
73
+ scale: null
74
+ temperature: 10000
75
+ mask_downsampler:
76
+ _target_: sam2.modeling.memory_encoder.MaskDownSampler
77
+ kernel_size: 3
78
+ stride: 2
79
+ padding: 1
80
+ fuser:
81
+ _target_: sam2.modeling.memory_encoder.Fuser
82
+ layer:
83
+ _target_: sam2.modeling.memory_encoder.CXBlock
84
+ dim: 256
85
+ kernel_size: 7
86
+ padding: 3
87
+ layer_scale_init_value: 1e-6
88
+ use_dwconv: True # depth-wise convs
89
+ num_layers: 2
90
+
91
+ num_maskmem: 7
92
+ image_size: 1024
93
+ # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
94
+ # SAM decoder
95
+ sigmoid_scale_for_mem_enc: 20.0
96
+ sigmoid_bias_for_mem_enc: -10.0
97
+ use_mask_input_as_output_without_sam: true
98
+ # Memory
99
+ directly_add_no_mem_embed: true
100
+ # use high-resolution feature map in the SAM mask decoder
101
+ # use_high_res_features_in_sam: true
102
+ use_high_res_features_in_sam: false
103
+ # output 3 masks on the first click on initial conditioning frames
104
+ multimask_output_in_sam: true
105
+ # SAM heads
106
+ iou_prediction_use_sigmoid: True
107
+ # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
108
+ use_obj_ptrs_in_encoder: true
109
+ add_tpos_enc_to_obj_ptrs: false
110
+ only_obj_ptrs_in_the_past_for_eval: true
111
+ # object occlusion prediction
112
+ pred_obj_scores: true
113
+ pred_obj_scores_mlp: true
114
+ fixed_no_obj_ptr: true
115
+ # multimask tracking settings
116
+ multimask_output_for_tracking: true
117
+ use_multimask_token_for_obj_ptr: true
118
+ multimask_min_pt_num: 0
119
+ multimask_max_pt_num: 1
120
+ use_mlp_for_obj_ptr_proj: true
121
+ # Compilation flag
122
+ # HieraT does not currently support compilation, should always be set to False
123
+ compile_image_encoder: false
sam2/configs/efficienttam_ti_2.yaml ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # Model
4
+ model:
5
+ _target_: sam2.modeling.sam2_base.SAM2Base
6
+ image_encoder:
7
+ _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
8
+ scalp: 0
9
+ trunk:
10
+ _target_: sam2.modeling.backbones.vitdet.ViT
11
+ patch_size: 16
12
+ embed_dim: 192
13
+ depth: 12
14
+ num_heads: 3
15
+ mlp_ratio: 4.0
16
+ qkv_bias: true
17
+ drop_path_rate: 0.0
18
+ use_rel_pos: false
19
+ window_size: 14
20
+ window_block_indexes: [0, 1, 3, 4, 6, 7, 9, 10]
21
+ neck:
22
+ _target_: sam2.modeling.backbones.image_encoder.ViTDetNeck
23
+ position_encoding:
24
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
25
+ num_pos_feats: 256
26
+ normalize: true
27
+ scale: null
28
+ temperature: 10000
29
+ d_model: 256
30
+ backbone_channel_list: [192,]
31
+ neck_norm: LN
32
+
33
+ memory_attention:
34
+ _target_: sam2.modeling.memory_attention.MemoryAttention
35
+ d_model: 256
36
+ pos_enc_at_input: true
37
+ layer:
38
+ _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
39
+ activation: relu
40
+ dim_feedforward: 2048
41
+ dropout: 0.1
42
+ pos_enc_at_attn: false
43
+ self_attention:
44
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
45
+ rope_theta: 10000.0
46
+ feat_sizes: [32, 32]
47
+ embedding_dim: 256
48
+ num_heads: 1
49
+ downsample_rate: 1
50
+ dropout: 0.1
51
+ d_model: 256
52
+ pos_enc_at_cross_attn_keys: true
53
+ pos_enc_at_cross_attn_queries: false
54
+ cross_attention:
55
+ _target_: sam2.modeling.sam.transformer.EfficientRoPEAttention2
56
+ rope_theta: 10000.0
57
+ feat_sizes: [32, 32]
58
+ rope_k_repeat: True
59
+ embedding_dim: 256
60
+ num_heads: 1
61
+ downsample_rate: 1
62
+ dropout: 0.1
63
+ kv_in_dim: 64
64
+ num_layers: 4
65
+
66
+ memory_encoder:
67
+ _target_: sam2.modeling.memory_encoder.MemoryEncoder
68
+ out_dim: 64
69
+ position_encoding:
70
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
71
+ num_pos_feats: 64
72
+ normalize: true
73
+ scale: null
74
+ temperature: 10000
75
+ mask_downsampler:
76
+ _target_: sam2.modeling.memory_encoder.MaskDownSampler
77
+ kernel_size: 3
78
+ stride: 2
79
+ padding: 1
80
+ fuser:
81
+ _target_: sam2.modeling.memory_encoder.Fuser
82
+ layer:
83
+ _target_: sam2.modeling.memory_encoder.CXBlock
84
+ dim: 256
85
+ kernel_size: 7
86
+ padding: 3
87
+ layer_scale_init_value: 1e-6
88
+ use_dwconv: True # depth-wise convs
89
+ num_layers: 2
90
+
91
+ num_maskmem: 7
92
+ image_size: 1024
93
+ # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
94
+ # SAM decoder
95
+ sigmoid_scale_for_mem_enc: 20.0
96
+ sigmoid_bias_for_mem_enc: -10.0
97
+ use_mask_input_as_output_without_sam: true
98
+ # Memory
99
+ directly_add_no_mem_embed: true
100
+ # use high-resolution feature map in the SAM mask decoder
101
+ # use_high_res_features_in_sam: true
102
+ use_high_res_features_in_sam: false
103
+ # output 3 masks on the first click on initial conditioning frames
104
+ multimask_output_in_sam: true
105
+ # SAM heads
106
+ iou_prediction_use_sigmoid: True
107
+ # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
108
+ use_obj_ptrs_in_encoder: true
109
+ add_tpos_enc_to_obj_ptrs: false
110
+ only_obj_ptrs_in_the_past_for_eval: true
111
+ # object occlusion prediction
112
+ pred_obj_scores: true
113
+ pred_obj_scores_mlp: true
114
+ fixed_no_obj_ptr: true
115
+ # multimask tracking settings
116
+ multimask_output_for_tracking: true
117
+ use_multimask_token_for_obj_ptr: true
118
+ multimask_min_pt_num: 0
119
+ multimask_max_pt_num: 1
120
+ use_mlp_for_obj_ptr_proj: true
121
+ # Compilation flag
122
+ # HieraT does not currently support compilation, should always be set to False
123
+ compile_image_encoder: false
sam2/configs/efficienttam_ti_512x512.yaml ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # Model
4
+ model:
5
+ _target_: sam2.modeling.sam2_base.SAM2Base
6
+ image_encoder:
7
+ _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
8
+ scalp: 0
9
+ trunk:
10
+ _target_: sam2.modeling.backbones.vitdet.ViT
11
+ patch_size: 16
12
+ embed_dim: 192
13
+ depth: 12
14
+ num_heads: 3
15
+ mlp_ratio: 4.0
16
+ qkv_bias: true
17
+ drop_path_rate: 0.0
18
+ use_rel_pos: false
19
+ window_size: 14
20
+ window_block_indexes: [0, 1, 3, 4, 6, 7, 9, 10]
21
+ neck:
22
+ _target_: sam2.modeling.backbones.image_encoder.ViTDetNeck
23
+ position_encoding:
24
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
25
+ num_pos_feats: 256
26
+ normalize: true
27
+ scale: null
28
+ temperature: 10000
29
+ d_model: 256
30
+ backbone_channel_list: [192,]
31
+ neck_norm: LN
32
+
33
+ memory_attention:
34
+ _target_: sam2.modeling.memory_attention.MemoryAttention
35
+ d_model: 256
36
+ pos_enc_at_input: true
37
+ layer:
38
+ _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
39
+ activation: relu
40
+ dim_feedforward: 2048
41
+ dropout: 0.1
42
+ pos_enc_at_attn: false
43
+ self_attention:
44
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
45
+ rope_theta: 10000.0
46
+ feat_sizes: [32, 32]
47
+ embedding_dim: 256
48
+ num_heads: 1
49
+ downsample_rate: 1
50
+ dropout: 0.1
51
+ d_model: 256
52
+ pos_enc_at_cross_attn_keys: true
53
+ pos_enc_at_cross_attn_queries: false
54
+ cross_attention:
55
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
56
+ rope_theta: 10000.0
57
+ feat_sizes: [32, 32]
58
+ rope_k_repeat: True
59
+ embedding_dim: 256
60
+ num_heads: 1
61
+ downsample_rate: 1
62
+ dropout: 0.1
63
+ kv_in_dim: 64
64
+ num_layers: 4
65
+
66
+ memory_encoder:
67
+ _target_: sam2.modeling.memory_encoder.MemoryEncoder
68
+ out_dim: 64
69
+ position_encoding:
70
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
71
+ num_pos_feats: 64
72
+ normalize: true
73
+ scale: null
74
+ temperature: 10000
75
+ mask_downsampler:
76
+ _target_: sam2.modeling.memory_encoder.MaskDownSampler
77
+ kernel_size: 3
78
+ stride: 2
79
+ padding: 1
80
+ fuser:
81
+ _target_: sam2.modeling.memory_encoder.Fuser
82
+ layer:
83
+ _target_: sam2.modeling.memory_encoder.CXBlock
84
+ dim: 256
85
+ kernel_size: 7
86
+ padding: 3
87
+ layer_scale_init_value: 1e-6
88
+ use_dwconv: True # depth-wise convs
89
+ num_layers: 2
90
+
91
+ num_maskmem: 7
92
+ image_size: 512
93
+ # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
94
+ # SAM decoder
95
+ sigmoid_scale_for_mem_enc: 20.0
96
+ sigmoid_bias_for_mem_enc: -10.0
97
+ use_mask_input_as_output_without_sam: true
98
+ # Memory
99
+ directly_add_no_mem_embed: true
100
+ # use high-resolution feature map in the SAM mask decoder
101
+ # use_high_res_features_in_sam: true
102
+ use_high_res_features_in_sam: false
103
+ # output 3 masks on the first click on initial conditioning frames
104
+ multimask_output_in_sam: true
105
+ # SAM heads
106
+ iou_prediction_use_sigmoid: True
107
+ # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
108
+ use_obj_ptrs_in_encoder: true
109
+ add_tpos_enc_to_obj_ptrs: false
110
+ only_obj_ptrs_in_the_past_for_eval: true
111
+ # object occlusion prediction
112
+ pred_obj_scores: true
113
+ pred_obj_scores_mlp: true
114
+ fixed_no_obj_ptr: true
115
+ # multimask tracking settings
116
+ multimask_output_for_tracking: true
117
+ use_multimask_token_for_obj_ptr: true
118
+ multimask_min_pt_num: 0
119
+ multimask_max_pt_num: 1
120
+ use_mlp_for_obj_ptr_proj: true
121
+ # Compilation flag
122
+ # HieraT does not currently support compilation, should always be set to False
123
+ compile_image_encoder: false
sam2/configs/sam2_hiera_b+.yaml ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # Model
4
+ model:
5
+ _target_: sam2.modeling.sam2_base.SAM2Base
6
+ image_encoder:
7
+ _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
8
+ scalp: 1
9
+ trunk:
10
+ _target_: sam2.modeling.backbones.hieradet.Hiera
11
+ embed_dim: 112
12
+ num_heads: 2
13
+ neck:
14
+ _target_: sam2.modeling.backbones.image_encoder.FpnNeck
15
+ position_encoding:
16
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
17
+ num_pos_feats: 256
18
+ normalize: true
19
+ scale: null
20
+ temperature: 10000
21
+ d_model: 256
22
+ backbone_channel_list: [896, 448, 224, 112]
23
+ fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
24
+ fpn_interp_model: nearest
25
+
26
+ memory_attention:
27
+ _target_: sam2.modeling.memory_attention.MemoryAttention
28
+ d_model: 256
29
+ pos_enc_at_input: true
30
+ layer:
31
+ _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
32
+ activation: relu
33
+ dim_feedforward: 2048
34
+ dropout: 0.1
35
+ pos_enc_at_attn: false
36
+ self_attention:
37
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
38
+ rope_theta: 10000.0
39
+ feat_sizes: [32, 32]
40
+ embedding_dim: 256
41
+ num_heads: 1
42
+ downsample_rate: 1
43
+ dropout: 0.1
44
+ d_model: 256
45
+ pos_enc_at_cross_attn_keys: true
46
+ pos_enc_at_cross_attn_queries: false
47
+ cross_attention:
48
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
49
+ rope_theta: 10000.0
50
+ feat_sizes: [32, 32]
51
+ rope_k_repeat: True
52
+ embedding_dim: 256
53
+ num_heads: 1
54
+ downsample_rate: 1
55
+ dropout: 0.1
56
+ kv_in_dim: 64
57
+ num_layers: 4
58
+
59
+ memory_encoder:
60
+ _target_: sam2.modeling.memory_encoder.MemoryEncoder
61
+ out_dim: 64
62
+ position_encoding:
63
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
64
+ num_pos_feats: 64
65
+ normalize: true
66
+ scale: null
67
+ temperature: 10000
68
+ mask_downsampler:
69
+ _target_: sam2.modeling.memory_encoder.MaskDownSampler
70
+ kernel_size: 3
71
+ stride: 2
72
+ padding: 1
73
+ fuser:
74
+ _target_: sam2.modeling.memory_encoder.Fuser
75
+ layer:
76
+ _target_: sam2.modeling.memory_encoder.CXBlock
77
+ dim: 256
78
+ kernel_size: 7
79
+ padding: 3
80
+ layer_scale_init_value: 1e-6
81
+ use_dwconv: True # depth-wise convs
82
+ num_layers: 2
83
+
84
+ num_maskmem: 7
85
+ image_size: 1024
86
+ # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
87
+ sigmoid_scale_for_mem_enc: 20.0
88
+ sigmoid_bias_for_mem_enc: -10.0
89
+ use_mask_input_as_output_without_sam: true
90
+ # Memory
91
+ directly_add_no_mem_embed: true
92
+ # use high-resolution feature map in the SAM mask decoder
93
+ use_high_res_features_in_sam: true
94
+ # output 3 masks on the first click on initial conditioning frames
95
+ multimask_output_in_sam: true
96
+ # SAM heads
97
+ iou_prediction_use_sigmoid: True
98
+ # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
99
+ use_obj_ptrs_in_encoder: true
100
+ add_tpos_enc_to_obj_ptrs: false
101
+ only_obj_ptrs_in_the_past_for_eval: true
102
+ # object occlusion prediction
103
+ pred_obj_scores: true
104
+ pred_obj_scores_mlp: true
105
+ fixed_no_obj_ptr: true
106
+ # multimask tracking settings
107
+ multimask_output_for_tracking: true
108
+ use_multimask_token_for_obj_ptr: true
109
+ multimask_min_pt_num: 0
110
+ multimask_max_pt_num: 1
111
+ use_mlp_for_obj_ptr_proj: true
112
+ # Compilation flag
113
+ compile_image_encoder: False
sam2/configs/sam2_hiera_l.yaml ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # Model
4
+ model:
5
+ _target_: sam2.modeling.sam2_base.SAM2Base
6
+ image_encoder:
7
+ _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
8
+ scalp: 1
9
+ trunk:
10
+ _target_: sam2.modeling.backbones.hieradet.Hiera
11
+ embed_dim: 144
12
+ num_heads: 2
13
+ stages: [2, 6, 36, 4]
14
+ global_att_blocks: [23, 33, 43]
15
+ window_pos_embed_bkg_spatial_size: [7, 7]
16
+ window_spec: [8, 4, 16, 8]
17
+ neck:
18
+ _target_: sam2.modeling.backbones.image_encoder.FpnNeck
19
+ position_encoding:
20
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
21
+ num_pos_feats: 256
22
+ normalize: true
23
+ scale: null
24
+ temperature: 10000
25
+ d_model: 256
26
+ backbone_channel_list: [1152, 576, 288, 144]
27
+ fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
28
+ fpn_interp_model: nearest
29
+
30
+ memory_attention:
31
+ _target_: sam2.modeling.memory_attention.MemoryAttention
32
+ d_model: 256
33
+ pos_enc_at_input: true
34
+ layer:
35
+ _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
36
+ activation: relu
37
+ dim_feedforward: 2048
38
+ dropout: 0.1
39
+ pos_enc_at_attn: false
40
+ self_attention:
41
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
42
+ rope_theta: 10000.0
43
+ feat_sizes: [32, 32]
44
+ embedding_dim: 256
45
+ num_heads: 1
46
+ downsample_rate: 1
47
+ dropout: 0.1
48
+ d_model: 256
49
+ pos_enc_at_cross_attn_keys: true
50
+ pos_enc_at_cross_attn_queries: false
51
+ cross_attention:
52
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
53
+ rope_theta: 10000.0
54
+ feat_sizes: [32, 32]
55
+ rope_k_repeat: True
56
+ embedding_dim: 256
57
+ num_heads: 1
58
+ downsample_rate: 1
59
+ dropout: 0.1
60
+ kv_in_dim: 64
61
+ num_layers: 4
62
+
63
+ memory_encoder:
64
+ _target_: sam2.modeling.memory_encoder.MemoryEncoder
65
+ out_dim: 64
66
+ position_encoding:
67
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
68
+ num_pos_feats: 64
69
+ normalize: true
70
+ scale: null
71
+ temperature: 10000
72
+ mask_downsampler:
73
+ _target_: sam2.modeling.memory_encoder.MaskDownSampler
74
+ kernel_size: 3
75
+ stride: 2
76
+ padding: 1
77
+ fuser:
78
+ _target_: sam2.modeling.memory_encoder.Fuser
79
+ layer:
80
+ _target_: sam2.modeling.memory_encoder.CXBlock
81
+ dim: 256
82
+ kernel_size: 7
83
+ padding: 3
84
+ layer_scale_init_value: 1e-6
85
+ use_dwconv: True # depth-wise convs
86
+ num_layers: 2
87
+
88
+ num_maskmem: 7
89
+ image_size: 1024
90
+ # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
91
+ sigmoid_scale_for_mem_enc: 20.0
92
+ sigmoid_bias_for_mem_enc: -10.0
93
+ use_mask_input_as_output_without_sam: true
94
+ # Memory
95
+ directly_add_no_mem_embed: true
96
+ # use high-resolution feature map in the SAM mask decoder
97
+ use_high_res_features_in_sam: true
98
+ # output 3 masks on the first click on initial conditioning frames
99
+ multimask_output_in_sam: true
100
+ # SAM heads
101
+ iou_prediction_use_sigmoid: True
102
+ # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
103
+ use_obj_ptrs_in_encoder: true
104
+ add_tpos_enc_to_obj_ptrs: false
105
+ only_obj_ptrs_in_the_past_for_eval: true
106
+ # object occlusion prediction
107
+ pred_obj_scores: true
108
+ pred_obj_scores_mlp: true
109
+ fixed_no_obj_ptr: true
110
+ # multimask tracking settings
111
+ multimask_output_for_tracking: true
112
+ use_multimask_token_for_obj_ptr: true
113
+ multimask_min_pt_num: 0
114
+ multimask_max_pt_num: 1
115
+ use_mlp_for_obj_ptr_proj: true
116
+ # Compilation flag
117
+ compile_image_encoder: False
sam2/configs/sam2_hiera_s.yaml ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # Model
4
+ model:
5
+ _target_: sam2.modeling.sam2_base.SAM2Base
6
+ image_encoder:
7
+ _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
8
+ scalp: 1
9
+ trunk:
10
+ _target_: sam2.modeling.backbones.hieradet.Hiera
11
+ embed_dim: 96
12
+ num_heads: 1
13
+ stages: [1, 2, 11, 2]
14
+ global_att_blocks: [7, 10, 13]
15
+ window_pos_embed_bkg_spatial_size: [7, 7]
16
+ neck:
17
+ _target_: sam2.modeling.backbones.image_encoder.FpnNeck
18
+ position_encoding:
19
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
20
+ num_pos_feats: 256
21
+ normalize: true
22
+ scale: null
23
+ temperature: 10000
24
+ d_model: 256
25
+ backbone_channel_list: [768, 384, 192, 96]
26
+ fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
27
+ fpn_interp_model: nearest
28
+
29
+ memory_attention:
30
+ _target_: sam2.modeling.memory_attention.MemoryAttention
31
+ d_model: 256
32
+ pos_enc_at_input: true
33
+ layer:
34
+ _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
35
+ activation: relu
36
+ dim_feedforward: 2048
37
+ dropout: 0.1
38
+ pos_enc_at_attn: false
39
+ self_attention:
40
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
41
+ rope_theta: 10000.0
42
+ feat_sizes: [32, 32]
43
+ embedding_dim: 256
44
+ num_heads: 1
45
+ downsample_rate: 1
46
+ dropout: 0.1
47
+ d_model: 256
48
+ pos_enc_at_cross_attn_keys: true
49
+ pos_enc_at_cross_attn_queries: false
50
+ cross_attention:
51
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
52
+ rope_theta: 10000.0
53
+ feat_sizes: [32, 32]
54
+ rope_k_repeat: True
55
+ embedding_dim: 256
56
+ num_heads: 1
57
+ downsample_rate: 1
58
+ dropout: 0.1
59
+ kv_in_dim: 64
60
+ num_layers: 4
61
+
62
+ memory_encoder:
63
+ _target_: sam2.modeling.memory_encoder.MemoryEncoder
64
+ out_dim: 64
65
+ position_encoding:
66
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
67
+ num_pos_feats: 64
68
+ normalize: true
69
+ scale: null
70
+ temperature: 10000
71
+ mask_downsampler:
72
+ _target_: sam2.modeling.memory_encoder.MaskDownSampler
73
+ kernel_size: 3
74
+ stride: 2
75
+ padding: 1
76
+ fuser:
77
+ _target_: sam2.modeling.memory_encoder.Fuser
78
+ layer:
79
+ _target_: sam2.modeling.memory_encoder.CXBlock
80
+ dim: 256
81
+ kernel_size: 7
82
+ padding: 3
83
+ layer_scale_init_value: 1e-6
84
+ use_dwconv: True # depth-wise convs
85
+ num_layers: 2
86
+
87
+ num_maskmem: 7
88
+ image_size: 1024
89
+ # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
90
+ sigmoid_scale_for_mem_enc: 20.0
91
+ sigmoid_bias_for_mem_enc: -10.0
92
+ use_mask_input_as_output_without_sam: true
93
+ # Memory
94
+ directly_add_no_mem_embed: true
95
+ # use high-resolution feature map in the SAM mask decoder
96
+ use_high_res_features_in_sam: true
97
+ # output 3 masks on the first click on initial conditioning frames
98
+ multimask_output_in_sam: true
99
+ # SAM heads
100
+ iou_prediction_use_sigmoid: True
101
+ # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
102
+ use_obj_ptrs_in_encoder: true
103
+ add_tpos_enc_to_obj_ptrs: false
104
+ only_obj_ptrs_in_the_past_for_eval: true
105
+ # object occlusion prediction
106
+ pred_obj_scores: true
107
+ pred_obj_scores_mlp: true
108
+ fixed_no_obj_ptr: true
109
+ # multimask tracking settings
110
+ multimask_output_for_tracking: true
111
+ use_multimask_token_for_obj_ptr: true
112
+ multimask_min_pt_num: 0
113
+ multimask_max_pt_num: 1
114
+ use_mlp_for_obj_ptr_proj: true
115
+ # Compilation flag
116
+ compile_image_encoder: False
sam2/configs/sam2_hiera_t.yaml ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # Model
4
+ model:
5
+ _target_: sam2.modeling.sam2_base.SAM2Base
6
+ image_encoder:
7
+ _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
8
+ scalp: 1
9
+ trunk:
10
+ _target_: sam2.modeling.backbones.hieradet.Hiera
11
+ embed_dim: 96
12
+ num_heads: 1
13
+ stages: [1, 2, 7, 2]
14
+ global_att_blocks: [5, 7, 9]
15
+ window_pos_embed_bkg_spatial_size: [7, 7]
16
+ neck:
17
+ _target_: sam2.modeling.backbones.image_encoder.FpnNeck
18
+ position_encoding:
19
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
20
+ num_pos_feats: 256
21
+ normalize: true
22
+ scale: null
23
+ temperature: 10000
24
+ d_model: 256
25
+ backbone_channel_list: [768, 384, 192, 96]
26
+ fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
27
+ fpn_interp_model: nearest
28
+
29
+ memory_attention:
30
+ _target_: sam2.modeling.memory_attention.MemoryAttention
31
+ d_model: 256
32
+ pos_enc_at_input: true
33
+ layer:
34
+ _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
35
+ activation: relu
36
+ dim_feedforward: 2048
37
+ dropout: 0.1
38
+ pos_enc_at_attn: false
39
+ self_attention:
40
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
41
+ rope_theta: 10000.0
42
+ feat_sizes: [32, 32]
43
+ embedding_dim: 256
44
+ num_heads: 1
45
+ downsample_rate: 1
46
+ dropout: 0.1
47
+ d_model: 256
48
+ pos_enc_at_cross_attn_keys: true
49
+ pos_enc_at_cross_attn_queries: false
50
+ cross_attention:
51
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
52
+ rope_theta: 10000.0
53
+ feat_sizes: [32, 32]
54
+ rope_k_repeat: True
55
+ embedding_dim: 256
56
+ num_heads: 1
57
+ downsample_rate: 1
58
+ dropout: 0.1
59
+ kv_in_dim: 64
60
+ num_layers: 4
61
+
62
+ memory_encoder:
63
+ _target_: sam2.modeling.memory_encoder.MemoryEncoder
64
+ out_dim: 64
65
+ position_encoding:
66
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
67
+ num_pos_feats: 64
68
+ normalize: true
69
+ scale: null
70
+ temperature: 10000
71
+ mask_downsampler:
72
+ _target_: sam2.modeling.memory_encoder.MaskDownSampler
73
+ kernel_size: 3
74
+ stride: 2
75
+ padding: 1
76
+ fuser:
77
+ _target_: sam2.modeling.memory_encoder.Fuser
78
+ layer:
79
+ _target_: sam2.modeling.memory_encoder.CXBlock
80
+ dim: 256
81
+ kernel_size: 7
82
+ padding: 3
83
+ layer_scale_init_value: 1e-6
84
+ use_dwconv: True # depth-wise convs
85
+ num_layers: 2
86
+
87
+ num_maskmem: 7
88
+ image_size: 1024
89
+ # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
90
+ # SAM decoder
91
+ sigmoid_scale_for_mem_enc: 20.0
92
+ sigmoid_bias_for_mem_enc: -10.0
93
+ use_mask_input_as_output_without_sam: true
94
+ # Memory
95
+ directly_add_no_mem_embed: true
96
+ # use high-resolution feature map in the SAM mask decoder
97
+ use_high_res_features_in_sam: true
98
+ # output 3 masks on the first click on initial conditioning frames
99
+ multimask_output_in_sam: true
100
+ # SAM heads
101
+ iou_prediction_use_sigmoid: True
102
+ # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
103
+ use_obj_ptrs_in_encoder: true
104
+ add_tpos_enc_to_obj_ptrs: false
105
+ only_obj_ptrs_in_the_past_for_eval: true
106
+ # object occlusion prediction
107
+ pred_obj_scores: true
108
+ pred_obj_scores_mlp: true
109
+ fixed_no_obj_ptr: true
110
+ # multimask tracking settings
111
+ multimask_output_for_tracking: true
112
+ use_multimask_token_for_obj_ptr: true
113
+ multimask_min_pt_num: 0
114
+ multimask_max_pt_num: 1
115
+ use_mlp_for_obj_ptr_proj: true
116
+ # Compilation flag
117
+ # HieraT does not currently support compilation, should always be set to False
118
+ compile_image_encoder: False
sam2/modeling/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
sam2/modeling/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (182 Bytes). View file
 
sam2/modeling/__pycache__/memory_attention.cpython-312.pyc ADDED
Binary file (7.27 kB). View file
 
sam2/modeling/__pycache__/memory_encoder.cpython-312.pyc ADDED
Binary file (7.85 kB). View file
 
sam2/modeling/__pycache__/position_encoding.cpython-312.pyc ADDED
Binary file (14.4 kB). View file
 
sam2/modeling/__pycache__/sam2_base.cpython-312.pyc ADDED
Binary file (29.2 kB). View file
 
sam2/modeling/__pycache__/sam2_utils.cpython-312.pyc ADDED
Binary file (11.7 kB). View file
 
sam2/modeling/backbones/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
sam2/modeling/backbones/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (192 Bytes). View file
 
sam2/modeling/backbones/__pycache__/hieradet.cpython-312.pyc ADDED
Binary file (12 kB). View file
 
sam2/modeling/backbones/__pycache__/image_encoder.cpython-312.pyc ADDED
Binary file (8 kB). View file
 
sam2/modeling/backbones/__pycache__/utils.cpython-312.pyc ADDED
Binary file (5.68 kB). View file
 
sam2/modeling/backbones/__pycache__/vitdet.cpython-312.pyc ADDED
Binary file (12.9 kB). View file
 
sam2/modeling/backbones/hieradet.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from functools import partial
8
+ from typing import List, Tuple, Union
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+
14
+ from sam2.modeling.backbones.utils import (
15
+ PatchEmbed,
16
+ window_partition,
17
+ window_unpartition,
18
+ )
19
+ from sam2.modeling.sam2_utils import MLP, DropPath
20
+
21
+
22
+ def do_pool(x: torch.Tensor, pool: nn.Module, norm: nn.Module = None) -> torch.Tensor:
23
+ if pool is None:
24
+ return x
25
+ # (B, H, W, C) -> (B, C, H, W)
26
+ x = x.permute(0, 3, 1, 2)
27
+ x = pool(x)
28
+ # (B, C, H', W') -> (B, H', W', C)
29
+ x = x.permute(0, 2, 3, 1)
30
+ if norm:
31
+ x = norm(x)
32
+
33
+ return x
34
+
35
+
36
+ class MultiScaleAttention(nn.Module):
37
+ def __init__(
38
+ self,
39
+ dim: int,
40
+ dim_out: int,
41
+ num_heads: int,
42
+ q_pool: nn.Module = None,
43
+ ):
44
+ super().__init__()
45
+
46
+ self.dim = dim
47
+ self.dim_out = dim_out
48
+
49
+ self.num_heads = num_heads
50
+ head_dim = dim_out // num_heads
51
+ self.scale = head_dim**-0.5
52
+
53
+ self.q_pool = q_pool
54
+ self.qkv = nn.Linear(dim, dim_out * 3)
55
+ self.proj = nn.Linear(dim_out, dim_out)
56
+
57
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
58
+ B, H, W, _ = x.shape
59
+ # qkv with shape (B, H * W, 3, nHead, C)
60
+ qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1)
61
+ # q, k, v with shape (B, H * W, nheads, C)
62
+ q, k, v = torch.unbind(qkv, 2)
63
+
64
+ # Q pooling (for downsample at stage changes)
65
+ if self.q_pool:
66
+ q = do_pool(q.reshape(B, H, W, -1), self.q_pool)
67
+ H, W = q.shape[1:3] # downsampled shape
68
+ q = q.reshape(B, H * W, self.num_heads, -1)
69
+
70
+ # Torch's SDPA expects [B, nheads, H*W, C] so we transpose
71
+ x = F.scaled_dot_product_attention(
72
+ q.transpose(1, 2),
73
+ k.transpose(1, 2),
74
+ v.transpose(1, 2),
75
+ )
76
+ # Transpose back
77
+ x = x.transpose(1, 2)
78
+ x = x.reshape(B, H, W, -1)
79
+
80
+ x = self.proj(x)
81
+
82
+ return x
83
+
84
+
85
+ class MultiScaleBlock(nn.Module):
86
+ def __init__(
87
+ self,
88
+ dim: int,
89
+ dim_out: int,
90
+ num_heads: int,
91
+ mlp_ratio: float = 4.0,
92
+ drop_path: float = 0.0,
93
+ norm_layer: Union[nn.Module, str] = "LayerNorm",
94
+ q_stride: Tuple[int, int] = None,
95
+ act_layer: nn.Module = nn.GELU,
96
+ window_size: int = 0,
97
+ ):
98
+ super().__init__()
99
+
100
+ if isinstance(norm_layer, str):
101
+ norm_layer = partial(getattr(nn, norm_layer), eps=1e-6)
102
+
103
+ self.dim = dim
104
+ self.dim_out = dim_out
105
+ self.norm1 = norm_layer(dim)
106
+
107
+ self.window_size = window_size
108
+
109
+ self.pool, self.q_stride = None, q_stride
110
+ if self.q_stride:
111
+ self.pool = nn.MaxPool2d(
112
+ kernel_size=q_stride, stride=q_stride, ceil_mode=False
113
+ )
114
+
115
+ self.attn = MultiScaleAttention(
116
+ dim,
117
+ dim_out,
118
+ num_heads=num_heads,
119
+ q_pool=self.pool,
120
+ )
121
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
122
+
123
+ self.norm2 = norm_layer(dim_out)
124
+ self.mlp = MLP(
125
+ dim_out,
126
+ int(dim_out * mlp_ratio),
127
+ dim_out,
128
+ num_layers=2,
129
+ activation=act_layer,
130
+ )
131
+
132
+ if dim != dim_out:
133
+ self.proj = nn.Linear(dim, dim_out)
134
+
135
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
136
+ shortcut = x # B, H, W, C
137
+ x = self.norm1(x)
138
+
139
+ # Skip connection
140
+ if self.dim != self.dim_out:
141
+ shortcut = do_pool(self.proj(x), self.pool)
142
+
143
+ # Window partition
144
+ window_size = self.window_size
145
+ if window_size > 0:
146
+ H, W = x.shape[1], x.shape[2]
147
+ x, pad_hw = window_partition(x, window_size)
148
+
149
+ # Window Attention + Q Pooling (if stage change)
150
+ x = self.attn(x)
151
+ if self.q_stride:
152
+ # Shapes have changed due to Q pooling
153
+ window_size = self.window_size // self.q_stride[0]
154
+ H, W = shortcut.shape[1:3]
155
+
156
+ pad_h = (window_size - H % window_size) % window_size
157
+ pad_w = (window_size - W % window_size) % window_size
158
+ pad_hw = (H + pad_h, W + pad_w)
159
+
160
+ # Reverse window partition
161
+ if self.window_size > 0:
162
+ x = window_unpartition(x, window_size, pad_hw, (H, W))
163
+
164
+ x = shortcut + self.drop_path(x)
165
+ # MLP
166
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
167
+ return x
168
+
169
+
170
+ class Hiera(nn.Module):
171
+ """
172
+ Reference: https://arxiv.org/abs/2306.00989
173
+ """
174
+
175
+ def __init__(
176
+ self,
177
+ embed_dim: int = 96, # initial embed dim
178
+ num_heads: int = 1, # initial number of heads
179
+ drop_path_rate: float = 0.0, # stochastic depth
180
+ q_pool: int = 3, # number of q_pool stages
181
+ q_stride: Tuple[int, int] = (2, 2), # downsample stride bet. stages
182
+ stages: Tuple[int, ...] = (2, 3, 16, 3), # blocks per stage
183
+ dim_mul: float = 2.0, # dim_mul factor at stage shift
184
+ head_mul: float = 2.0, # head_mul factor at stage shift
185
+ window_pos_embed_bkg_spatial_size: Tuple[int, int] = (14, 14),
186
+ # window size per stage, when not using global att.
187
+ window_spec: Tuple[int, ...] = (
188
+ 8,
189
+ 4,
190
+ 14,
191
+ 7,
192
+ ),
193
+ # global attn in these blocks
194
+ global_att_blocks: Tuple[int, ...] = (
195
+ 12,
196
+ 16,
197
+ 20,
198
+ ),
199
+ return_interm_layers=True, # return feats from every stage
200
+ ):
201
+ super().__init__()
202
+
203
+ assert len(stages) == len(window_spec)
204
+ self.window_spec = window_spec
205
+
206
+ depth = sum(stages)
207
+ self.q_stride = q_stride
208
+ self.stage_ends = [sum(stages[:i]) - 1 for i in range(1, len(stages) + 1)]
209
+ assert 0 <= q_pool <= len(self.stage_ends[:-1])
210
+ self.q_pool_blocks = [x + 1 for x in self.stage_ends[:-1]][:q_pool]
211
+ self.return_interm_layers = return_interm_layers
212
+
213
+ self.patch_embed = PatchEmbed(
214
+ embed_dim=embed_dim,
215
+ )
216
+ # Which blocks have global att?
217
+ self.global_att_blocks = global_att_blocks
218
+
219
+ # Windowed positional embedding (https://arxiv.org/abs/2311.05613)
220
+ self.window_pos_embed_bkg_spatial_size = window_pos_embed_bkg_spatial_size
221
+ self.pos_embed = nn.Parameter(
222
+ torch.zeros(1, embed_dim, *self.window_pos_embed_bkg_spatial_size)
223
+ )
224
+ self.pos_embed_window = nn.Parameter(
225
+ torch.zeros(1, embed_dim, self.window_spec[0], self.window_spec[0])
226
+ )
227
+
228
+ dpr = [
229
+ x.item() for x in torch.linspace(0, drop_path_rate, depth)
230
+ ] # stochastic depth decay rule
231
+
232
+ cur_stage = 1
233
+ self.blocks = nn.ModuleList()
234
+
235
+ for i in range(depth):
236
+ dim_out = embed_dim
237
+ # lags by a block, so first block of
238
+ # next stage uses an initial window size
239
+ # of previous stage and final window size of current stage
240
+ window_size = self.window_spec[cur_stage - 1]
241
+
242
+ if self.global_att_blocks is not None:
243
+ window_size = 0 if i in self.global_att_blocks else window_size
244
+
245
+ if i - 1 in self.stage_ends:
246
+ dim_out = int(embed_dim * dim_mul)
247
+ num_heads = int(num_heads * head_mul)
248
+ cur_stage += 1
249
+
250
+ block = MultiScaleBlock(
251
+ dim=embed_dim,
252
+ dim_out=dim_out,
253
+ num_heads=num_heads,
254
+ drop_path=dpr[i],
255
+ q_stride=self.q_stride if i in self.q_pool_blocks else None,
256
+ window_size=window_size,
257
+ )
258
+
259
+ embed_dim = dim_out
260
+ self.blocks.append(block)
261
+
262
+ self.channel_list = (
263
+ [self.blocks[i].dim_out for i in self.stage_ends[::-1]]
264
+ if return_interm_layers
265
+ else [self.blocks[-1].dim_out]
266
+ )
267
+
268
+ def _get_pos_embed(self, hw: Tuple[int, int]) -> torch.Tensor:
269
+ h, w = hw
270
+ window_embed = self.pos_embed_window
271
+ pos_embed = F.interpolate(self.pos_embed, size=(h, w), mode="bicubic")
272
+ pos_embed = pos_embed + window_embed.tile(
273
+ [x // y for x, y in zip(pos_embed.shape, window_embed.shape)]
274
+ )
275
+ pos_embed = pos_embed.permute(0, 2, 3, 1)
276
+ return pos_embed
277
+
278
+ def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
279
+ x = self.patch_embed(x)
280
+ # x: (B, H, W, C)
281
+
282
+ # Add pos embed
283
+ x = x + self._get_pos_embed(x.shape[1:3])
284
+
285
+ outputs = []
286
+ for i, blk in enumerate(self.blocks):
287
+ x = blk(x)
288
+ if (i == self.stage_ends[-1]) or (
289
+ i in self.stage_ends and self.return_interm_layers
290
+ ):
291
+ feats = x.permute(0, 3, 1, 2)
292
+ outputs.append(feats)
293
+
294
+ return outputs
sam2/modeling/backbones/image_encoder.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import List, Optional
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+
13
+ from sam2.modeling.sam2_utils import LayerNorm2d
14
+
15
+ class ImageEncoder(nn.Module):
16
+ def __init__(
17
+ self,
18
+ trunk: nn.Module,
19
+ neck: nn.Module,
20
+ scalp: int = 0,
21
+ ):
22
+ super().__init__()
23
+ self.trunk = trunk
24
+ self.neck = neck
25
+ self.scalp = scalp
26
+ assert (
27
+ self.trunk.channel_list == self.neck.backbone_channel_list
28
+ ), f"Channel dims of trunk and neck do not match. Trunk: {self.trunk.channel_list}, neck: {self.neck.backbone_channel_list}"
29
+
30
+ def forward(self, sample: torch.Tensor):
31
+ # Forward through backbone
32
+ features, pos = self.neck(self.trunk(sample))
33
+ if self.scalp > 0:
34
+ # Discard the lowest resolution features
35
+ features, pos = features[: -self.scalp], pos[: -self.scalp]
36
+
37
+ src = features[-1]
38
+ output = {
39
+ "vision_features": src,
40
+ "vision_pos_enc": pos,
41
+ "backbone_fpn": features,
42
+ }
43
+ return output
44
+
45
+
46
+ class FpnNeck(nn.Module):
47
+ """
48
+ A modified variant of Feature Pyramid Network (FPN) neck
49
+ (we remove output conv and also do bicubic interpolation similar to ViT
50
+ pos embed interpolation)
51
+ """
52
+
53
+ def __init__(
54
+ self,
55
+ position_encoding: nn.Module,
56
+ d_model: int,
57
+ backbone_channel_list: List[int],
58
+ kernel_size: int = 1,
59
+ stride: int = 1,
60
+ padding: int = 0,
61
+ fpn_interp_model: str = "bilinear",
62
+ fuse_type: str = "sum",
63
+ fpn_top_down_levels: Optional[List[int]] = None,
64
+ ):
65
+ """Initialize the neck
66
+ :param trunk: the backbone
67
+ :param position_encoding: the positional encoding to use
68
+ :param d_model: the dimension of the model
69
+ :param neck_norm: the normalization to use
70
+ """
71
+ super().__init__()
72
+ self.position_encoding = position_encoding
73
+ self.convs = nn.ModuleList()
74
+ self.backbone_channel_list = backbone_channel_list
75
+ for dim in backbone_channel_list:
76
+ current = nn.Sequential()
77
+ current.add_module(
78
+ "conv",
79
+ nn.Conv2d(
80
+ in_channels=dim,
81
+ out_channels=d_model,
82
+ kernel_size=kernel_size,
83
+ stride=stride,
84
+ padding=padding,
85
+ ),
86
+ )
87
+
88
+ self.convs.append(current)
89
+ self.fpn_interp_model = fpn_interp_model
90
+ assert fuse_type in ["sum", "avg"]
91
+ self.fuse_type = fuse_type
92
+
93
+ # levels to have top-down features in its outputs
94
+ # e.g. if fpn_top_down_levels is [2, 3], then only outputs of level 2 and 3
95
+ # have top-down propagation, while outputs of level 0 and level 1 have only
96
+ # lateral features from the same backbone level.
97
+ if fpn_top_down_levels is None:
98
+ # default is to have top-down features on all levels
99
+ fpn_top_down_levels = range(len(self.convs))
100
+ self.fpn_top_down_levels = list(fpn_top_down_levels)
101
+
102
+ def forward(self, xs: List[torch.Tensor]):
103
+
104
+ out = [None] * len(self.convs)
105
+ pos = [None] * len(self.convs)
106
+ assert len(xs) == len(self.convs)
107
+ # fpn forward pass
108
+ # see https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/fpn.py
109
+ prev_features = None
110
+ # forward in top-down order (from low to high resolution)
111
+ n = len(self.convs) - 1
112
+ for i in range(n, -1, -1):
113
+ x = xs[i]
114
+ lateral_features = self.convs[n - i](x)
115
+ if i in self.fpn_top_down_levels and prev_features is not None:
116
+ top_down_features = F.interpolate(
117
+ prev_features.to(dtype=torch.float32),
118
+ scale_factor=2.0,
119
+ mode=self.fpn_interp_model,
120
+ align_corners=(
121
+ None if self.fpn_interp_model == "nearest" else False
122
+ ),
123
+ antialias=False,
124
+ )
125
+ prev_features = lateral_features + top_down_features
126
+ if self.fuse_type == "avg":
127
+ prev_features /= 2
128
+ else:
129
+ prev_features = lateral_features
130
+ x_out = prev_features
131
+ out[i] = x_out
132
+ pos[i] = self.position_encoding(x_out).to(x_out.dtype)
133
+
134
+ return out, pos
135
+
136
+ class ViTDetNeck(nn.Module):
137
+ def __init__(
138
+ self,
139
+ position_encoding: nn.Module,
140
+ d_model: int,
141
+ backbone_channel_list: List[int],
142
+ kernel_size: int = 1,
143
+ stride: int = 1,
144
+ padding: int = 0,
145
+ neck_norm=None,
146
+ ):
147
+ """Initialize the neck
148
+
149
+ :param trunk: the backbone
150
+ :param position_encoding: the positional encoding to use
151
+ :param d_model: the dimension of the model
152
+ :param neck_norm: the normalization to use
153
+ """
154
+ super().__init__()
155
+ self.backbone_channel_list = backbone_channel_list
156
+ self.position_encoding = position_encoding
157
+ self.convs = nn.ModuleList()
158
+ use_bias = neck_norm is None
159
+ for dim in self.backbone_channel_list:
160
+ current = nn.Sequential()
161
+ current.add_module(
162
+ "conv_1x1",
163
+ nn.Conv2d(
164
+ in_channels=dim,
165
+ out_channels=d_model,
166
+ kernel_size=1,
167
+ bias=use_bias,
168
+ ),
169
+ )
170
+ if neck_norm is not None:
171
+ current.add_module("norm_0", LayerNorm2d(d_model))
172
+ current.add_module(
173
+ "conv_3x3",
174
+ nn.Conv2d(
175
+ in_channels=d_model,
176
+ out_channels=d_model,
177
+ kernel_size=3,
178
+ padding=1,
179
+ bias=use_bias,
180
+ ),
181
+ )
182
+ if neck_norm is not None:
183
+ current.add_module("norm_1", LayerNorm2d(d_model))
184
+ self.convs.append(current)
185
+
186
+ def forward(self, xs: List[torch.Tensor]):
187
+ out = [None] * len(self.convs)
188
+ pos = [None] * len(self.convs)
189
+ assert len(xs) == len(self.convs)
190
+
191
+ x = xs[0]
192
+ x_out = self.convs[0](x)
193
+ out[0] = x_out
194
+ pos[0] = self.position_encoding(x_out).to(x_out.dtype)
195
+
196
+ return out, pos
sam2/modeling/backbones/utils.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """Some utilities for backbones, in particular for windowing"""
8
+
9
+ from typing import Tuple
10
+ import math
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+
16
+
17
+ def window_partition(x, window_size):
18
+ """
19
+ Partition into non-overlapping windows with padding if needed.
20
+ Args:
21
+ x (tensor): input tokens with [B, H, W, C].
22
+ window_size (int): window size.
23
+ Returns:
24
+ windows: windows after partition with [B * num_windows, window_size, window_size, C].
25
+ (Hp, Wp): padded height and width before partition
26
+ """
27
+ B, H, W, C = x.shape
28
+
29
+ pad_h = (window_size - H % window_size) % window_size
30
+ pad_w = (window_size - W % window_size) % window_size
31
+ if pad_h > 0 or pad_w > 0:
32
+ x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
33
+ Hp, Wp = H + pad_h, W + pad_w
34
+
35
+ x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
36
+ windows = (
37
+ x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
38
+ )
39
+ return windows, (Hp, Wp)
40
+
41
+
42
+ def window_unpartition(windows, window_size, pad_hw, hw):
43
+ """
44
+ Window unpartition into original sequences and removing padding.
45
+ Args:
46
+ x (tensor): input tokens with [B * num_windows, window_size, window_size, C].
47
+ window_size (int): window size.
48
+ pad_hw (Tuple): padded height and width (Hp, Wp).
49
+ hw (Tuple): original height and width (H, W) before padding.
50
+ Returns:
51
+ x: unpartitioned sequences with [B, H, W, C].
52
+ """
53
+ Hp, Wp = pad_hw
54
+ H, W = hw
55
+ B = windows.shape[0] // (Hp * Wp // window_size // window_size)
56
+ x = windows.view(
57
+ B, Hp // window_size, Wp // window_size, window_size, window_size, -1
58
+ )
59
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
60
+
61
+ if Hp > H or Wp > W:
62
+ x = x[:, :H, :W, :].contiguous()
63
+ return x
64
+
65
+
66
+ class PatchEmbed(nn.Module):
67
+ """
68
+ Image to Patch Embedding.
69
+ """
70
+
71
+ def __init__(
72
+ self,
73
+ kernel_size: Tuple[int, ...] = (7, 7),
74
+ stride: Tuple[int, ...] = (4, 4),
75
+ padding: Tuple[int, ...] = (3, 3),
76
+ in_chans: int = 3,
77
+ embed_dim: int = 768,
78
+ ):
79
+ """
80
+ Args:
81
+ kernel_size (Tuple): kernel size of the projection layer.
82
+ stride (Tuple): stride of the projection layer.
83
+ padding (Tuple): padding size of the projection layer.
84
+ in_chans (int): Number of input image channels.
85
+ embed_dim (int): embed_dim (int): Patch embedding dimension.
86
+ """
87
+ super().__init__()
88
+ self.proj = nn.Conv2d(
89
+ in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
90
+ )
91
+
92
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
93
+ x = self.proj(x)
94
+ # B C H W -> B H W C
95
+ x = x.permute(0, 2, 3, 1)
96
+ return x
97
+
98
+ def get_abs_pos(abs_pos, has_cls_token, hw):
99
+ """
100
+ Calculate absolute positional embeddings. If needed, resize embeddings and remove cls_token
101
+ dimension for the original embeddings.
102
+ Args:
103
+ abs_pos (Tensor): absolute positional embeddings with (1, num_position, C).
104
+ has_cls_token (bool): If true, has 1 embedding in abs_pos for cls token.
105
+ hw (Tuple): size of input image tokens.
106
+ Returns:
107
+ Absolute positional embeddings after processing with shape (1, H, W, C)
108
+ """
109
+ h, w = hw
110
+ if has_cls_token:
111
+ abs_pos = abs_pos[:, 1:]
112
+ xy_num = abs_pos.shape[1]
113
+ size = int(math.sqrt(xy_num))
114
+ assert size * size == xy_num
115
+
116
+ if size != h or size != w:
117
+ new_abs_pos = F.interpolate(
118
+ abs_pos.reshape(1, size, size, -1).permute(0, 3, 1, 2),
119
+ size=(h, w),
120
+ mode="bicubic",
121
+ align_corners=False,
122
+ )
123
+ return new_abs_pos.permute(0, 2, 3, 1)
124
+ else:
125
+ return abs_pos.reshape(1, h, w, -1)
sam2/modeling/backbones/vitdet.py ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ViTDet backbone adapted from Detectron2"""
2
+
3
+ from functools import partial
4
+ from typing import List, Tuple, Union
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+ from sam2.modeling.backbones.utils import (
11
+ PatchEmbed,
12
+ window_partition,
13
+ window_unpartition,
14
+ get_abs_pos,
15
+ )
16
+
17
+ from sam2.modeling.sam2_utils import DropPath, MLP, LayerScale
18
+
19
+ from functools import partial
20
+
21
+
22
+ class Attention(nn.Module):
23
+ """Multi-head Attention block with relative position embeddings."""
24
+
25
+ def __init__(
26
+ self,
27
+ dim,
28
+ num_heads=8,
29
+ qkv_bias=True,
30
+ use_rel_pos=False,
31
+ rel_pos_zero_init=True,
32
+ input_size=None,
33
+ ):
34
+ """
35
+ Args:
36
+ dim (int): Number of input channels.
37
+ num_heads (int): Number of attention heads.
38
+ qkv_bias (bool: If True, add a learnable bias to query, key, value.
39
+ rel_pos (bool): If True, add relative positional embeddings to the attention map.
40
+ rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
41
+ input_size (int or None): Input resolution for calculating the relative positional
42
+ parameter size.
43
+ attn_type: Type of attention operation, e.g. "vanilla", "vanilla-xformer".
44
+ """
45
+ super().__init__()
46
+ self.num_heads = num_heads
47
+ head_dim = dim // num_heads
48
+ self.scale = head_dim**-0.5
49
+
50
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
51
+ self.proj = nn.Linear(dim, dim)
52
+
53
+ self.use_rel_pos = use_rel_pos
54
+
55
+ def forward(self, x):
56
+ B, H, W, _ = x.shape
57
+ # qkv with shape (3, B, nHead, H * W, C)
58
+ qkv = (
59
+ self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
60
+ )
61
+ # q, k, v with shape (B * nHead, H * W, C)
62
+ q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)
63
+
64
+ q = q.view(B, self.num_heads, H * W, -1)
65
+ k = k.view(B, self.num_heads, H * W, -1)
66
+ v = v.view(B, self.num_heads, H * W, -1)
67
+ with torch.backends.cuda.sdp_kernel(
68
+ enable_flash=True,
69
+ enable_math=True,
70
+ enable_mem_efficient=True,
71
+ ):
72
+ x = F.scaled_dot_product_attention(q, k, v)
73
+ x = (
74
+ x.view(B, self.num_heads, H, W, -1)
75
+ .permute(0, 2, 3, 1, 4)
76
+ .reshape(B, H, W, -1)
77
+ )
78
+ x = self.proj(x)
79
+
80
+ return x
81
+
82
+
83
+ class Block(nn.Module):
84
+ """Transformer blocks with support of window attention"""
85
+
86
+ def __init__(
87
+ self,
88
+ dim,
89
+ num_heads,
90
+ mlp_ratio=4.0,
91
+ qkv_bias=True,
92
+ drop_path=0.0,
93
+ norm_layer=nn.LayerNorm,
94
+ act_layer=nn.GELU,
95
+ use_rel_pos=False,
96
+ rel_pos_zero_init=True,
97
+ window_size=0,
98
+ input_size=None,
99
+ dropout=0.0,
100
+ init_values=None,
101
+ ):
102
+ """
103
+ Args:
104
+ dim (int): Number of input channels.
105
+ num_heads (int): Number of attention heads in each ViT block.
106
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
107
+ qkv_bias (bool): If True, add a learnable bias to query, key, value.
108
+ drop_path (float): Stochastic depth rate.
109
+ norm_layer (nn.Module): Normalization layer.
110
+ act_layer (nn.Module): Activation layer.
111
+ use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
112
+ rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
113
+ window_size (int): Window size for window attention blocks. If it equals 0, then not
114
+ use window attention.
115
+ input_size (int or None): Input resolution for calculating the relative positional
116
+ parameter size.
117
+ dropout (float): Dropout rate.
118
+ """
119
+ super().__init__()
120
+ self.norm1 = norm_layer(dim)
121
+ self.attn = Attention(
122
+ dim,
123
+ num_heads=num_heads,
124
+ qkv_bias=qkv_bias,
125
+ use_rel_pos=use_rel_pos,
126
+ rel_pos_zero_init=rel_pos_zero_init,
127
+ input_size=input_size if window_size == 0 else (window_size, window_size),
128
+ )
129
+ self.ls1 = (
130
+ LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
131
+ )
132
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
133
+
134
+ self.norm2 = norm_layer(dim)
135
+ self.mlp = MLP(
136
+ dim,
137
+ int(dim * mlp_ratio),
138
+ dim,
139
+ num_layers=2,
140
+ activation=act_layer,
141
+ )
142
+ # self.mlp = Mlp2(
143
+ # in_features=dim,
144
+ # hidden_features=int(dim * mlp_ratio),
145
+ # act_layer=act_layer,
146
+ # drop=(dropout, 0.0),
147
+ # )
148
+ self.ls2 = (
149
+ LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
150
+ )
151
+ self.dropout = nn.Dropout(dropout)
152
+ self.window_size = window_size
153
+
154
+ def forward(self, x):
155
+ shortcut = x
156
+ x = self.norm1(x)
157
+ # Window partition
158
+ if self.window_size > 0:
159
+ H, W = x.shape[1], x.shape[2]
160
+ x, pad_hw = window_partition(x, self.window_size)
161
+
162
+ x = self.ls1(self.attn(x))
163
+ # Reverse window partition
164
+ if self.window_size > 0:
165
+ x = window_unpartition(x, self.window_size, pad_hw, (H, W))
166
+
167
+ x = shortcut + self.dropout(self.drop_path(x))
168
+ x = x + self.dropout(self.drop_path(self.ls2(self.mlp(self.norm2(x)))))
169
+
170
+ return x
171
+
172
+
173
+ class ViT(nn.Module):
174
+ """
175
+ This module implements Vision Transformer (ViT) backbone in :paper:`vitdet`.
176
+ "Exploring Plain Vision Transformer Backbones for Object Detection",
177
+ https://arxiv.org/abs/2203.16527
178
+ """
179
+
180
+ def __init__(
181
+ self,
182
+ img_size=1024,
183
+ patch_size=16,
184
+ in_chans=3,
185
+ embed_dim=768,
186
+ depth=12,
187
+ num_heads=12,
188
+ mlp_ratio=4.0,
189
+ qkv_bias=True,
190
+ drop_path_rate=0.0,
191
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
192
+ act_layer=nn.GELU,
193
+ use_abs_pos=True,
194
+ use_rel_pos=False,
195
+ rel_pos_zero_init=True,
196
+ window_size=14,
197
+ window_block_indexes=(0, 1, 3, 4, 6, 7, 9, 10),
198
+ use_act_checkpoint=False,
199
+ pretrain_img_size=224,
200
+ pretrain_use_cls_token=True,
201
+ dropout=0.0,
202
+ weights_path=None,
203
+ return_interm_layers=False,
204
+ init_values=None,
205
+ ):
206
+ """
207
+ Args:
208
+ img_size (int): Input image size. Only relevant for rel pos.
209
+ patch_size (int): Patch size.
210
+ in_chans (int): Number of input image channels.
211
+ embed_dim (int): Patch embedding dimension.
212
+ depth (int): Depth of ViT.
213
+ num_heads (int): Number of attention heads in each ViT block.
214
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
215
+ qkv_bias (bool): If True, add a learnable bias to query, key, value.
216
+ drop_path_rate (float): Stochastic depth rate.
217
+ norm_layer (nn.Module): Normalization layer.
218
+ act_layer (nn.Module): Activation layer.
219
+ use_abs_pos (bool): If True, use absolute positional embeddings.
220
+ use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
221
+ rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
222
+ window_size (int): Window size for window attention blocks.
223
+ window_block_indexes (list): Indexes for blocks using window attention.
224
+ residual_block_indexes (list): Indexes for blocks using conv propagation.
225
+ use_act_checkpoint (bool): If True, use activation checkpointing.
226
+ pretrain_img_size (int): input image size for pretraining models.
227
+ pretrain_use_cls_token (bool): If True, pretrainig models use class token.
228
+ dropout (float): Dropout rate. Applied in residual blocks of attn, mlp and inside the mlp.
229
+ path (str or None): Path to the pretrained weights.
230
+ return_interm_layers (bool): Whether to return intermediate layers (all global attention blocks).
231
+ freezing (BackboneFreezingType): Type of freezing.
232
+ """
233
+ super().__init__()
234
+ self.pretrain_use_cls_token = pretrain_use_cls_token
235
+
236
+ self.patch_embed = PatchEmbed(
237
+ kernel_size=(patch_size, patch_size),
238
+ stride=(patch_size, patch_size),
239
+ padding=(0, 0),
240
+ in_chans=in_chans,
241
+ embed_dim=embed_dim,
242
+ )
243
+
244
+ if use_abs_pos:
245
+ # Initialize absolute positional embedding with pretrain image size.
246
+ num_patches = (pretrain_img_size // patch_size) * (
247
+ pretrain_img_size // patch_size
248
+ )
249
+ num_positions = (num_patches + 1) if pretrain_use_cls_token else num_patches
250
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_positions, embed_dim))
251
+ else:
252
+ self.pos_embed = None
253
+
254
+ # stochastic depth decay rule
255
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
256
+
257
+ self.blocks = nn.ModuleList()
258
+ self.full_attn_ids = []
259
+ cur_stage = 1
260
+ for i in range(depth):
261
+ block = Block(
262
+ dim=embed_dim,
263
+ num_heads=num_heads,
264
+ mlp_ratio=mlp_ratio,
265
+ qkv_bias=qkv_bias,
266
+ drop_path=dpr[i],
267
+ norm_layer=norm_layer,
268
+ act_layer=act_layer,
269
+ use_rel_pos=use_rel_pos,
270
+ rel_pos_zero_init=rel_pos_zero_init,
271
+ window_size=window_size if i in window_block_indexes else 0,
272
+ input_size=(img_size // patch_size, img_size // patch_size),
273
+ dropout=dropout,
274
+ init_values=init_values,
275
+ )
276
+ if i not in window_block_indexes:
277
+ self.full_attn_ids.append(i)
278
+ cur_stage += 1
279
+
280
+ self.blocks.append(block)
281
+
282
+ self.return_interm_layers = return_interm_layers
283
+ self.channel_list = (
284
+ [embed_dim] * len(self.full_attn_ids)
285
+ if return_interm_layers
286
+ else [embed_dim]
287
+ )
288
+
289
+
290
+ def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
291
+
292
+ x = self.patch_embed(x)
293
+ if self.pos_embed is not None:
294
+ x = x + get_abs_pos(
295
+ self.pos_embed, self.pretrain_use_cls_token, (x.shape[1], x.shape[2])
296
+ )
297
+
298
+ outputs = []
299
+ for i, blk in enumerate(self.blocks):
300
+ x = blk(x)
301
+ if (i == self.full_attn_ids[-1]) or (
302
+ self.return_interm_layers and i in self.full_attn_ids
303
+ ):
304
+ feats = x.permute(0, 3, 1, 2)
305
+ outputs.append(feats)
306
+
307
+ return outputs
sam2/modeling/memory_attention.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import Optional
8
+
9
+ import torch
10
+ from torch import Tensor, nn
11
+
12
+ from sam2.modeling.sam2_utils import get_activation_fn, get_clones
13
+ from sam2.modeling.sam.transformer import RoPEAttention
14
+ from sam2.modeling.sam.transformer import EfficientRoPEAttention1
15
+ from sam2.modeling.sam.transformer import EfficientRoPEAttention2
16
+
17
+
18
+ class MemoryAttentionLayer(nn.Module):
19
+
20
+ def __init__(
21
+ self,
22
+ activation: str,
23
+ cross_attention: nn.Module,
24
+ d_model: int,
25
+ dim_feedforward: int,
26
+ dropout: float,
27
+ pos_enc_at_attn: bool,
28
+ pos_enc_at_cross_attn_keys: bool,
29
+ pos_enc_at_cross_attn_queries: bool,
30
+ self_attention: nn.Module,
31
+ ):
32
+ super().__init__()
33
+ self.d_model = d_model
34
+ self.dim_feedforward = dim_feedforward
35
+ self.dropout_value = dropout
36
+ self.self_attn = self_attention
37
+ self.cross_attn_image = cross_attention
38
+
39
+ # Implementation of Feedforward model
40
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
41
+ self.dropout = nn.Dropout(dropout)
42
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
43
+
44
+ self.norm1 = nn.LayerNorm(d_model)
45
+ self.norm2 = nn.LayerNorm(d_model)
46
+ self.norm3 = nn.LayerNorm(d_model)
47
+ self.dropout1 = nn.Dropout(dropout)
48
+ self.dropout2 = nn.Dropout(dropout)
49
+ self.dropout3 = nn.Dropout(dropout)
50
+
51
+ self.activation_str = activation
52
+ self.activation = get_activation_fn(activation)
53
+
54
+ # Where to add pos enc
55
+ self.pos_enc_at_attn = pos_enc_at_attn
56
+ self.pos_enc_at_cross_attn_queries = pos_enc_at_cross_attn_queries
57
+ self.pos_enc_at_cross_attn_keys = pos_enc_at_cross_attn_keys
58
+
59
+ def _forward_sa(self, tgt, query_pos):
60
+ # Self-Attention
61
+ tgt2 = self.norm1(tgt)
62
+ q = k = tgt2 + query_pos if self.pos_enc_at_attn else tgt2
63
+ tgt2 = self.self_attn(q, k, v=tgt2)
64
+ tgt = tgt + self.dropout1(tgt2)
65
+ return tgt
66
+
67
+ def _forward_ca(self, tgt, memory, query_pos, pos, num_k_exclude_rope=0):
68
+ kwds = {}
69
+ if num_k_exclude_rope > 0:
70
+ assert isinstance(self.cross_attn_image, RoPEAttention) or isinstance(self.cross_attn_image, EfficientRoPEAttention1) or isinstance(self.cross_attn_image, EfficientRoPEAttention2)
71
+ kwds = {"num_k_exclude_rope": num_k_exclude_rope}
72
+
73
+ # Cross-Attention
74
+ tgt2 = self.norm2(tgt)
75
+ tgt2 = self.cross_attn_image(
76
+ q=tgt2 + query_pos if self.pos_enc_at_cross_attn_queries else tgt2,
77
+ k=memory + pos if self.pos_enc_at_cross_attn_keys else memory,
78
+ v=memory,
79
+ **kwds,
80
+ )
81
+ tgt = tgt + self.dropout2(tgt2)
82
+ return tgt
83
+
84
+ def forward(
85
+ self,
86
+ tgt,
87
+ memory,
88
+ pos: Optional[Tensor] = None,
89
+ query_pos: Optional[Tensor] = None,
90
+ num_k_exclude_rope: int = 0,
91
+ ) -> torch.Tensor:
92
+
93
+ # Self-Attn, Cross-Attn
94
+ tgt = self._forward_sa(tgt, query_pos)
95
+ tgt = self._forward_ca(tgt, memory, query_pos, pos, num_k_exclude_rope)
96
+ # MLP
97
+ tgt2 = self.norm3(tgt)
98
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
99
+ tgt = tgt + self.dropout3(tgt2)
100
+ return tgt
101
+
102
+
103
+ class MemoryAttention(nn.Module):
104
+ def __init__(
105
+ self,
106
+ d_model: int,
107
+ pos_enc_at_input: bool,
108
+ layer: nn.Module,
109
+ num_layers: int,
110
+ batch_first: bool = True, # Do layers expect batch first input?
111
+ ):
112
+ super().__init__()
113
+ self.d_model = d_model
114
+ self.layers = get_clones(layer, num_layers)
115
+ self.num_layers = num_layers
116
+ self.norm = nn.LayerNorm(d_model)
117
+ self.pos_enc_at_input = pos_enc_at_input
118
+ self.batch_first = batch_first
119
+
120
+ def forward(
121
+ self,
122
+ curr: torch.Tensor, # self-attention inputs
123
+ memory: torch.Tensor, # cross-attention inputs
124
+ curr_pos: Optional[Tensor] = None, # pos_enc for self-attention inputs
125
+ memory_pos: Optional[Tensor] = None, # pos_enc for cross-attention inputs
126
+ num_obj_ptr_tokens: int = 0, # number of object pointer *tokens*
127
+ ):
128
+ if isinstance(curr, list):
129
+ assert isinstance(curr_pos, list)
130
+ assert len(curr) == len(curr_pos) == 1
131
+ curr, curr_pos = (
132
+ curr[0],
133
+ curr_pos[0],
134
+ )
135
+
136
+ assert (
137
+ curr.shape[1] == memory.shape[1]
138
+ ), "Batch size must be the same for curr and memory"
139
+
140
+ output = curr
141
+ if self.pos_enc_at_input and curr_pos is not None:
142
+ output = output + 0.1 * curr_pos
143
+
144
+ if self.batch_first:
145
+ # Convert to batch first
146
+ output = output.transpose(0, 1)
147
+ curr_pos = curr_pos.transpose(0, 1)
148
+ memory = memory.transpose(0, 1)
149
+ memory_pos = memory_pos.transpose(0, 1)
150
+
151
+ for layer in self.layers:
152
+ kwds = {}
153
+ if isinstance(layer.cross_attn_image, RoPEAttention) or isinstance(layer.cross_attn_image, EfficientRoPEAttention1) or isinstance(layer.cross_attn_image, EfficientRoPEAttention2):
154
+ kwds = {"num_k_exclude_rope": num_obj_ptr_tokens}
155
+
156
+ output = layer(
157
+ tgt=output,
158
+ memory=memory,
159
+ pos=memory_pos,
160
+ query_pos=curr_pos,
161
+ **kwds,
162
+ )
163
+ normed_output = self.norm(output)
164
+
165
+ if self.batch_first:
166
+ # Convert back to seq first
167
+ normed_output = normed_output.transpose(0, 1)
168
+ curr_pos = curr_pos.transpose(0, 1)
169
+
170
+ return normed_output
sam2/modeling/memory_encoder.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+ from typing import Tuple
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+
14
+ from sam2.modeling.sam2_utils import DropPath, LayerNorm2d, get_clones
15
+
16
+
17
+ class MaskDownSampler(nn.Module):
18
+ """
19
+ Progressively downsample a mask by total_stride, each time by stride.
20
+ Note that LayerNorm is applied per *token*, like in ViT.
21
+
22
+ With each downsample (by a factor stride**2), channel capacity increases by the same factor.
23
+ In the end, we linearly project to embed_dim channels.
24
+ """
25
+
26
+ def __init__(
27
+ self,
28
+ embed_dim=256,
29
+ kernel_size=4,
30
+ stride=4,
31
+ padding=0,
32
+ total_stride=16,
33
+ activation=nn.GELU,
34
+ ):
35
+ super().__init__()
36
+ num_layers = int(math.log2(total_stride) // math.log2(stride))
37
+ assert stride**num_layers == total_stride
38
+ self.encoder = nn.Sequential()
39
+ mask_in_chans, mask_out_chans = 1, 1
40
+ for _ in range(num_layers):
41
+ mask_out_chans = mask_in_chans * (stride**2)
42
+ self.encoder.append(
43
+ nn.Conv2d(
44
+ mask_in_chans,
45
+ mask_out_chans,
46
+ kernel_size=kernel_size,
47
+ stride=stride,
48
+ padding=padding,
49
+ )
50
+ )
51
+ self.encoder.append(LayerNorm2d(mask_out_chans))
52
+ self.encoder.append(activation())
53
+ mask_in_chans = mask_out_chans
54
+
55
+ self.encoder.append(nn.Conv2d(mask_out_chans, embed_dim, kernel_size=1))
56
+
57
+ def forward(self, x):
58
+ return self.encoder(x)
59
+
60
+
61
+ # Lightly adapted from ConvNext (https://github.com/facebookresearch/ConvNeXt)
62
+ class CXBlock(nn.Module):
63
+ r"""ConvNeXt Block. There are two equivalent implementations:
64
+ (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
65
+ (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
66
+ We use (2) as we find it slightly faster in PyTorch
67
+
68
+ Args:
69
+ dim (int): Number of input channels.
70
+ drop_path (float): Stochastic depth rate. Default: 0.0
71
+ layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
72
+ """
73
+
74
+ def __init__(
75
+ self,
76
+ dim,
77
+ kernel_size=7,
78
+ padding=3,
79
+ drop_path=0.0,
80
+ layer_scale_init_value=1e-6,
81
+ use_dwconv=True,
82
+ ):
83
+ super().__init__()
84
+ self.dwconv = nn.Conv2d(
85
+ dim,
86
+ dim,
87
+ kernel_size=kernel_size,
88
+ padding=padding,
89
+ groups=dim if use_dwconv else 1,
90
+ ) # depthwise conv
91
+ self.norm = LayerNorm2d(dim, eps=1e-6)
92
+ self.pwconv1 = nn.Linear(
93
+ dim, 4 * dim
94
+ ) # pointwise/1x1 convs, implemented with linear layers
95
+ self.act = nn.GELU()
96
+ self.pwconv2 = nn.Linear(4 * dim, dim)
97
+ self.gamma = (
98
+ nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
99
+ if layer_scale_init_value > 0
100
+ else None
101
+ )
102
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
103
+
104
+ def forward(self, x):
105
+ input = x
106
+ x = self.dwconv(x)
107
+ x = self.norm(x)
108
+ x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
109
+ x = self.pwconv1(x)
110
+ x = self.act(x)
111
+ x = self.pwconv2(x)
112
+ if self.gamma is not None:
113
+ x = self.gamma * x
114
+ x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
115
+
116
+ x = input + self.drop_path(x)
117
+ return x
118
+
119
+
120
+ class Fuser(nn.Module):
121
+ def __init__(self, layer, num_layers, dim=None, input_projection=False):
122
+ super().__init__()
123
+ self.proj = nn.Identity()
124
+ self.layers = get_clones(layer, num_layers)
125
+
126
+ if input_projection:
127
+ assert dim is not None
128
+ self.proj = nn.Conv2d(dim, dim, kernel_size=1)
129
+
130
+ def forward(self, x):
131
+ # normally x: (N, C, H, W)
132
+ x = self.proj(x)
133
+ for layer in self.layers:
134
+ x = layer(x)
135
+ return x
136
+
137
+
138
+ class MemoryEncoder(nn.Module):
139
+ def __init__(
140
+ self,
141
+ out_dim,
142
+ mask_downsampler,
143
+ fuser,
144
+ position_encoding,
145
+ in_dim=256, # in_dim of pix_feats
146
+ ):
147
+ super().__init__()
148
+
149
+ self.mask_downsampler = mask_downsampler
150
+
151
+ self.pix_feat_proj = nn.Conv2d(in_dim, in_dim, kernel_size=1)
152
+ self.fuser = fuser
153
+ self.position_encoding = position_encoding
154
+ self.out_proj = nn.Identity()
155
+ if out_dim != in_dim:
156
+ self.out_proj = nn.Conv2d(in_dim, out_dim, kernel_size=1)
157
+
158
+ def forward(
159
+ self,
160
+ pix_feat: torch.Tensor,
161
+ masks: torch.Tensor,
162
+ skip_mask_sigmoid: bool = False,
163
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
164
+ ## Process masks
165
+ # sigmoid, so that less domain shift from gt masks which are bool
166
+ if not skip_mask_sigmoid:
167
+ masks = F.sigmoid(masks)
168
+ masks = self.mask_downsampler(masks)
169
+
170
+ ## Fuse pix_feats and downsampled masks
171
+ # in case the visual features are on CPU, cast them to CUDA
172
+ pix_feat = pix_feat.to(masks.device)
173
+
174
+ x = self.pix_feat_proj(pix_feat)
175
+ x = x + masks
176
+ x = self.fuser(x)
177
+ x = self.out_proj(x)
178
+
179
+ pos = self.position_encoding(x).to(x.dtype)
180
+
181
+ return {"vision_features": x, "vision_pos_enc": [pos]}
sam2/modeling/position_encoding.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+ from typing import Any, Optional, Tuple
9
+
10
+ import numpy as np
11
+ import torch
12
+ from torch import nn
13
+
14
+
15
+ class PositionEmbeddingSine(nn.Module):
16
+ """
17
+ This is a more standard version of the position embedding, very similar to the one
18
+ used by the Attention is all you need paper, generalized to work on images.
19
+ """
20
+
21
+ def __init__(
22
+ self,
23
+ num_pos_feats,
24
+ temperature: int = 10000,
25
+ normalize: bool = True,
26
+ scale: Optional[float] = None,
27
+ ):
28
+ super().__init__()
29
+ assert num_pos_feats % 2 == 0, "Expecting even model width"
30
+ self.num_pos_feats = num_pos_feats // 2
31
+ self.temperature = temperature
32
+ self.normalize = normalize
33
+ if scale is not None and normalize is False:
34
+ raise ValueError("normalize should be True if scale is passed")
35
+ if scale is None:
36
+ scale = 2 * math.pi
37
+ self.scale = scale
38
+
39
+ self.cache = {}
40
+
41
+ def _encode_xy(self, x, y):
42
+ # The positions are expected to be normalized
43
+ assert len(x) == len(y) and x.ndim == y.ndim == 1
44
+ x_embed = x * self.scale
45
+ y_embed = y * self.scale
46
+
47
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
48
+ dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
49
+
50
+ pos_x = x_embed[:, None] / dim_t
51
+ pos_y = y_embed[:, None] / dim_t
52
+ pos_x = torch.stack(
53
+ (pos_x[:, 0::2].sin(), pos_x[:, 1::2].cos()), dim=2
54
+ ).flatten(1)
55
+ pos_y = torch.stack(
56
+ (pos_y[:, 0::2].sin(), pos_y[:, 1::2].cos()), dim=2
57
+ ).flatten(1)
58
+ return pos_x, pos_y
59
+
60
+ @torch.no_grad()
61
+ def encode_boxes(self, x, y, w, h):
62
+ pos_x, pos_y = self._encode_xy(x, y)
63
+ pos = torch.cat((pos_y, pos_x, h[:, None], w[:, None]), dim=1)
64
+ return pos
65
+
66
+ encode = encode_boxes # Backwards compatibility
67
+
68
+ @torch.no_grad()
69
+ def encode_points(self, x, y, labels):
70
+ (bx, nx), (by, ny), (bl, nl) = x.shape, y.shape, labels.shape
71
+ assert bx == by and nx == ny and bx == bl and nx == nl
72
+ pos_x, pos_y = self._encode_xy(x.flatten(), y.flatten())
73
+ pos_x, pos_y = pos_x.reshape(bx, nx, -1), pos_y.reshape(by, ny, -1)
74
+ pos = torch.cat((pos_y, pos_x, labels[:, :, None]), dim=2)
75
+ return pos
76
+
77
+ @torch.no_grad()
78
+ def forward(self, x: torch.Tensor):
79
+ cache_key = (x.shape[-2], x.shape[-1])
80
+ if cache_key in self.cache:
81
+ return self.cache[cache_key][None].repeat(x.shape[0], 1, 1, 1)
82
+ y_embed = (
83
+ torch.arange(1, x.shape[-2] + 1, dtype=torch.float32, device=x.device)
84
+ .view(1, -1, 1)
85
+ .repeat(x.shape[0], 1, x.shape[-1])
86
+ )
87
+ x_embed = (
88
+ torch.arange(1, x.shape[-1] + 1, dtype=torch.float32, device=x.device)
89
+ .view(1, 1, -1)
90
+ .repeat(x.shape[0], x.shape[-2], 1)
91
+ )
92
+
93
+ if self.normalize:
94
+ eps = 1e-6
95
+ y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
96
+ x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
97
+
98
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
99
+ dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
100
+
101
+ pos_x = x_embed[:, :, :, None] / dim_t
102
+ pos_y = y_embed[:, :, :, None] / dim_t
103
+ pos_x = torch.stack(
104
+ (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
105
+ ).flatten(3)
106
+ pos_y = torch.stack(
107
+ (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
108
+ ).flatten(3)
109
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
110
+ self.cache[cache_key] = pos[0]
111
+ return pos
112
+
113
+
114
+ class PositionEmbeddingRandom(nn.Module):
115
+ """
116
+ Positional encoding using random spatial frequencies.
117
+ """
118
+
119
+ def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:
120
+ super().__init__()
121
+ if scale is None or scale <= 0.0:
122
+ scale = 1.0
123
+ self.register_buffer(
124
+ "positional_encoding_gaussian_matrix",
125
+ scale * torch.randn((2, num_pos_feats)),
126
+ )
127
+
128
+ def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
129
+ """Positionally encode points that are normalized to [0,1]."""
130
+ # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
131
+ coords = 2 * coords - 1
132
+ coords = coords @ self.positional_encoding_gaussian_matrix
133
+ coords = 2 * np.pi * coords
134
+ # outputs d_1 x ... x d_n x C shape
135
+ return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
136
+
137
+ def forward(self, size: Tuple[int, int]) -> torch.Tensor:
138
+ """Generate positional encoding for a grid of the specified size."""
139
+ h, w = size
140
+ device: Any = self.positional_encoding_gaussian_matrix.device
141
+ grid = torch.ones((h, w), device=device, dtype=torch.float32)
142
+ y_embed = grid.cumsum(dim=0) - 0.5
143
+ x_embed = grid.cumsum(dim=1) - 0.5
144
+ y_embed = y_embed / h
145
+ x_embed = x_embed / w
146
+
147
+ pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1))
148
+ return pe.permute(2, 0, 1) # C x H x W
149
+
150
+ def forward_with_coords(
151
+ self, coords_input: torch.Tensor, image_size: Tuple[int, int]
152
+ ) -> torch.Tensor:
153
+ """Positionally encode points that are not normalized to [0,1]."""
154
+ coords = coords_input.clone()
155
+ coords[:, :, 0] = coords[:, :, 0] / image_size[1]
156
+ coords[:, :, 1] = coords[:, :, 1] / image_size[0]
157
+ return self._pe_encoding(coords.to(torch.float)) # B x N x C
158
+
159
+
160
+ # Rotary Positional Encoding, adapted from:
161
+ # 1. https://github.com/meta-llama/codellama/blob/main/llama/model.py
162
+ # 2. https://github.com/naver-ai/rope-vit
163
+ # 3. https://github.com/lucidrains/rotary-embedding-torch
164
+
165
+
166
+ def init_t_xy(end_x: int, end_y: int):
167
+ t = torch.arange(end_x * end_y, dtype=torch.float32)
168
+ t_x = (t % end_x).float()
169
+ t_y = torch.div(t, end_x, rounding_mode="floor").float()
170
+ return t_x, t_y
171
+
172
+
173
+ def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 10000.0):
174
+ freqs_x = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
175
+ freqs_y = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
176
+
177
+ t_x, t_y = init_t_xy(end_x, end_y)
178
+ freqs_x = torch.outer(t_x, freqs_x)
179
+ freqs_y = torch.outer(t_y, freqs_y)
180
+ freqs_cis_x = torch.polar(torch.ones_like(freqs_x), freqs_x)
181
+ freqs_cis_y = torch.polar(torch.ones_like(freqs_y), freqs_y)
182
+ return torch.cat([freqs_cis_x, freqs_cis_y], dim=-1)
183
+
184
+
185
+ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
186
+ ndim = x.ndim
187
+ assert 0 <= 1 < ndim
188
+ assert freqs_cis.shape == (x.shape[-2], x.shape[-1])
189
+ shape = [d if i >= ndim - 2 else 1 for i, d in enumerate(x.shape)]
190
+ return freqs_cis.view(*shape)
191
+
192
+
193
+ def apply_rotary_enc(
194
+ xq: torch.Tensor,
195
+ xk: torch.Tensor,
196
+ freqs_cis: torch.Tensor,
197
+ repeat_freqs_k: bool = False,
198
+ ):
199
+ xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
200
+ xk_ = (
201
+ torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
202
+ if xk.shape[-2] != 0
203
+ else None
204
+ )
205
+ freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
206
+ xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
207
+ if xk_ is None:
208
+ # no keys to rotate, due to dropout
209
+ return xq_out.type_as(xq).to(xq.device), xk
210
+ # repeat freqs along seq_len dim to match k seq_len
211
+ if repeat_freqs_k:
212
+ r = xk_.shape[-2] // xq_.shape[-2]
213
+ freqs_cis = freqs_cis.repeat(*([1] * (freqs_cis.ndim - 2)), r, 1)
214
+ xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
215
+ return xq_out.type_as(xq).to(xq.device), xk_out.type_as(xk).to(xk.device)
sam2/modeling/sam/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
sam2/modeling/sam/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (186 Bytes). View file
 
sam2/modeling/sam/__pycache__/mask_decoder.cpython-312.pyc ADDED
Binary file (12.7 kB). View file
 
sam2/modeling/sam/__pycache__/prompt_encoder.cpython-312.pyc ADDED
Binary file (9.48 kB). View file
 
sam2/modeling/sam/__pycache__/transformer.cpython-312.pyc ADDED
Binary file (23.2 kB). View file
 
sam2/modeling/sam/mask_decoder.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import List, Optional, Tuple, Type
8
+
9
+ import torch
10
+ from torch import nn
11
+
12
+ from sam2.modeling.sam2_utils import MLP, LayerNorm2d
13
+
14
+
15
+ class MaskDecoder(nn.Module):
16
+ def __init__(
17
+ self,
18
+ *,
19
+ transformer_dim: int,
20
+ transformer: nn.Module,
21
+ num_multimask_outputs: int = 3,
22
+ activation: Type[nn.Module] = nn.GELU,
23
+ iou_head_depth: int = 3,
24
+ iou_head_hidden_dim: int = 256,
25
+ use_high_res_features: bool = False,
26
+ iou_prediction_use_sigmoid=False,
27
+ dynamic_multimask_via_stability=False,
28
+ dynamic_multimask_stability_delta=0.05,
29
+ dynamic_multimask_stability_thresh=0.98,
30
+ pred_obj_scores: bool = False,
31
+ pred_obj_scores_mlp: bool = False,
32
+ use_multimask_token_for_obj_ptr: bool = False,
33
+ ) -> None:
34
+ """
35
+ Predicts masks given an image and prompt embeddings, using a
36
+ transformer architecture.
37
+
38
+ Arguments:
39
+ transformer_dim (int): the channel dimension of the transformer
40
+ transformer (nn.Module): the transformer used to predict masks
41
+ num_multimask_outputs (int): the number of masks to predict
42
+ when disambiguating masks
43
+ activation (nn.Module): the type of activation to use when
44
+ upscaling masks
45
+ iou_head_depth (int): the depth of the MLP used to predict
46
+ mask quality
47
+ iou_head_hidden_dim (int): the hidden dimension of the MLP
48
+ used to predict mask quality
49
+ """
50
+ super().__init__()
51
+ self.transformer_dim = transformer_dim
52
+ self.transformer = transformer
53
+
54
+ self.num_multimask_outputs = num_multimask_outputs
55
+
56
+ self.iou_token = nn.Embedding(1, transformer_dim)
57
+ self.num_mask_tokens = num_multimask_outputs + 1
58
+ self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)
59
+
60
+ self.pred_obj_scores = pred_obj_scores
61
+ if self.pred_obj_scores:
62
+ self.obj_score_token = nn.Embedding(1, transformer_dim)
63
+ self.use_multimask_token_for_obj_ptr = use_multimask_token_for_obj_ptr
64
+
65
+ self.output_upscaling = nn.Sequential(
66
+ nn.ConvTranspose2d(
67
+ transformer_dim, transformer_dim // 4, kernel_size=2, stride=2
68
+ ),
69
+ LayerNorm2d(transformer_dim // 4),
70
+ activation(),
71
+ nn.ConvTranspose2d(
72
+ transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2
73
+ ),
74
+ activation(),
75
+ )
76
+ self.use_high_res_features = use_high_res_features
77
+ if use_high_res_features:
78
+ self.conv_s0 = nn.Conv2d(
79
+ transformer_dim, transformer_dim // 8, kernel_size=1, stride=1
80
+ )
81
+ self.conv_s1 = nn.Conv2d(
82
+ transformer_dim, transformer_dim // 4, kernel_size=1, stride=1
83
+ )
84
+
85
+ self.output_hypernetworks_mlps = nn.ModuleList(
86
+ [
87
+ MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3)
88
+ for i in range(self.num_mask_tokens)
89
+ ]
90
+ )
91
+
92
+ self.iou_prediction_head = MLP(
93
+ transformer_dim,
94
+ iou_head_hidden_dim,
95
+ self.num_mask_tokens,
96
+ iou_head_depth,
97
+ sigmoid_output=iou_prediction_use_sigmoid,
98
+ )
99
+ if self.pred_obj_scores:
100
+ self.pred_obj_score_head = nn.Linear(transformer_dim, 1)
101
+ if pred_obj_scores_mlp:
102
+ self.pred_obj_score_head = MLP(transformer_dim, transformer_dim, 1, 3)
103
+
104
+ # When outputting a single mask, optionally we can dynamically fall back to the best
105
+ # multimask output token if the single mask output token gives low stability scores.
106
+ self.dynamic_multimask_via_stability = dynamic_multimask_via_stability
107
+ self.dynamic_multimask_stability_delta = dynamic_multimask_stability_delta
108
+ self.dynamic_multimask_stability_thresh = dynamic_multimask_stability_thresh
109
+
110
+ def forward(
111
+ self,
112
+ image_embeddings: torch.Tensor,
113
+ image_pe: torch.Tensor,
114
+ sparse_prompt_embeddings: torch.Tensor,
115
+ dense_prompt_embeddings: torch.Tensor,
116
+ multimask_output: bool,
117
+ repeat_image: bool,
118
+ high_res_features: Optional[List[torch.Tensor]] = None,
119
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
120
+ """
121
+ Predict masks given image and prompt embeddings.
122
+
123
+ Arguments:
124
+ image_embeddings (torch.Tensor): the embeddings from the image encoder
125
+ image_pe (torch.Tensor): positional encoding with the shape of image_embeddings
126
+ sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes
127
+ dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs
128
+ multimask_output (bool): Whether to return multiple masks or a single
129
+ mask.
130
+
131
+ Returns:
132
+ torch.Tensor: batched predicted masks
133
+ torch.Tensor: batched predictions of mask quality
134
+ torch.Tensor: batched SAM token for mask output
135
+ """
136
+ masks, iou_pred, mask_tokens_out, object_score_logits = self.predict_masks(
137
+ image_embeddings=image_embeddings,
138
+ image_pe=image_pe,
139
+ sparse_prompt_embeddings=sparse_prompt_embeddings,
140
+ dense_prompt_embeddings=dense_prompt_embeddings,
141
+ repeat_image=repeat_image,
142
+ high_res_features=high_res_features,
143
+ )
144
+
145
+ # Select the correct mask or masks for output
146
+ if multimask_output:
147
+ masks = masks[:, 1:, :, :]
148
+ iou_pred = iou_pred[:, 1:]
149
+ elif self.dynamic_multimask_via_stability and not self.training:
150
+ masks, iou_pred = self._dynamic_multimask_via_stability(masks, iou_pred)
151
+ else:
152
+ masks = masks[:, 0:1, :, :]
153
+ iou_pred = iou_pred[:, 0:1]
154
+
155
+ if multimask_output and self.use_multimask_token_for_obj_ptr:
156
+ sam_tokens_out = mask_tokens_out[:, 1:] # [b, 3, c] shape
157
+ else:
158
+ # Take the mask output token. Here we *always* use the token for single mask output.
159
+ # At test time, even if we track after 1-click (and using multimask_output=True),
160
+ # we still take the single mask token here. The rationale is that we always track
161
+ # after multiple clicks during training, so the past tokens seen during training
162
+ # are always the single mask token (and we'll let it be the object-memory token).
163
+ sam_tokens_out = mask_tokens_out[:, 0:1] # [b, 1, c] shape
164
+
165
+ # Prepare output
166
+ return masks, iou_pred, sam_tokens_out, object_score_logits
167
+
168
+ def predict_masks(
169
+ self,
170
+ image_embeddings: torch.Tensor,
171
+ image_pe: torch.Tensor,
172
+ sparse_prompt_embeddings: torch.Tensor,
173
+ dense_prompt_embeddings: torch.Tensor,
174
+ repeat_image: bool,
175
+ high_res_features: Optional[List[torch.Tensor]] = None,
176
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
177
+ """Predicts masks. See 'forward' for more details."""
178
+ # Concatenate output tokens
179
+ s = 0
180
+ if self.pred_obj_scores:
181
+ output_tokens = torch.cat(
182
+ [
183
+ self.obj_score_token.weight,
184
+ self.iou_token.weight,
185
+ self.mask_tokens.weight,
186
+ ],
187
+ dim=0,
188
+ )
189
+ s = 1
190
+ else:
191
+ output_tokens = torch.cat(
192
+ [self.iou_token.weight, self.mask_tokens.weight], dim=0
193
+ )
194
+ output_tokens = output_tokens.unsqueeze(0).expand(
195
+ sparse_prompt_embeddings.size(0), -1, -1
196
+ )
197
+ tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
198
+
199
+ # Expand per-image data in batch direction to be per-mask
200
+ if repeat_image:
201
+ src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
202
+ else:
203
+ assert image_embeddings.shape[0] == tokens.shape[0]
204
+ src = image_embeddings
205
+ src = src + dense_prompt_embeddings
206
+ assert (
207
+ image_pe.size(0) == 1
208
+ ), "image_pe should have size 1 in batch dim (from `get_dense_pe()`)"
209
+ pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
210
+ b, c, h, w = src.shape
211
+
212
+ # Run the transformer
213
+ hs, src = self.transformer(src, pos_src, tokens)
214
+ iou_token_out = hs[:, s, :]
215
+ mask_tokens_out = hs[:, s + 1 : (s + 1 + self.num_mask_tokens), :]
216
+
217
+ # Upscale mask embeddings and predict masks using the mask tokens
218
+ src = src.transpose(1, 2).view(b, c, h, w)
219
+ if not self.use_high_res_features:
220
+ upscaled_embedding = self.output_upscaling(src)
221
+ else:
222
+ dc1, ln1, act1, dc2, act2 = self.output_upscaling
223
+ feat_s0, feat_s1 = high_res_features
224
+ upscaled_embedding = act1(ln1(dc1(src) + feat_s1))
225
+ upscaled_embedding = act2(dc2(upscaled_embedding) + feat_s0)
226
+
227
+ hyper_in_list: List[torch.Tensor] = []
228
+ for i in range(self.num_mask_tokens):
229
+ hyper_in_list.append(
230
+ self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])
231
+ )
232
+ hyper_in = torch.stack(hyper_in_list, dim=1)
233
+ b, c, h, w = upscaled_embedding.shape
234
+ masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)
235
+
236
+ # Generate mask quality predictions
237
+ iou_pred = self.iou_prediction_head(iou_token_out)
238
+ if self.pred_obj_scores:
239
+ assert s == 1
240
+ object_score_logits = self.pred_obj_score_head(hs[:, 0, :])
241
+ else:
242
+ # Obj scores logits - default to 10.0, i.e. assuming the object is present, sigmoid(10)=1
243
+ object_score_logits = 10.0 * iou_pred.new_ones(iou_pred.shape[0], 1)
244
+
245
+ return masks, iou_pred, mask_tokens_out, object_score_logits
246
+
247
+ def _get_stability_scores(self, mask_logits):
248
+ """
249
+ Compute stability scores of the mask logits based on the IoU between upper and
250
+ lower thresholds, similar to https://github.com/fairinternal/onevision/pull/568.
251
+ """
252
+ mask_logits = mask_logits.flatten(-2)
253
+ stability_delta = self.dynamic_multimask_stability_delta
254
+ area_i = torch.sum(mask_logits > stability_delta, dim=-1).float()
255
+ area_u = torch.sum(mask_logits > -stability_delta, dim=-1).float()
256
+ stability_scores = torch.where(area_u > 0, area_i / area_u, 1.0)
257
+ return stability_scores
258
+
259
+ def _dynamic_multimask_via_stability(self, all_mask_logits, all_iou_scores):
260
+ """
261
+ When outputting a single mask, if the stability score from the current single-mask
262
+ output (based on output token 0) falls below a threshold, we instead select from
263
+ multi-mask outputs (based on output token 1~3) the mask with the highest predicted
264
+ IoU score. This is intended to ensure a valid mask for both clicking and tracking.
265
+ """
266
+ # The best mask from multimask output tokens (1~3)
267
+ multimask_logits = all_mask_logits[:, 1:, :, :]
268
+ multimask_iou_scores = all_iou_scores[:, 1:]
269
+ best_scores_inds = torch.argmax(multimask_iou_scores, dim=-1)
270
+ batch_inds = torch.arange(
271
+ multimask_iou_scores.size(0), device=all_iou_scores.device
272
+ )
273
+ best_multimask_logits = multimask_logits[batch_inds, best_scores_inds]
274
+ best_multimask_logits = best_multimask_logits.unsqueeze(1)
275
+ best_multimask_iou_scores = multimask_iou_scores[batch_inds, best_scores_inds]
276
+ best_multimask_iou_scores = best_multimask_iou_scores.unsqueeze(1)
277
+
278
+ # The mask from singlemask output token 0 and its stability score
279
+ singlemask_logits = all_mask_logits[:, 0:1, :, :]
280
+ singlemask_iou_scores = all_iou_scores[:, 0:1]
281
+ stability_scores = self._get_stability_scores(singlemask_logits)
282
+ is_stable = stability_scores >= self.dynamic_multimask_stability_thresh
283
+
284
+ # Dynamically fall back to best multimask output upon low stability scores.
285
+ mask_logits_out = torch.where(
286
+ is_stable[..., None, None].expand_as(singlemask_logits),
287
+ singlemask_logits,
288
+ best_multimask_logits,
289
+ )
290
+ iou_scores_out = torch.where(
291
+ is_stable.expand_as(singlemask_iou_scores),
292
+ singlemask_iou_scores,
293
+ best_multimask_iou_scores,
294
+ )
295
+ return mask_logits_out, iou_scores_out