nebula commited on
Commit
078145b
1 Parent(s): 7b4e230
.gitattributes CHANGED
@@ -32,3 +32,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ SAM_counting_anything__ArXiv_.pdf filter=lfs diff=lfs merge=lfs -text
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023 Vision-Intelligence-and-Robots-Group
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,13 +1,54 @@
1
- ---
2
- title: Counting Anything
3
- emoji: 🐠
4
- colorFrom: gray
5
- colorTo: indigo
6
- sdk: gradio
7
- sdk_version: 3.27.0
8
- app_file: app.py
9
- pinned: false
10
- license: apache-2.0
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # count-anything
2
+ An empirical study on few-shot counting using segment anything (SAM)
3
+
4
+ Meta AI recently released the Segment Anything model [[SAM]](https://github.com/facebookresearch/segment-anything), which has garnered attention due to its impressive performance in class-agnostic segmenting. In this study, we explore the use of SAM for the challenging task of few-shot object counting, which involves counting objects of an unseen category by providing a few bounding boxes of examples. We compare SAM's performance with other few-shot counting methods and find that it is currently unsatisfactory without further fine-tuning, particularly for small and crowded objects.
5
+
6
+ ![image](example.png)
7
+ ## Install
8
+ Install python dependencies. We use conda and python 3.10.4 and PyTorch 1.13.1
9
+ > conda env create -f env.yaml
10
+
11
+ ## Dataset preparation
12
+ - For FSC-147:
13
+ Images can be downloaded from here: https://drive.google.com/file/d/1ymDYrGs9DSRicfZbSCDiOu0ikGDh5k6S/view?usp=sharing
14
+
15
+ - For coco val2017:
16
+ Images can be downloaded from here: https://cocodataset.org/
17
+ ## Comparison Results
18
+
19
+ ### FSC
20
+
21
+ ![image](resultFSC.png)
22
+
23
+ ### COCO
24
+
25
+ ![image](resultcoco.png)
26
+ ## Test
27
+ Download the [ViT-H SAM model](https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth)
28
+
29
+ - For FSC-147:
30
+ ```
31
+ python test_FSC.py --data_path <FSC-147 dataset path> --model_path <path to ViT-H SAM model>
32
+ ```
33
+
34
+ - For coco val2017:
35
+ ```
36
+ python test_coco.py --data_path <coco val2017 dataset path\> --model_path <path to ViT-H SAM model>
37
+ ```
38
+
39
+ ## Visualize
40
+ You can run [vis_FSC.ipynb](vis_FSC.ipynb) for FSC-147 or [vis_coco.ipynb](vis_coco.ipynb) for coco.
41
+
42
+ ## Acknowledgement
43
+ We thank facebookresearch for their segment-anything model [[project]](https://github.com/facebookresearch/segment-anything), cvlab-stonybrook for their Learning To Count Everything [[project]](https://github.com/cvlab-stonybrook/LearningToCountEverything) and coco [[datasets]](https://cocodataset.org/).
44
+
45
+ ## Citation
46
+ If you find the code useful, please cite:
47
+ ```
48
+ @article{ma2023countanything,
49
+ title={CAN SAM COUNT ANYTHING? AN EMPIRICAL STUDY ON SAM COUNTING},
50
+ author={Ma, Zhiheng and Hong, Xiaopeng and Shangguan Qinnan},
51
+ journal={arXiv preprint arXiv:2304.xxxxx},
52
+ year={2023}
53
+ }
54
+ ```
SAM_counting_anything__ArXiv_.pdf ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:49bcb63df39d9ad072ae11393355e29757169608fa4540388698cf23a2c7110a
3
+ size 6071349
app.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image, ImageDraw
2
+ import cv2
3
+ import gradio as gr
4
+ import torch
5
+ from segment_anything import sam_model_registry
6
+ from automatic_mask_generator import SamAutomaticMaskGenerator
7
+
8
+ device = 'cuda'
9
+ sam = sam_model_registry['vit_h'](checkpoint='./sam_vit_h_4b8939.pth')
10
+ sam.to(device=device)
11
+
12
+
13
+ mask_generator = SamAutomaticMaskGenerator(
14
+ model=sam,
15
+ min_mask_region_area=25
16
+ )
17
+
18
+ def binarize(x):
19
+ return (x != 0).astype('uint8') * 255
20
+
21
+ def draw_box(boxes=[], img=None):
22
+ if len(boxes) == 0 and img is None:
23
+ return None
24
+
25
+ if img is None:
26
+ img = Image.new('RGB', (512, 512), (255, 255, 255))
27
+ colors = ["red", "olive", "blue", "green", "orange", "brown", "cyan", "purple"]
28
+ draw = ImageDraw.Draw(img)
29
+ # print(boxes)
30
+ for bid, box in enumerate(boxes):
31
+ draw.rectangle([box[0], box[1], box[2], box[3]], outline=colors[bid % len(colors)], width=4)
32
+ return img
33
+
34
+
35
+ def draw_pred_box(boxes=[], img=None):
36
+ if len(boxes) == 0 and img is None:
37
+ return None
38
+
39
+ if img is None:
40
+ img = Image.new('RGB', (512, 512), (255, 255, 255))
41
+ colors = "green"
42
+ draw = ImageDraw.Draw(img)
43
+ # print(boxes)
44
+ for bid, box in enumerate(boxes):
45
+ draw.rectangle([box[0], box[1], box[2], box[3]], outline=colors, width=4)
46
+ return img
47
+
48
+
49
+ def debug(input_img):
50
+ mask = input_img["mask"]
51
+ mask = mask[..., 0]
52
+ contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
53
+
54
+ boxes = []
55
+ for contour in contours:
56
+ y1, y2 = contour[:, 0, 1].min(), contour[:, 0, 1].max()
57
+ x1, x2 = contour[:, 0, 0].min(), contour[:, 0, 0].max()
58
+ boxes.append([x1, y1, x2, y2])
59
+ draw_image = draw_box(boxes, Image.fromarray(input_img["image"]))
60
+
61
+ masks = mask_generator.generate(input_img["image"], boxes)
62
+ pred_cnt = len(masks)
63
+ pred_bboxes = []
64
+ for i in masks:
65
+ x0, y0, w, h = i['bbox']
66
+ pred_bboxes.append([x0, y0, x0+w, y0+h])
67
+ pred_image = draw_pred_box(pred_bboxes, Image.fromarray(input_img["image"]))
68
+ return [draw_image, pred_image, "Count: {}".format(pred_cnt)]
69
+
70
+ description = """<p style="text-align: center; font-weight: bold;">
71
+ <span style="font-size: 28px">Count Anything</span>
72
+ <br>
73
+ <span style="font-size: 18px" id="paper-info">
74
+ [<a href=" " target="_blank">Project Page</a>]
75
+ [<a href=" " target="_blank">Paper</a>]
76
+ [<a href="https://github.com/Vision-Intelligence-and-Robots-Group/count-anything" target="_blank">GitHub</a>]
77
+ </span>
78
+ </p>
79
+ """
80
+
81
+ run = gr.Interface(
82
+ debug,
83
+ gr.Image(shape=[512, 512], source="upload", tool="sketch").style(height=500, width=500),
84
+ [gr.Image(), gr.Image(), gr.Text()],
85
+ description = description
86
+ )
87
+
88
+ run.launch()
automatic_mask_generator.py ADDED
@@ -0,0 +1,496 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from torchvision.ops.boxes import batched_nms, box_area # type: ignore
4
+
5
+ from typing import Any, Dict, List, Optional, Tuple
6
+ import torch.nn.functional as F
7
+ from collections import defaultdict
8
+
9
+ from segment_anything.modeling import Sam
10
+ from segment_anything.predictor import SamPredictor
11
+ from segment_anything.utils.amg import (
12
+ MaskData,
13
+ area_from_rle,
14
+ batch_iterator,
15
+ batched_mask_to_box,
16
+ box_xyxy_to_xywh,
17
+ build_all_layer_point_grids,
18
+ calculate_stability_score,
19
+ coco_encode_rle,
20
+ generate_crop_boxes,
21
+ is_box_near_crop_edge,
22
+ mask_to_rle_pytorch,
23
+ remove_small_regions,
24
+ rle_to_mask,
25
+ uncrop_boxes_xyxy,
26
+ uncrop_masks,
27
+ uncrop_points,
28
+ )
29
+
30
+
31
+ def uncrop_boxes_xyxy(boxes: torch.Tensor, crop_box: List[int]) -> torch.Tensor:
32
+ x0, y0, _, _ = crop_box
33
+ offset = torch.tensor([[x0, y0, x0, y0]], device=boxes.device)
34
+ # Check if boxes has a channel dimension
35
+ if len(boxes.shape) == 3:
36
+ offset = offset.unsqueeze(1)
37
+ return boxes + offset
38
+
39
+ def pre_process_ref_box(ref_box, crop_box, layer_idx):
40
+ if layer_idx == 0:
41
+ return ref_box
42
+ else:
43
+ new_bbox = []
44
+ x0, y0, x1, y1 = crop_box
45
+ for ref in ref_box:
46
+ x0_r, y0_r, x1_r, y1_r = ref
47
+ area = (y1_r - y0_r) * (x1_r - x0_r)
48
+ x_0_new = max(x0, x0_r)
49
+ y_0_new = max(y0, y0_r)
50
+ x_1_new = min(x1, x1_r)
51
+ y_1_new = min(y1, y1_r)
52
+ crop_area = (y_1_new - y_0_new) * (x_1_new - x_0_new)
53
+ if crop_area / area > 0.7:
54
+ new_bbox.append([x_0_new, y_0_new, x_1_new, y_1_new])
55
+
56
+ return new_bbox
57
+
58
+
59
+
60
+
61
+ class SamAutomaticMaskGenerator:
62
+ def __init__(
63
+ self,
64
+ model: Sam,
65
+ points_per_side: Optional[int] = 32,
66
+ points_per_batch: int = 64,
67
+ pred_iou_thresh: float = 0.88,
68
+ stability_score_thresh: float = 0.95,
69
+ stability_score_offset: float = 1.0,
70
+ box_nms_thresh: float = 0.7,
71
+ crop_n_layers: int = 0,
72
+ crop_nms_thresh: float = 0.7,
73
+ crop_overlap_ratio: float = 512 / 1500,
74
+ crop_n_points_downscale_factor: int = 1,
75
+ point_grids: Optional[List[np.ndarray]] = None,
76
+ min_mask_region_area: int = 0,
77
+ output_mode: str = "binary_mask",
78
+ ) -> None:
79
+ """
80
+ Using a SAM model, generates masks for the entire image.
81
+ Generates a grid of point prompts over the image, then filters
82
+ low quality and duplicate masks. The default settings are chosen
83
+ for SAM with a ViT-H backbone.
84
+
85
+ Arguments:
86
+ model (Sam): The SAM model to use for mask prediction.
87
+ points_per_side (int or None): The number of points to be sampled
88
+ along one side of the image. The total number of points is
89
+ points_per_side**2. If None, 'point_grids' must provide explicit
90
+ point sampling.
91
+ points_per_batch (int): Sets the number of points run simultaneously
92
+ by the model. Higher numbers may be faster but use more GPU memory.
93
+ pred_iou_thresh (float): A filtering threshold in [0,1], using the
94
+ model's predicted mask quality.
95
+ stability_score_thresh (float): A filtering threshold in [0,1], using
96
+ the stability of the mask under changes to the cutoff used to binarize
97
+ the model's mask predictions.
98
+ stability_score_offset (float): The amount to shift the cutoff when
99
+ calculated the stability score.
100
+ box_nms_thresh (float): The box IoU cutoff used by non-maximal
101
+ suppression to filter duplicate masks.
102
+ crops_n_layers (int): If >0, mask prediction will be run again on
103
+ crops of the image. Sets the number of layers to run, where each
104
+ layer has 2**i_layer number of image crops.
105
+ crops_nms_thresh (float): The box IoU cutoff used by non-maximal
106
+ suppression to filter duplicate masks between different crops.
107
+ crop_overlap_ratio (float): Sets the degree to which crops overlap.
108
+ In the first crop layer, crops will overlap by this fraction of
109
+ the image length. Later layers with more crops scale down this overlap.
110
+ crop_n_points_downscale_factor (int): The number of points-per-side
111
+ sampled in layer n is scaled down by crop_n_points_downscale_factor**n.
112
+ point_grids (list(np.ndarray) or None): A list over explicit grids
113
+ of points used for sampling, normalized to [0,1]. The nth grid in the
114
+ list is used in the nth crop layer. Exclusive with points_per_side.
115
+ min_mask_region_area (int): If >0, postprocessing will be applied
116
+ to remove disconnected regions and holes in masks with area smaller
117
+ than min_mask_region_area. Requires opencv.
118
+ output_mode (str): The form masks are returned in. Can be 'binary_mask',
119
+ 'uncompressed_rle', or 'coco_rle'. 'coco_rle' requires pycocotools.
120
+ For large resolutions, 'binary_mask' may consume large amounts of
121
+ memory.
122
+ """
123
+
124
+ assert (points_per_side is None) != (
125
+ point_grids is None
126
+ ), "Exactly one of points_per_side or point_grid must be provided."
127
+ if points_per_side is not None:
128
+ self.point_grids = build_all_layer_point_grids(
129
+ points_per_side,
130
+ crop_n_layers,
131
+ crop_n_points_downscale_factor,
132
+ )
133
+ elif point_grids is not None:
134
+ self.point_grids = point_grids
135
+ else:
136
+ raise ValueError("Can't have both points_per_side and point_grid be None.")
137
+
138
+ assert output_mode in [
139
+ "binary_mask",
140
+ "uncompressed_rle",
141
+ "coco_rle",
142
+ ], f"Unknown output_mode {output_mode}."
143
+ if output_mode == "coco_rle":
144
+ from pycocotools import mask as mask_utils # type: ignore # noqa: F401
145
+
146
+ if min_mask_region_area > 0:
147
+ import cv2 # type: ignore # noqa: F401
148
+
149
+ self.predictor = SamPredictor(model)
150
+ self.points_per_batch = points_per_batch
151
+ self.pred_iou_thresh = pred_iou_thresh
152
+ self.stability_score_thresh = stability_score_thresh
153
+ self.stability_score_offset = stability_score_offset
154
+ self.box_nms_thresh = box_nms_thresh
155
+ self.crop_n_layers = crop_n_layers
156
+ self.crop_nms_thresh = crop_nms_thresh
157
+ self.crop_overlap_ratio = crop_overlap_ratio
158
+ self.crop_n_points_downscale_factor = crop_n_points_downscale_factor
159
+ self.min_mask_region_area = min_mask_region_area
160
+ self.output_mode = output_mode
161
+
162
+ self.prototype = defaultdict(list)
163
+
164
+ @torch.no_grad()
165
+ def generate(self, image: np.ndarray, ref_bbox) -> List[Dict[str, Any]]:
166
+ """
167
+ Generates masks for the given image.
168
+
169
+ Arguments:
170
+ image (np.ndarray): The image to generate masks for, in HWC uint8 format.
171
+
172
+
173
+ Returns:
174
+ list(dict(str, any)): A list over records for masks. Each record is
175
+ a dict containing the following keys:
176
+ segmentation (dict(str, any) or np.ndarray): The mask. If
177
+ output_mode='binary_mask', is an array of shape HW. Otherwise,
178
+ is a dictionary containing the RLE.
179
+ bbox (list(float)): The box around the mask, in XYWH format.
180
+ area (int): The area in pixels of the mask.
181
+ predicted_iou (float): The model's own prediction of the mask's
182
+ quality. This is filtered by the pred_iou_thresh parameter.
183
+ point_coords (list(list(float))): The point coordinates input
184
+ to the model to generate this mask.
185
+ stability_score (float): A measure of the mask's quality. This
186
+ is filtered on using the stability_score_thresh parameter.
187
+ crop_box (list(float)): The crop of the image used to generate
188
+ the mask, given in XYWH format.
189
+ """
190
+
191
+ # Generate masks
192
+ mask_data = self._generate_masks(image, ref_bbox)
193
+
194
+ # Filter small disconnected regions and holes in masks
195
+ if self.min_mask_region_area > 0:
196
+ mask_data = self.postprocess_small_regions(
197
+ mask_data,
198
+ self.min_mask_region_area,
199
+ max(self.box_nms_thresh, self.crop_nms_thresh),
200
+ )
201
+
202
+ # Encode masks
203
+ if self.output_mode == "coco_rle":
204
+ mask_data["segmentations"] = [coco_encode_rle(rle) for rle in mask_data["rles"]]
205
+ elif self.output_mode == "binary_mask":
206
+ mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]]
207
+ else:
208
+ mask_data["segmentations"] = mask_data["rles"]
209
+
210
+ # Write mask records
211
+ curr_anns = []
212
+ for idx in range(len(mask_data["segmentations"])):
213
+ ann = {
214
+ "segmentation": mask_data["segmentations"][idx],
215
+ "area": area_from_rle(mask_data["rles"][idx]),
216
+ "bbox": box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(),
217
+ "predicted_iou": mask_data["iou_preds"][idx].item(),
218
+ "point_coords": [mask_data["points"][idx].tolist()],
219
+ "stability_score": mask_data["stability_score"][idx].item(),
220
+ "crop_box": box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(),
221
+ }
222
+ curr_anns.append(ann)
223
+
224
+ return curr_anns
225
+
226
+ def _generate_masks(self, image: np.ndarray, ref_box) -> MaskData:
227
+ orig_size = image.shape[:2]
228
+ crop_boxes, layer_idxs = generate_crop_boxes(
229
+ orig_size, self.crop_n_layers, self.crop_overlap_ratio
230
+ )
231
+
232
+ # Iterate over image crops
233
+ # data = MaskData()
234
+ data_dic = defaultdict(MaskData)
235
+ for crop_box, layer_idx in zip(crop_boxes, layer_idxs):
236
+ crop_data = self._process_crop(image, crop_box, layer_idx, orig_size, ref_box)
237
+ data_dic[layer_idx].cat(crop_data)
238
+
239
+ data = MaskData()
240
+ for layer_idx in data_dic.keys():
241
+ proto_fea = torch.concat(self.prototype[layer_idx], dim=0)
242
+ if len(proto_fea) > 1:
243
+ cos_dis = proto_fea @ proto_fea.t()
244
+ sim_thresh = torch.min(cos_dis)
245
+ else:
246
+ sim_thresh = 0.7
247
+ sub_data = data_dic[layer_idx]
248
+ fea = sub_data['fea']
249
+ cos_dis = torch.max(fea @ proto_fea.t(), dim=1)[0]
250
+ sub_data.filter(cos_dis>=sim_thresh)
251
+ data.cat(sub_data)
252
+
253
+ self.prototype = defaultdict(list)
254
+
255
+
256
+ # Remove duplicate masks between crops
257
+ if len(crop_boxes) > 1:
258
+ # Prefer masks from smaller crops
259
+ scores = 1 / box_area(data["crop_boxes"])
260
+ scores = scores.to(data["boxes"].device)
261
+ keep_by_nms = batched_nms(
262
+ data["boxes"].float(),
263
+ scores,
264
+ torch.zeros(len(data["boxes"])), # categories
265
+ iou_threshold=self.crop_nms_thresh,
266
+ )
267
+ data.filter(keep_by_nms)
268
+
269
+ data.to_numpy()
270
+ return data
271
+
272
+ def _process_crop(
273
+ self,
274
+ image: np.ndarray,
275
+ crop_box: List[int],
276
+ crop_layer_idx: int,
277
+ orig_size: Tuple[int, ...],
278
+ ref_box
279
+ ) -> MaskData:
280
+ # Crop the image and calculate embeddings
281
+ x0, y0, x1, y1 = crop_box
282
+ cropped_im = image[y0:y1, x0:x1, :]
283
+ cropped_im_size = cropped_im.shape[:2]
284
+ self.predictor.set_image(cropped_im)
285
+
286
+ ref_box = pre_process_ref_box(ref_box, crop_box, crop_layer_idx)
287
+ if len(ref_box) > 0:
288
+ ref_box = torch.tensor(ref_box, device=self.predictor.device)
289
+ transformed_boxes = self.predictor.transform.apply_boxes_torch(ref_box, cropped_im_size)
290
+ masks, iou_preds, low_res_masks = self.predictor.predict_torch(
291
+ point_coords=None,
292
+ point_labels=None,
293
+ boxes=transformed_boxes,
294
+ multimask_output=False
295
+ )
296
+ feature = self.predictor.get_image_embedding()
297
+
298
+ low_res_masks = F.interpolate(low_res_masks, size=feature.shape[-2:], mode='bilinear', align_corners=False)
299
+
300
+ feature = feature.flatten(2, 3)
301
+ low_res_masks = low_res_masks.flatten(2, 3)
302
+ masks_low_res = (low_res_masks > self.predictor.model.mask_threshold).float()
303
+ topk_idx = torch.topk(low_res_masks, 1)[1]
304
+ masks_low_res.scatter_(2, topk_idx, 1.0)
305
+
306
+
307
+ prototype_fea = (feature * masks_low_res).sum(dim=2) / masks_low_res.sum(dim=2)
308
+ prototype_fea = F.normalize(prototype_fea, dim=1)
309
+ self.prototype[crop_layer_idx].append(prototype_fea)
310
+
311
+
312
+ if crop_layer_idx == 0: # add reference gounding
313
+ x = ref_box[:, 0] + (ref_box[:, 2] - ref_box[:, 0]) / 2
314
+ y = ref_box[:, 1] + (ref_box[:, 3] - ref_box[:, 1]) / 2
315
+ points = torch.stack([x, y], dim=1)
316
+ data = MaskData(
317
+ masks=masks.flatten(0, 1),
318
+ iou_preds= torch.ones_like(iou_preds.flatten(0, 1)),
319
+ fea = prototype_fea,
320
+ points=points.cpu(),
321
+ stability_score = torch.ones_like(iou_preds.flatten(0, 1)),
322
+ )
323
+ data["boxes"] = batched_mask_to_box(data["masks"])
324
+ data["rles"] = mask_to_rle_pytorch(data["masks"])
325
+ del data["masks"]
326
+ else:
327
+ data = MaskData()
328
+
329
+
330
+
331
+ # Get points for this crop
332
+ points_scale = np.array(cropped_im_size)[None, ::-1]
333
+ points_for_image = self.point_grids[crop_layer_idx] * points_scale
334
+
335
+ # Generate masks for this crop in batches
336
+ for (points,) in batch_iterator(self.points_per_batch, points_for_image):
337
+ batch_data = self._process_batch(points, cropped_im_size,
338
+ crop_box, orig_size)
339
+ data.cat(batch_data)
340
+ del batch_data
341
+ self.predictor.reset_image()
342
+
343
+ # Remove duplicates within this crop.
344
+ keep_by_nms = batched_nms(
345
+ data["boxes"].float(),
346
+ data["iou_preds"],
347
+ torch.zeros(len(data["boxes"])), # categories
348
+ iou_threshold=self.box_nms_thresh,
349
+ )
350
+ data.filter(keep_by_nms)
351
+
352
+ # Return to the original image frame
353
+ data["boxes"] = uncrop_boxes_xyxy(data["boxes"], crop_box)
354
+ data["points"] = uncrop_points(data["points"], crop_box)
355
+ data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))])
356
+
357
+ return data
358
+
359
+ def _process_batch(
360
+ self,
361
+ points: np.ndarray,
362
+ im_size: Tuple[int, ...],
363
+ crop_box: List[int],
364
+ orig_size: Tuple[int, ...],
365
+ ) -> MaskData:
366
+ orig_h, orig_w = orig_size
367
+
368
+ # Run model on this batch
369
+ transformed_points = self.predictor.transform.apply_coords(points, im_size)
370
+ in_points = torch.as_tensor(transformed_points, device=self.predictor.device)
371
+ in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device)
372
+ masks, iou_preds, low_res_masks = self.predictor.predict_torch(
373
+ in_points[:, None, :],
374
+ in_labels[:, None],
375
+ multimask_output=True,
376
+ return_logits=True,
377
+ )
378
+
379
+ feature = self.predictor.get_image_embedding()
380
+ low_res_masks=low_res_masks.flatten(0, 1)
381
+ low_res_masks = F.interpolate(low_res_masks[:, None, :, :], size=feature.shape[-2:],
382
+ mode='bilinear', align_corners=False)
383
+ # low_res_masks = low_res_masks > self.predictor.model.mask_threshold
384
+
385
+ # fea = feature.flatten(2, 3)
386
+ # low_res_masks = low_res_masks.flatten(2, 3)
387
+ # topk_idx = torch.topk(low_res_masks, 4)[1]
388
+ # fea = fea.expand(topk_idx.shape[0], -1, -1)
389
+ # topk_idx = topk_idx.expand(-1, fea.shape[1], -1)
390
+ # fea = fea.gather(2, topk_idx)
391
+
392
+
393
+ feature = feature.flatten(2, 3)
394
+ low_res_masks = low_res_masks.flatten(2, 3)
395
+ masks_low_res = (low_res_masks > self.predictor.model.mask_threshold).float()
396
+ topk_idx = torch.topk(low_res_masks, 1)[1]
397
+ masks_low_res.scatter_(2, topk_idx, 1.0)
398
+ pool_fea = (feature * masks_low_res).sum(dim=2) / masks_low_res.sum(dim=2)
399
+ pool_fea = F.normalize(pool_fea, dim=1)
400
+
401
+ # k_val = torch.topk(torch.flatten(low_res_masks, start_dim=2, end_dim=3), k=4, dim=-1)[0][:, :, -1, None]
402
+ # low_res_masks = (low_res_masks >= k_val).float()
403
+ # low_res_masks = low_res_masks.float()
404
+ # pool_fea = (feature * low_res_masks).sum(dim=(2, 3)) / low_res_masks.sum(dim=(2, 3))
405
+
406
+
407
+
408
+ # Serialize predictions and store in MaskData
409
+ data = MaskData(
410
+ masks=masks.flatten(0, 1),
411
+ iou_preds=iou_preds.flatten(0, 1),
412
+ points=torch.as_tensor(points.repeat(masks.shape[1], axis=0)),
413
+ fea = pool_fea,
414
+ )
415
+ del masks
416
+
417
+
418
+ # Filter by predicted IoU
419
+ if self.pred_iou_thresh > 0.0:
420
+ keep_mask = data["iou_preds"] > self.pred_iou_thresh
421
+ data.filter(keep_mask)
422
+
423
+ # Calculate stability score
424
+ data["stability_score"] = calculate_stability_score(
425
+ data["masks"], self.predictor.model.mask_threshold, self.stability_score_offset
426
+ )
427
+ if self.stability_score_thresh > 0.0:
428
+ keep_mask = data["stability_score"] >= self.stability_score_thresh
429
+ data.filter(keep_mask)
430
+
431
+ # Threshold masks and calculate boxes
432
+ data["masks"] = data["masks"] > self.predictor.model.mask_threshold
433
+ data["boxes"] = batched_mask_to_box(data["masks"])
434
+
435
+ # Filter boxes that touch crop boundaries
436
+ keep_mask = ~is_box_near_crop_edge(data["boxes"], crop_box, [0, 0, orig_w, orig_h])
437
+ if not torch.all(keep_mask):
438
+ data.filter(keep_mask)
439
+
440
+ # Compress to RLE
441
+ data["masks"] = uncrop_masks(data["masks"], crop_box, orig_h, orig_w)
442
+ data["rles"] = mask_to_rle_pytorch(data["masks"])
443
+ del data["masks"]
444
+
445
+ return data
446
+
447
+ @staticmethod
448
+ def postprocess_small_regions(
449
+ mask_data: MaskData, min_area: int, nms_thresh: float
450
+ ) -> MaskData:
451
+ """
452
+ Removes small disconnected regions and holes in masks, then reruns
453
+ box NMS to remove any new duplicates.
454
+
455
+ Edits mask_data in place.
456
+
457
+ Requires open-cv as a dependency.
458
+ """
459
+ if len(mask_data["rles"]) == 0:
460
+ return mask_data
461
+
462
+ # Filter small disconnected regions and holes
463
+ new_masks = []
464
+ scores = []
465
+ for rle in mask_data["rles"]:
466
+ mask = rle_to_mask(rle)
467
+
468
+ mask, changed = remove_small_regions(mask, min_area, mode="holes")
469
+ unchanged = not changed
470
+ mask, changed = remove_small_regions(mask, min_area, mode="islands")
471
+ unchanged = unchanged and not changed
472
+
473
+ new_masks.append(torch.as_tensor(mask).unsqueeze(0))
474
+ # Give score=0 to changed masks and score=1 to unchanged masks
475
+ # so NMS will prefer ones that didn't need postprocessing
476
+ scores.append(float(unchanged))
477
+
478
+ # Recalculate boxes and remove any new duplicates
479
+ masks = torch.cat(new_masks, dim=0)
480
+ boxes = batched_mask_to_box(masks)
481
+ keep_by_nms = batched_nms(
482
+ boxes.float(),
483
+ torch.as_tensor(scores),
484
+ torch.zeros(len(boxes)), # categories
485
+ iou_threshold=nms_thresh,
486
+ )
487
+
488
+ # Only recalculate RLEs for masks that have changed
489
+ for i_mask in keep_by_nms:
490
+ if scores[i_mask] == 0.0:
491
+ mask_torch = masks[i_mask].unsqueeze(0)
492
+ mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0]
493
+ mask_data["boxes"][i_mask] = boxes[i_mask] # update res directly
494
+ mask_data.filter(keep_by_nms)
495
+
496
+ return mask_data
env.yaml ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: ltce
2
+ channels:
3
+ - conda-forge
4
+ - defaults
5
+ dependencies:
6
+ - _libgcc_mutex=0.1=main
7
+ - _openmp_mutex=5.1=1_gnu
8
+ - asttokens=2.2.1=pyhd8ed1ab_0
9
+ - backcall=0.2.0=pyh9f0ad1d_0
10
+ - backports=1.0=pyhd8ed1ab_3
11
+ - backports.functools_lru_cache=1.6.4=pyhd8ed1ab_0
12
+ - blas=1.0=openblas
13
+ - brotli=1.0.9=h5eee18b_7
14
+ - brotli-bin=1.0.9=h5eee18b_7
15
+ - bzip2=1.0.8=h7b6447c_0
16
+ - ca-certificates=2023.01.10=h06a4308_0
17
+ - cairo=1.16.0=hb05425b_4
18
+ - certifi=2022.12.7=py310h06a4308_0
19
+ - contourpy=1.0.5=py310hdb19cb5_0
20
+ - cycler=0.11.0=pyhd3eb1b0_0
21
+ - dbus=1.13.18=hb2f20db_0
22
+ - debugpy=1.5.1=py310h295c915_0
23
+ - decorator=5.1.1=pyhd8ed1ab_0
24
+ - eigen=3.4.0=h4bd325d_0
25
+ - entrypoints=0.4=pyhd8ed1ab_0
26
+ - executing=1.2.0=pyhd8ed1ab_0
27
+ - expat=2.2.10=h9c3ff4c_0
28
+ - ffmpeg=4.2.2=h20bf706_0
29
+ - fontconfig=2.14.1=hef1e5e3_0
30
+ - fonttools=4.25.0=pyhd3eb1b0_0
31
+ - freetype=2.10.4=h0708190_1
32
+ - giflib=5.2.1=h36c2ea0_2
33
+ - glib=2.69.1=h4ff587b_1
34
+ - gmp=6.2.1=h58526e2_0
35
+ - gnutls=3.6.13=h85f3911_1
36
+ - graphite2=1.3.14=h295c915_1
37
+ - gst-plugins-base=1.14.1=h6a678d5_1
38
+ - gstreamer=1.14.1=h5eee18b_1
39
+ - harfbuzz=4.3.0=hf52aaf7_1
40
+ - hdf5=1.10.6=h3ffc7dd_1
41
+ - icu=58.2=hf484d3e_1000
42
+ - ipykernel=6.15.0=pyh210e3f2_0
43
+ - ipython=8.12.0=pyh41d4057_0
44
+ - jedi=0.18.2=pyhd8ed1ab_0
45
+ - jpeg=9e=h166bdaf_1
46
+ - jupyter_client=7.3.4=pyhd8ed1ab_0
47
+ - jupyter_core=5.3.0=py310hff52083_0
48
+ - keyutils=1.6.1=h166bdaf_0
49
+ - kiwisolver=1.4.4=py310h6a678d5_0
50
+ - krb5=1.19.3=h3790be6_0
51
+ - lame=3.100=h7f98852_1001
52
+ - lcms2=2.12=h3be6417_0
53
+ - ld_impl_linux-64=2.38=h1181459_1
54
+ - lerc=3.0=h295c915_0
55
+ - libblas=3.9.0=15_linux64_openblas
56
+ - libbrotlicommon=1.0.9=h5eee18b_7
57
+ - libbrotlidec=1.0.9=h5eee18b_7
58
+ - libbrotlienc=1.0.9=h5eee18b_7
59
+ - libcblas=3.9.0=15_linux64_openblas
60
+ - libclang=10.0.1=default_hb85057a_2
61
+ - libdeflate=1.17=h5eee18b_0
62
+ - libedit=3.1.20191231=he28a2e2_2
63
+ - libevent=2.1.12=h8f2d780_0
64
+ - libffi=3.3=he6710b0_2
65
+ - libgcc-ng=11.2.0=h1234567_1
66
+ - libgfortran-ng=12.2.0=h69a702a_19
67
+ - libgfortran5=12.2.0=h337968e_19
68
+ - libgomp=11.2.0=h1234567_1
69
+ - liblapack=3.9.0=15_linux64_openblas
70
+ - libllvm10=10.0.1=he513fc3_3
71
+ - libopenblas=0.3.20=pthreads_h78a6416_0
72
+ - libopus=1.3.1=h7f98852_1
73
+ - libpng=1.6.39=h5eee18b_0
74
+ - libpq=12.9=h16c4e8d_3
75
+ - libprotobuf=3.20.3=he621ea3_0
76
+ - libsodium=1.0.18=h36c2ea0_1
77
+ - libstdcxx-ng=11.2.0=h1234567_1
78
+ - libtiff=4.5.0=h6a678d5_2
79
+ - libuuid=1.41.5=h5eee18b_0
80
+ - libvpx=1.7.0=h439df22_0
81
+ - libwebp=1.2.4=h11a3e52_1
82
+ - libwebp-base=1.2.4=h5eee18b_1
83
+ - libxcb=1.15=h7f8727e_0
84
+ - libxkbcommon=1.0.1=hfa300c1_0
85
+ - libxml2=2.9.14=h74e7548_0
86
+ - libxslt=1.1.35=h4e12654_0
87
+ - lz4-c=1.9.3=h9c3ff4c_1
88
+ - matplotlib=3.7.1=py310h06a4308_1
89
+ - matplotlib-base=3.7.1=py310h1128e8f_1
90
+ - matplotlib-inline=0.1.6=pyhd8ed1ab_0
91
+ - munkres=1.1.4=py_0
92
+ - ncurses=6.4=h6a678d5_0
93
+ - nest-asyncio=1.5.6=pyhd8ed1ab_0
94
+ - nettle=3.6=he412f7d_0
95
+ - nspr=4.33=h295c915_0
96
+ - nss=3.74=h0370c37_0
97
+ - opencv=4.6.0=py310h1128e8f_3
98
+ - openh264=2.1.1=h4ff587b_0
99
+ - openjpeg=2.4.0=h3ad879b_0
100
+ - openssl=1.1.1t=h7f8727e_0
101
+ - packaging=23.1=pyhd8ed1ab_0
102
+ - parso=0.8.3=pyhd8ed1ab_0
103
+ - pcre=8.45=h9c3ff4c_0
104
+ - pexpect=4.8.0=pyh1a96a4e_2
105
+ - pickleshare=0.7.5=py_1003
106
+ - pip=23.0.1=py310h06a4308_0
107
+ - pixman=0.40.0=h36c2ea0_0
108
+ - platformdirs=3.2.0=pyhd8ed1ab_0
109
+ - ply=3.11=py310h06a4308_0
110
+ - prompt-toolkit=3.0.38=pyha770c72_0
111
+ - prompt_toolkit=3.0.38=hd8ed1ab_0
112
+ - psutil=5.9.0=py310h5eee18b_0
113
+ - ptyprocess=0.7.0=pyhd3deb0d_0
114
+ - pure_eval=0.2.2=pyhd8ed1ab_0
115
+ - pygments=2.15.0=pyhd8ed1ab_0
116
+ - pyparsing=3.0.9=py310h06a4308_0
117
+ - pyqt=5.15.7=py310h6a678d5_1
118
+ - python=3.10.4=h12debd9_0
119
+ - python-dateutil=2.8.2=pyhd8ed1ab_0
120
+ - python_abi=3.10=2_cp310
121
+ - pyzmq=23.2.0=py310h6a678d5_0
122
+ - qt-main=5.15.2=h327a75a_7
123
+ - qt-webengine=5.15.9=hd2b0992_4
124
+ - qtwebkit=5.212=h4eab89a_4
125
+ - readline=8.2=h5eee18b_0
126
+ - setuptools=65.6.3=py310h06a4308_0
127
+ - sip=6.6.2=py310h6a678d5_0
128
+ - six=1.16.0=pyh6c4a22f_0
129
+ - sqlite=3.41.2=h5eee18b_0
130
+ - stack_data=0.6.2=pyhd8ed1ab_0
131
+ - tk=8.6.12=h1ccaba5_0
132
+ - toml=0.10.2=pyhd3eb1b0_0
133
+ - tornado=6.1=py310h5764c6d_3
134
+ - tqdm=4.65.0=py310h2f386ee_0
135
+ - traitlets=5.9.0=pyhd8ed1ab_0
136
+ - typing-extensions=4.5.0=hd8ed1ab_0
137
+ - typing_extensions=4.5.0=pyha770c72_0
138
+ - tzdata=2023c=h04d1e81_0
139
+ - wcwidth=0.2.6=pyhd8ed1ab_0
140
+ - wheel=0.38.4=py310h06a4308_0
141
+ - x264=1!157.20191217=h7b6447c_0
142
+ - xz=5.2.10=h5eee18b_1
143
+ - zeromq=4.3.4=h9c3ff4c_1
144
+ - zlib=1.2.13=h5eee18b_0
145
+ - zstd=1.5.2=ha4553b6_0
146
+ - pip:
147
+ - charset-normalizer==3.1.0
148
+ - cmake==3.26.3
149
+ - filelock==3.11.0
150
+ - idna==3.4
151
+ - jinja2==3.1.2
152
+ - lit==16.0.1
153
+ - markupsafe==2.1.2
154
+ - mpmath==1.3.0
155
+ - networkx==3.1
156
+ - numpy==1.24.2
157
+ - nvidia-cublas-cu11==11.10.3.66
158
+ - nvidia-cuda-cupti-cu11==11.7.101
159
+ - nvidia-cuda-nvrtc-cu11==11.7.99
160
+ - nvidia-cuda-runtime-cu11==11.7.99
161
+ - nvidia-cudnn-cu11==8.5.0.96
162
+ - nvidia-cufft-cu11==10.9.0.58
163
+ - nvidia-curand-cu11==10.2.10.91
164
+ - nvidia-cusolver-cu11==11.4.0.1
165
+ - nvidia-cusparse-cu11==11.7.4.91
166
+ - nvidia-nccl-cu11==2.14.3
167
+ - nvidia-nvtx-cu11==11.7.91
168
+ - pillow==9.5.0
169
+ - pyqt5-sip==12.11.0
170
+ - requests==2.28.2
171
+ - segment-anything==1.0
172
+ - sympy==1.11.1
173
+ - torch==2.0.0
174
+ - torchaudio==2.0.1
175
+ - torchvision==0.15.1
176
+ - triton==2.0.0
177
+ - urllib3==1.26.15
example.png ADDED
resultFSC.png ADDED
resultcoco.png ADDED
test_FSC.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import argparse
3
+ import json
4
+ import numpy as np
5
+ from tqdm import tqdm
6
+ from os.path import exists
7
+ import os
8
+
9
+ from segment_anything import sam_model_registry
10
+ from automatic_mask_generator import SamAutomaticMaskGenerator
11
+ import matplotlib.pyplot as plt
12
+
13
+
14
+
15
+
16
+ parser = argparse.ArgumentParser(description="Few Shot Counting Evaluation code")
17
+ parser.add_argument("-dp", "--data_path", type=str, default='/data/counte/', help="Path to the FSC147 dataset")
18
+ parser.add_argument("-ts", "--test_split", type=str, default='val', choices=["val_PartA","val_PartB","test_PartA","test_PartB","test", "val"], help="what data split to evaluate on")
19
+ parser.add_argument("-mt", "--model_type", type=str, default="vit_h", help="model type")
20
+ parser.add_argument("-mp", "--model_path", type=str, default="/home/teddy/segment-anything/sam_vit_h_4b8939.pth", help="path to trained model")
21
+ parser.add_argument("-v", "--viz", type=bool, default=True, help="wether to visualize")
22
+ parser.add_argument("-d", "--device", default='0', help='assign device')
23
+ args = parser.parse_args()
24
+
25
+ data_path = args.data_path
26
+ anno_file = data_path + 'annotation_FSC147_384.json'
27
+ data_split_file = data_path + 'Train_Test_Val_FSC_147.json'
28
+ im_dir = data_path + 'images_384_VarV2'
29
+
30
+
31
+ if not exists(anno_file) or not exists(im_dir):
32
+ print("Make sure you set up the --data-path correctly.")
33
+ print("Current setting is {}, but the image dir and annotation file do not exist.".format(args.data_path))
34
+ print("Aborting the evaluation")
35
+ exit(-1)
36
+
37
+ def show_anns(anns):
38
+ if len(anns) == 0:
39
+ return
40
+ sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
41
+ ax = plt.gca()
42
+ ax.set_autoscale_on(False)
43
+ for ann in sorted_anns:
44
+ x0, y0, w, h = ann['bbox']
45
+ ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))
46
+ ax.scatter([x0+w//2], [y0+h//2], color='green', marker='*', s=10, edgecolor='white', linewidth=1.25)
47
+
48
+
49
+ debug = True
50
+ os.environ['CUDA_VISIBLE_DEVICES'] = args.device.strip()
51
+ device = 'cuda'
52
+ sam = sam_model_registry[args.model_type](checkpoint=args.model_path)
53
+ sam.to(device=device)
54
+
55
+
56
+ mask_generator = SamAutomaticMaskGenerator(
57
+ model=sam,
58
+ min_mask_region_area=25
59
+ )
60
+
61
+ with open(anno_file) as f:
62
+ annotations = json.load(f)
63
+
64
+ with open(data_split_file) as f:
65
+ data_split = json.load(f)
66
+
67
+
68
+ cnt = 0
69
+ SAE = 0 # sum of absolute errors
70
+ SSE = 0 # sum of square errors
71
+
72
+ print("Evaluation on {} data".format(args.test_split))
73
+ im_ids = data_split[args.test_split]
74
+
75
+ # with open("err.json") as f:
76
+ # im_ids = json.load(f)
77
+
78
+
79
+ pbar = tqdm(im_ids)
80
+ # err_list = []
81
+ for im_id in pbar:
82
+ anno = annotations[im_id]
83
+ bboxes = anno['box_examples_coordinates']
84
+ dots = np.array(anno['points'])
85
+
86
+ image = cv2.imread('{}/{}'.format(im_dir, im_id))
87
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
88
+
89
+ input_boxes = list()
90
+ for bbox in bboxes:
91
+ x1, y1 = bbox[0][0], bbox[0][1]
92
+ x2, y2 = bbox[2][0], bbox[2][1]
93
+ input_boxes.append([x1, y1, x2, y2])
94
+
95
+ masks = mask_generator.generate(image, input_boxes)
96
+ if args.viz:
97
+ if not exists('viz'):
98
+ os.mkdir('viz')
99
+ plt.figure(figsize=(10,10))
100
+ plt.imshow(image)
101
+ show_anns(masks)
102
+ plt.axis('off')
103
+ plt.savefig('viz/{}'.format(im_id))
104
+ plt.close()
105
+
106
+ gt_cnt = dots.shape[0]
107
+ pred_cnt = len(masks)
108
+ cnt = cnt + 1
109
+ err = abs(gt_cnt - pred_cnt)
110
+ SAE += err
111
+ SSE += err**2
112
+ # if err / gt_cnt > 0.7:
113
+ # err_list.append(im_id)
114
+
115
+ pbar.set_description('{:<8}: actual-predicted: {:6d}, {:6.1f}, error: {:6.1f}. Current MAE: {:5.2f}, RMSE: {:5.2f}'.\
116
+ format(im_id, gt_cnt, pred_cnt, abs(pred_cnt - gt_cnt), SAE/cnt, (SSE/cnt)**0.5))
117
+
118
+ print('On {} data, MAE: {:6.2f}, RMSE: {:6.2f}'.format(args.test_split, SAE/cnt, (SSE/cnt)**0.5))
119
+ # with open('err.json', "w") as f:
120
+ # json.dump(err_list, f)
test_coco.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import argparse
3
+ import json
4
+ import numpy as np
5
+ from tqdm import tqdm
6
+ from os.path import exists
7
+ import os
8
+
9
+ from segment_anything import sam_model_registry
10
+ from automatic_mask_generator import SamAutomaticMaskGenerator
11
+ import matplotlib.pyplot as plt
12
+
13
+
14
+
15
+
16
+ parser = argparse.ArgumentParser(description="Few Shot Counting Evaluation code")
17
+ parser.add_argument("-dp", "--data_path", type=str, default='/data/counte/', help="Path to the coco dataset")
18
+ parser.add_argument("-ts", "--test_split", type=str, default='val2017', choices=["val2017"], help="what data split to evaluate on")
19
+ parser.add_argument("-mt", "--model_type", type=str, default="vit_h", help="model type")
20
+ parser.add_argument("-mp", "--model_path", type=str, default="/home/teddy/segment-anything/sam_vit_h_4b8939.pth", help="path to trained model")
21
+ parser.add_argument("-v", "--viz", type=bool, default=True, help="wether to visualize")
22
+ parser.add_argument("-d", "--device", default='0', help='assign device')
23
+ args = parser.parse_args()
24
+
25
+ data_path = args.data_path
26
+ anno_file = data_path + 'annotations_trainval2017/annotations/instances_val2017.json'
27
+ im_dir = data_path + 'val2017'
28
+
29
+
30
+ if not exists(anno_file) or not exists(im_dir):
31
+ print("Make sure you set up the --data-path correctly.")
32
+ print("Current setting is {}, but the image dir and annotation file do not exist.".format(args.data_path))
33
+ print("Aborting the evaluation")
34
+ exit(-1)
35
+
36
+ def show_anns(anns):
37
+ if len(anns) == 0:
38
+ return
39
+ sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
40
+ ax = plt.gca()
41
+ ax.set_autoscale_on(False)
42
+ for ann in sorted_anns:
43
+ x0, y0, w, h = ann['bbox']
44
+ ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))
45
+ ax.scatter([x0+w//2], [y0+h//2], color='green', marker='*', s=10, edgecolor='white', linewidth=1.25)
46
+
47
+
48
+ debug = True
49
+ os.environ['CUDA_VISIBLE_DEVICES'] = args.device.strip()
50
+ device = 'cuda'
51
+ sam = sam_model_registry[args.model_type](checkpoint=args.model_path)
52
+ sam.to(device=device)
53
+
54
+
55
+ mask_generator = SamAutomaticMaskGenerator(
56
+ model=sam,
57
+ min_mask_region_area=25
58
+ )
59
+
60
+ with open(anno_file) as f:
61
+ annotations = json.load(f)
62
+
63
+ images = sorted(annotations['images'],key=lambda x:x['file_name'])
64
+
65
+ prepared_json = {}
66
+ for i in images:
67
+ prepared_json[i['file_name']] = {
68
+ "H":i['height'],
69
+ "W":i['width'],
70
+ "boxes":{},
71
+ # "category_ids":[],
72
+ }
73
+ for i in annotations['annotations']:
74
+ im_id = str(i['image_id'])
75
+ prezero = 12 - len(im_id)
76
+ im_id = '0'*prezero + im_id + ".jpg"
77
+ if i["category_id"] in prepared_json[im_id]["boxes"]:
78
+ prepared_json[im_id]["boxes"][i["category_id"]].append(i['bbox'])
79
+ else:
80
+ prepared_json[im_id]["boxes"][i["category_id"]] = []
81
+ prepared_json[im_id]["boxes"][i["category_id"]].append(i['bbox'])
82
+
83
+ im_ids = []
84
+ for i in prepared_json.keys():
85
+ im_ids.append(i)
86
+
87
+
88
+ cnt = 0
89
+ folds = [
90
+ [1,5,9,14,18,22,27,33,37,41,46,50,54,58,62,67,74,78,82,87],
91
+ [2,6,10,15,19,23,28,34,38,42,47,51,55,59,63,70,75,79,84,88],
92
+ [3,7,11,16,20,24,31,35,39,43,48,52,56,60,64,72,76,80,85,89],
93
+ [4,8,13,17,21,25,32,36,40,44,49,53,57,61,65,73,77,81,86,90],
94
+ ]
95
+ SAE = [0,0,0,0] # sum of absolute errors
96
+ SSE = [0,0,0,0] # sum of square errors
97
+
98
+ print("Evaluation on {} data".format(args.test_split))
99
+
100
+ # logs = []
101
+
102
+
103
+ pbar = tqdm(im_ids)
104
+ # err_list = []
105
+ for im_id in pbar:
106
+ category_id = list(prepared_json[im_id]['boxes'].keys())
107
+
108
+ image = cv2.imread('{}/{}'.format(im_dir, im_id))
109
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
110
+ # log = []
111
+ # log.append(im_id)
112
+
113
+ for id in category_id:
114
+ boxes = prepared_json[im_id]['boxes'][id]
115
+
116
+ input_boxes = list()
117
+ x1, y1 = boxes[0][0],boxes[0][1]
118
+ x2, y2 = boxes[0][0] + boxes[0][2],boxes[0][1] + boxes[0][3]
119
+ input_boxes.append([x1, y1, x2, y2])
120
+
121
+ masks = mask_generator.generate(image, input_boxes)
122
+
123
+ if args.viz:
124
+ if not exists('viz'):
125
+ os.mkdir('viz')
126
+ plt.figure(figsize=(10,10))
127
+ plt.imshow(image)
128
+ show_anns(masks)
129
+ plt.axis('off')
130
+ plt.savefig('viz/{}_{}.jpg'.format(im_id[0:-4],id))
131
+ plt.close()
132
+
133
+ gt_cnt = len(boxes)
134
+ pred_cnt = len(masks)
135
+ err = abs(gt_cnt - pred_cnt)
136
+ log.append("\n{},gt_cnt:{},pred_cnt:{}".format(id,gt_cnt,pred_cnt))
137
+ if id in folds[0]:
138
+ SAE[0] += err
139
+ SSE[0] += err**2
140
+ elif id in folds[1]:
141
+ SAE[1] += err
142
+ SSE[1] += err**2
143
+ elif id in folds[2]:
144
+ SAE[2] += err
145
+ SSE[2] += err**2
146
+ elif id in folds[3]:
147
+ SAE[3] += err
148
+ SSE[3] += err**2
149
+
150
+ cnt = cnt + 1
151
+ # logs.append(log)
152
+ pbar.set_description('fold1: {:6.2f}, fold2: {:6.2f}, fold3: {:6.2f}, fold4: {:6.2f},'.\
153
+ format(SAE[0]/cnt,SAE[1]/cnt,SAE[2]/cnt,SAE[3]/cnt))
154
+
155
+ print('On {} data, fold1 MAE: {:6.2f}, RMSE: {:6.2f}\n \
156
+ fold2 MAE: {:6.2f}, RMSE: {:6.2f}\n \
157
+ fold3 MAE: {:6.2f}, RMSE: {:6.2f}\n \
158
+ fold4 MAE: {:6.2f}, RMSE: {:6.2f}\n \
159
+ '.format(args.test_split,SAE[0]/cnt,(SSE[0]/cnt)**0.5,SAE[1]/cnt,(SSE[1]/cnt)**0.5,SAE[2]/cnt,(SSE[2]/cnt)**0.5,SAE[3]/cnt,(SSE[3]/cnt)**0.5))
vis_FSC.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
vis_coco.ipynb ADDED
The diff for this file is too large to render. See raw diff