Spaces:
Runtime error
Runtime error
all
Browse files- .gitattributes +1 -0
- LICENSE +21 -0
- README.md +54 -13
- SAM_counting_anything__ArXiv_.pdf +3 -0
- app.py +88 -0
- automatic_mask_generator.py +496 -0
- env.yaml +177 -0
- example.png +0 -0
- resultFSC.png +0 -0
- resultcoco.png +0 -0
- test_FSC.py +120 -0
- test_coco.py +159 -0
- vis_FSC.ipynb +0 -0
- vis_coco.ipynb +0 -0
.gitattributes
CHANGED
@@ -32,3 +32,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
35 |
+
SAM_counting_anything__ArXiv_.pdf filter=lfs diff=lfs merge=lfs -text
|
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2023 Vision-Intelligence-and-Robots-Group
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
README.md
CHANGED
@@ -1,13 +1,54 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# count-anything
|
2 |
+
An empirical study on few-shot counting using segment anything (SAM)
|
3 |
+
|
4 |
+
Meta AI recently released the Segment Anything model [[SAM]](https://github.com/facebookresearch/segment-anything), which has garnered attention due to its impressive performance in class-agnostic segmenting. In this study, we explore the use of SAM for the challenging task of few-shot object counting, which involves counting objects of an unseen category by providing a few bounding boxes of examples. We compare SAM's performance with other few-shot counting methods and find that it is currently unsatisfactory without further fine-tuning, particularly for small and crowded objects.
|
5 |
+
|
6 |
+
![image](example.png)
|
7 |
+
## Install
|
8 |
+
Install python dependencies. We use conda and python 3.10.4 and PyTorch 1.13.1
|
9 |
+
> conda env create -f env.yaml
|
10 |
+
|
11 |
+
## Dataset preparation
|
12 |
+
- For FSC-147:
|
13 |
+
Images can be downloaded from here: https://drive.google.com/file/d/1ymDYrGs9DSRicfZbSCDiOu0ikGDh5k6S/view?usp=sharing
|
14 |
+
|
15 |
+
- For coco val2017:
|
16 |
+
Images can be downloaded from here: https://cocodataset.org/
|
17 |
+
## Comparison Results
|
18 |
+
|
19 |
+
### FSC
|
20 |
+
|
21 |
+
![image](resultFSC.png)
|
22 |
+
|
23 |
+
### COCO
|
24 |
+
|
25 |
+
![image](resultcoco.png)
|
26 |
+
## Test
|
27 |
+
Download the [ViT-H SAM model](https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth)
|
28 |
+
|
29 |
+
- For FSC-147:
|
30 |
+
```
|
31 |
+
python test_FSC.py --data_path <FSC-147 dataset path> --model_path <path to ViT-H SAM model>
|
32 |
+
```
|
33 |
+
|
34 |
+
- For coco val2017:
|
35 |
+
```
|
36 |
+
python test_coco.py --data_path <coco val2017 dataset path\> --model_path <path to ViT-H SAM model>
|
37 |
+
```
|
38 |
+
|
39 |
+
## Visualize
|
40 |
+
You can run [vis_FSC.ipynb](vis_FSC.ipynb) for FSC-147 or [vis_coco.ipynb](vis_coco.ipynb) for coco.
|
41 |
+
|
42 |
+
## Acknowledgement
|
43 |
+
We thank facebookresearch for their segment-anything model [[project]](https://github.com/facebookresearch/segment-anything), cvlab-stonybrook for their Learning To Count Everything [[project]](https://github.com/cvlab-stonybrook/LearningToCountEverything) and coco [[datasets]](https://cocodataset.org/).
|
44 |
+
|
45 |
+
## Citation
|
46 |
+
If you find the code useful, please cite:
|
47 |
+
```
|
48 |
+
@article{ma2023countanything,
|
49 |
+
title={CAN SAM COUNT ANYTHING? AN EMPIRICAL STUDY ON SAM COUNTING},
|
50 |
+
author={Ma, Zhiheng and Hong, Xiaopeng and Shangguan Qinnan},
|
51 |
+
journal={arXiv preprint arXiv:2304.xxxxx},
|
52 |
+
year={2023}
|
53 |
+
}
|
54 |
+
```
|
SAM_counting_anything__ArXiv_.pdf
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:49bcb63df39d9ad072ae11393355e29757169608fa4540388698cf23a2c7110a
|
3 |
+
size 6071349
|
app.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image, ImageDraw
|
2 |
+
import cv2
|
3 |
+
import gradio as gr
|
4 |
+
import torch
|
5 |
+
from segment_anything import sam_model_registry
|
6 |
+
from automatic_mask_generator import SamAutomaticMaskGenerator
|
7 |
+
|
8 |
+
device = 'cuda'
|
9 |
+
sam = sam_model_registry['vit_h'](checkpoint='./sam_vit_h_4b8939.pth')
|
10 |
+
sam.to(device=device)
|
11 |
+
|
12 |
+
|
13 |
+
mask_generator = SamAutomaticMaskGenerator(
|
14 |
+
model=sam,
|
15 |
+
min_mask_region_area=25
|
16 |
+
)
|
17 |
+
|
18 |
+
def binarize(x):
|
19 |
+
return (x != 0).astype('uint8') * 255
|
20 |
+
|
21 |
+
def draw_box(boxes=[], img=None):
|
22 |
+
if len(boxes) == 0 and img is None:
|
23 |
+
return None
|
24 |
+
|
25 |
+
if img is None:
|
26 |
+
img = Image.new('RGB', (512, 512), (255, 255, 255))
|
27 |
+
colors = ["red", "olive", "blue", "green", "orange", "brown", "cyan", "purple"]
|
28 |
+
draw = ImageDraw.Draw(img)
|
29 |
+
# print(boxes)
|
30 |
+
for bid, box in enumerate(boxes):
|
31 |
+
draw.rectangle([box[0], box[1], box[2], box[3]], outline=colors[bid % len(colors)], width=4)
|
32 |
+
return img
|
33 |
+
|
34 |
+
|
35 |
+
def draw_pred_box(boxes=[], img=None):
|
36 |
+
if len(boxes) == 0 and img is None:
|
37 |
+
return None
|
38 |
+
|
39 |
+
if img is None:
|
40 |
+
img = Image.new('RGB', (512, 512), (255, 255, 255))
|
41 |
+
colors = "green"
|
42 |
+
draw = ImageDraw.Draw(img)
|
43 |
+
# print(boxes)
|
44 |
+
for bid, box in enumerate(boxes):
|
45 |
+
draw.rectangle([box[0], box[1], box[2], box[3]], outline=colors, width=4)
|
46 |
+
return img
|
47 |
+
|
48 |
+
|
49 |
+
def debug(input_img):
|
50 |
+
mask = input_img["mask"]
|
51 |
+
mask = mask[..., 0]
|
52 |
+
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
53 |
+
|
54 |
+
boxes = []
|
55 |
+
for contour in contours:
|
56 |
+
y1, y2 = contour[:, 0, 1].min(), contour[:, 0, 1].max()
|
57 |
+
x1, x2 = contour[:, 0, 0].min(), contour[:, 0, 0].max()
|
58 |
+
boxes.append([x1, y1, x2, y2])
|
59 |
+
draw_image = draw_box(boxes, Image.fromarray(input_img["image"]))
|
60 |
+
|
61 |
+
masks = mask_generator.generate(input_img["image"], boxes)
|
62 |
+
pred_cnt = len(masks)
|
63 |
+
pred_bboxes = []
|
64 |
+
for i in masks:
|
65 |
+
x0, y0, w, h = i['bbox']
|
66 |
+
pred_bboxes.append([x0, y0, x0+w, y0+h])
|
67 |
+
pred_image = draw_pred_box(pred_bboxes, Image.fromarray(input_img["image"]))
|
68 |
+
return [draw_image, pred_image, "Count: {}".format(pred_cnt)]
|
69 |
+
|
70 |
+
description = """<p style="text-align: center; font-weight: bold;">
|
71 |
+
<span style="font-size: 28px">Count Anything</span>
|
72 |
+
<br>
|
73 |
+
<span style="font-size: 18px" id="paper-info">
|
74 |
+
[<a href=" " target="_blank">Project Page</a>]
|
75 |
+
[<a href=" " target="_blank">Paper</a>]
|
76 |
+
[<a href="https://github.com/Vision-Intelligence-and-Robots-Group/count-anything" target="_blank">GitHub</a>]
|
77 |
+
</span>
|
78 |
+
</p>
|
79 |
+
"""
|
80 |
+
|
81 |
+
run = gr.Interface(
|
82 |
+
debug,
|
83 |
+
gr.Image(shape=[512, 512], source="upload", tool="sketch").style(height=500, width=500),
|
84 |
+
[gr.Image(), gr.Image(), gr.Text()],
|
85 |
+
description = description
|
86 |
+
)
|
87 |
+
|
88 |
+
run.launch()
|
automatic_mask_generator.py
ADDED
@@ -0,0 +1,496 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
from torchvision.ops.boxes import batched_nms, box_area # type: ignore
|
4 |
+
|
5 |
+
from typing import Any, Dict, List, Optional, Tuple
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from collections import defaultdict
|
8 |
+
|
9 |
+
from segment_anything.modeling import Sam
|
10 |
+
from segment_anything.predictor import SamPredictor
|
11 |
+
from segment_anything.utils.amg import (
|
12 |
+
MaskData,
|
13 |
+
area_from_rle,
|
14 |
+
batch_iterator,
|
15 |
+
batched_mask_to_box,
|
16 |
+
box_xyxy_to_xywh,
|
17 |
+
build_all_layer_point_grids,
|
18 |
+
calculate_stability_score,
|
19 |
+
coco_encode_rle,
|
20 |
+
generate_crop_boxes,
|
21 |
+
is_box_near_crop_edge,
|
22 |
+
mask_to_rle_pytorch,
|
23 |
+
remove_small_regions,
|
24 |
+
rle_to_mask,
|
25 |
+
uncrop_boxes_xyxy,
|
26 |
+
uncrop_masks,
|
27 |
+
uncrop_points,
|
28 |
+
)
|
29 |
+
|
30 |
+
|
31 |
+
def uncrop_boxes_xyxy(boxes: torch.Tensor, crop_box: List[int]) -> torch.Tensor:
|
32 |
+
x0, y0, _, _ = crop_box
|
33 |
+
offset = torch.tensor([[x0, y0, x0, y0]], device=boxes.device)
|
34 |
+
# Check if boxes has a channel dimension
|
35 |
+
if len(boxes.shape) == 3:
|
36 |
+
offset = offset.unsqueeze(1)
|
37 |
+
return boxes + offset
|
38 |
+
|
39 |
+
def pre_process_ref_box(ref_box, crop_box, layer_idx):
|
40 |
+
if layer_idx == 0:
|
41 |
+
return ref_box
|
42 |
+
else:
|
43 |
+
new_bbox = []
|
44 |
+
x0, y0, x1, y1 = crop_box
|
45 |
+
for ref in ref_box:
|
46 |
+
x0_r, y0_r, x1_r, y1_r = ref
|
47 |
+
area = (y1_r - y0_r) * (x1_r - x0_r)
|
48 |
+
x_0_new = max(x0, x0_r)
|
49 |
+
y_0_new = max(y0, y0_r)
|
50 |
+
x_1_new = min(x1, x1_r)
|
51 |
+
y_1_new = min(y1, y1_r)
|
52 |
+
crop_area = (y_1_new - y_0_new) * (x_1_new - x_0_new)
|
53 |
+
if crop_area / area > 0.7:
|
54 |
+
new_bbox.append([x_0_new, y_0_new, x_1_new, y_1_new])
|
55 |
+
|
56 |
+
return new_bbox
|
57 |
+
|
58 |
+
|
59 |
+
|
60 |
+
|
61 |
+
class SamAutomaticMaskGenerator:
|
62 |
+
def __init__(
|
63 |
+
self,
|
64 |
+
model: Sam,
|
65 |
+
points_per_side: Optional[int] = 32,
|
66 |
+
points_per_batch: int = 64,
|
67 |
+
pred_iou_thresh: float = 0.88,
|
68 |
+
stability_score_thresh: float = 0.95,
|
69 |
+
stability_score_offset: float = 1.0,
|
70 |
+
box_nms_thresh: float = 0.7,
|
71 |
+
crop_n_layers: int = 0,
|
72 |
+
crop_nms_thresh: float = 0.7,
|
73 |
+
crop_overlap_ratio: float = 512 / 1500,
|
74 |
+
crop_n_points_downscale_factor: int = 1,
|
75 |
+
point_grids: Optional[List[np.ndarray]] = None,
|
76 |
+
min_mask_region_area: int = 0,
|
77 |
+
output_mode: str = "binary_mask",
|
78 |
+
) -> None:
|
79 |
+
"""
|
80 |
+
Using a SAM model, generates masks for the entire image.
|
81 |
+
Generates a grid of point prompts over the image, then filters
|
82 |
+
low quality and duplicate masks. The default settings are chosen
|
83 |
+
for SAM with a ViT-H backbone.
|
84 |
+
|
85 |
+
Arguments:
|
86 |
+
model (Sam): The SAM model to use for mask prediction.
|
87 |
+
points_per_side (int or None): The number of points to be sampled
|
88 |
+
along one side of the image. The total number of points is
|
89 |
+
points_per_side**2. If None, 'point_grids' must provide explicit
|
90 |
+
point sampling.
|
91 |
+
points_per_batch (int): Sets the number of points run simultaneously
|
92 |
+
by the model. Higher numbers may be faster but use more GPU memory.
|
93 |
+
pred_iou_thresh (float): A filtering threshold in [0,1], using the
|
94 |
+
model's predicted mask quality.
|
95 |
+
stability_score_thresh (float): A filtering threshold in [0,1], using
|
96 |
+
the stability of the mask under changes to the cutoff used to binarize
|
97 |
+
the model's mask predictions.
|
98 |
+
stability_score_offset (float): The amount to shift the cutoff when
|
99 |
+
calculated the stability score.
|
100 |
+
box_nms_thresh (float): The box IoU cutoff used by non-maximal
|
101 |
+
suppression to filter duplicate masks.
|
102 |
+
crops_n_layers (int): If >0, mask prediction will be run again on
|
103 |
+
crops of the image. Sets the number of layers to run, where each
|
104 |
+
layer has 2**i_layer number of image crops.
|
105 |
+
crops_nms_thresh (float): The box IoU cutoff used by non-maximal
|
106 |
+
suppression to filter duplicate masks between different crops.
|
107 |
+
crop_overlap_ratio (float): Sets the degree to which crops overlap.
|
108 |
+
In the first crop layer, crops will overlap by this fraction of
|
109 |
+
the image length. Later layers with more crops scale down this overlap.
|
110 |
+
crop_n_points_downscale_factor (int): The number of points-per-side
|
111 |
+
sampled in layer n is scaled down by crop_n_points_downscale_factor**n.
|
112 |
+
point_grids (list(np.ndarray) or None): A list over explicit grids
|
113 |
+
of points used for sampling, normalized to [0,1]. The nth grid in the
|
114 |
+
list is used in the nth crop layer. Exclusive with points_per_side.
|
115 |
+
min_mask_region_area (int): If >0, postprocessing will be applied
|
116 |
+
to remove disconnected regions and holes in masks with area smaller
|
117 |
+
than min_mask_region_area. Requires opencv.
|
118 |
+
output_mode (str): The form masks are returned in. Can be 'binary_mask',
|
119 |
+
'uncompressed_rle', or 'coco_rle'. 'coco_rle' requires pycocotools.
|
120 |
+
For large resolutions, 'binary_mask' may consume large amounts of
|
121 |
+
memory.
|
122 |
+
"""
|
123 |
+
|
124 |
+
assert (points_per_side is None) != (
|
125 |
+
point_grids is None
|
126 |
+
), "Exactly one of points_per_side or point_grid must be provided."
|
127 |
+
if points_per_side is not None:
|
128 |
+
self.point_grids = build_all_layer_point_grids(
|
129 |
+
points_per_side,
|
130 |
+
crop_n_layers,
|
131 |
+
crop_n_points_downscale_factor,
|
132 |
+
)
|
133 |
+
elif point_grids is not None:
|
134 |
+
self.point_grids = point_grids
|
135 |
+
else:
|
136 |
+
raise ValueError("Can't have both points_per_side and point_grid be None.")
|
137 |
+
|
138 |
+
assert output_mode in [
|
139 |
+
"binary_mask",
|
140 |
+
"uncompressed_rle",
|
141 |
+
"coco_rle",
|
142 |
+
], f"Unknown output_mode {output_mode}."
|
143 |
+
if output_mode == "coco_rle":
|
144 |
+
from pycocotools import mask as mask_utils # type: ignore # noqa: F401
|
145 |
+
|
146 |
+
if min_mask_region_area > 0:
|
147 |
+
import cv2 # type: ignore # noqa: F401
|
148 |
+
|
149 |
+
self.predictor = SamPredictor(model)
|
150 |
+
self.points_per_batch = points_per_batch
|
151 |
+
self.pred_iou_thresh = pred_iou_thresh
|
152 |
+
self.stability_score_thresh = stability_score_thresh
|
153 |
+
self.stability_score_offset = stability_score_offset
|
154 |
+
self.box_nms_thresh = box_nms_thresh
|
155 |
+
self.crop_n_layers = crop_n_layers
|
156 |
+
self.crop_nms_thresh = crop_nms_thresh
|
157 |
+
self.crop_overlap_ratio = crop_overlap_ratio
|
158 |
+
self.crop_n_points_downscale_factor = crop_n_points_downscale_factor
|
159 |
+
self.min_mask_region_area = min_mask_region_area
|
160 |
+
self.output_mode = output_mode
|
161 |
+
|
162 |
+
self.prototype = defaultdict(list)
|
163 |
+
|
164 |
+
@torch.no_grad()
|
165 |
+
def generate(self, image: np.ndarray, ref_bbox) -> List[Dict[str, Any]]:
|
166 |
+
"""
|
167 |
+
Generates masks for the given image.
|
168 |
+
|
169 |
+
Arguments:
|
170 |
+
image (np.ndarray): The image to generate masks for, in HWC uint8 format.
|
171 |
+
|
172 |
+
|
173 |
+
Returns:
|
174 |
+
list(dict(str, any)): A list over records for masks. Each record is
|
175 |
+
a dict containing the following keys:
|
176 |
+
segmentation (dict(str, any) or np.ndarray): The mask. If
|
177 |
+
output_mode='binary_mask', is an array of shape HW. Otherwise,
|
178 |
+
is a dictionary containing the RLE.
|
179 |
+
bbox (list(float)): The box around the mask, in XYWH format.
|
180 |
+
area (int): The area in pixels of the mask.
|
181 |
+
predicted_iou (float): The model's own prediction of the mask's
|
182 |
+
quality. This is filtered by the pred_iou_thresh parameter.
|
183 |
+
point_coords (list(list(float))): The point coordinates input
|
184 |
+
to the model to generate this mask.
|
185 |
+
stability_score (float): A measure of the mask's quality. This
|
186 |
+
is filtered on using the stability_score_thresh parameter.
|
187 |
+
crop_box (list(float)): The crop of the image used to generate
|
188 |
+
the mask, given in XYWH format.
|
189 |
+
"""
|
190 |
+
|
191 |
+
# Generate masks
|
192 |
+
mask_data = self._generate_masks(image, ref_bbox)
|
193 |
+
|
194 |
+
# Filter small disconnected regions and holes in masks
|
195 |
+
if self.min_mask_region_area > 0:
|
196 |
+
mask_data = self.postprocess_small_regions(
|
197 |
+
mask_data,
|
198 |
+
self.min_mask_region_area,
|
199 |
+
max(self.box_nms_thresh, self.crop_nms_thresh),
|
200 |
+
)
|
201 |
+
|
202 |
+
# Encode masks
|
203 |
+
if self.output_mode == "coco_rle":
|
204 |
+
mask_data["segmentations"] = [coco_encode_rle(rle) for rle in mask_data["rles"]]
|
205 |
+
elif self.output_mode == "binary_mask":
|
206 |
+
mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]]
|
207 |
+
else:
|
208 |
+
mask_data["segmentations"] = mask_data["rles"]
|
209 |
+
|
210 |
+
# Write mask records
|
211 |
+
curr_anns = []
|
212 |
+
for idx in range(len(mask_data["segmentations"])):
|
213 |
+
ann = {
|
214 |
+
"segmentation": mask_data["segmentations"][idx],
|
215 |
+
"area": area_from_rle(mask_data["rles"][idx]),
|
216 |
+
"bbox": box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(),
|
217 |
+
"predicted_iou": mask_data["iou_preds"][idx].item(),
|
218 |
+
"point_coords": [mask_data["points"][idx].tolist()],
|
219 |
+
"stability_score": mask_data["stability_score"][idx].item(),
|
220 |
+
"crop_box": box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(),
|
221 |
+
}
|
222 |
+
curr_anns.append(ann)
|
223 |
+
|
224 |
+
return curr_anns
|
225 |
+
|
226 |
+
def _generate_masks(self, image: np.ndarray, ref_box) -> MaskData:
|
227 |
+
orig_size = image.shape[:2]
|
228 |
+
crop_boxes, layer_idxs = generate_crop_boxes(
|
229 |
+
orig_size, self.crop_n_layers, self.crop_overlap_ratio
|
230 |
+
)
|
231 |
+
|
232 |
+
# Iterate over image crops
|
233 |
+
# data = MaskData()
|
234 |
+
data_dic = defaultdict(MaskData)
|
235 |
+
for crop_box, layer_idx in zip(crop_boxes, layer_idxs):
|
236 |
+
crop_data = self._process_crop(image, crop_box, layer_idx, orig_size, ref_box)
|
237 |
+
data_dic[layer_idx].cat(crop_data)
|
238 |
+
|
239 |
+
data = MaskData()
|
240 |
+
for layer_idx in data_dic.keys():
|
241 |
+
proto_fea = torch.concat(self.prototype[layer_idx], dim=0)
|
242 |
+
if len(proto_fea) > 1:
|
243 |
+
cos_dis = proto_fea @ proto_fea.t()
|
244 |
+
sim_thresh = torch.min(cos_dis)
|
245 |
+
else:
|
246 |
+
sim_thresh = 0.7
|
247 |
+
sub_data = data_dic[layer_idx]
|
248 |
+
fea = sub_data['fea']
|
249 |
+
cos_dis = torch.max(fea @ proto_fea.t(), dim=1)[0]
|
250 |
+
sub_data.filter(cos_dis>=sim_thresh)
|
251 |
+
data.cat(sub_data)
|
252 |
+
|
253 |
+
self.prototype = defaultdict(list)
|
254 |
+
|
255 |
+
|
256 |
+
# Remove duplicate masks between crops
|
257 |
+
if len(crop_boxes) > 1:
|
258 |
+
# Prefer masks from smaller crops
|
259 |
+
scores = 1 / box_area(data["crop_boxes"])
|
260 |
+
scores = scores.to(data["boxes"].device)
|
261 |
+
keep_by_nms = batched_nms(
|
262 |
+
data["boxes"].float(),
|
263 |
+
scores,
|
264 |
+
torch.zeros(len(data["boxes"])), # categories
|
265 |
+
iou_threshold=self.crop_nms_thresh,
|
266 |
+
)
|
267 |
+
data.filter(keep_by_nms)
|
268 |
+
|
269 |
+
data.to_numpy()
|
270 |
+
return data
|
271 |
+
|
272 |
+
def _process_crop(
|
273 |
+
self,
|
274 |
+
image: np.ndarray,
|
275 |
+
crop_box: List[int],
|
276 |
+
crop_layer_idx: int,
|
277 |
+
orig_size: Tuple[int, ...],
|
278 |
+
ref_box
|
279 |
+
) -> MaskData:
|
280 |
+
# Crop the image and calculate embeddings
|
281 |
+
x0, y0, x1, y1 = crop_box
|
282 |
+
cropped_im = image[y0:y1, x0:x1, :]
|
283 |
+
cropped_im_size = cropped_im.shape[:2]
|
284 |
+
self.predictor.set_image(cropped_im)
|
285 |
+
|
286 |
+
ref_box = pre_process_ref_box(ref_box, crop_box, crop_layer_idx)
|
287 |
+
if len(ref_box) > 0:
|
288 |
+
ref_box = torch.tensor(ref_box, device=self.predictor.device)
|
289 |
+
transformed_boxes = self.predictor.transform.apply_boxes_torch(ref_box, cropped_im_size)
|
290 |
+
masks, iou_preds, low_res_masks = self.predictor.predict_torch(
|
291 |
+
point_coords=None,
|
292 |
+
point_labels=None,
|
293 |
+
boxes=transformed_boxes,
|
294 |
+
multimask_output=False
|
295 |
+
)
|
296 |
+
feature = self.predictor.get_image_embedding()
|
297 |
+
|
298 |
+
low_res_masks = F.interpolate(low_res_masks, size=feature.shape[-2:], mode='bilinear', align_corners=False)
|
299 |
+
|
300 |
+
feature = feature.flatten(2, 3)
|
301 |
+
low_res_masks = low_res_masks.flatten(2, 3)
|
302 |
+
masks_low_res = (low_res_masks > self.predictor.model.mask_threshold).float()
|
303 |
+
topk_idx = torch.topk(low_res_masks, 1)[1]
|
304 |
+
masks_low_res.scatter_(2, topk_idx, 1.0)
|
305 |
+
|
306 |
+
|
307 |
+
prototype_fea = (feature * masks_low_res).sum(dim=2) / masks_low_res.sum(dim=2)
|
308 |
+
prototype_fea = F.normalize(prototype_fea, dim=1)
|
309 |
+
self.prototype[crop_layer_idx].append(prototype_fea)
|
310 |
+
|
311 |
+
|
312 |
+
if crop_layer_idx == 0: # add reference gounding
|
313 |
+
x = ref_box[:, 0] + (ref_box[:, 2] - ref_box[:, 0]) / 2
|
314 |
+
y = ref_box[:, 1] + (ref_box[:, 3] - ref_box[:, 1]) / 2
|
315 |
+
points = torch.stack([x, y], dim=1)
|
316 |
+
data = MaskData(
|
317 |
+
masks=masks.flatten(0, 1),
|
318 |
+
iou_preds= torch.ones_like(iou_preds.flatten(0, 1)),
|
319 |
+
fea = prototype_fea,
|
320 |
+
points=points.cpu(),
|
321 |
+
stability_score = torch.ones_like(iou_preds.flatten(0, 1)),
|
322 |
+
)
|
323 |
+
data["boxes"] = batched_mask_to_box(data["masks"])
|
324 |
+
data["rles"] = mask_to_rle_pytorch(data["masks"])
|
325 |
+
del data["masks"]
|
326 |
+
else:
|
327 |
+
data = MaskData()
|
328 |
+
|
329 |
+
|
330 |
+
|
331 |
+
# Get points for this crop
|
332 |
+
points_scale = np.array(cropped_im_size)[None, ::-1]
|
333 |
+
points_for_image = self.point_grids[crop_layer_idx] * points_scale
|
334 |
+
|
335 |
+
# Generate masks for this crop in batches
|
336 |
+
for (points,) in batch_iterator(self.points_per_batch, points_for_image):
|
337 |
+
batch_data = self._process_batch(points, cropped_im_size,
|
338 |
+
crop_box, orig_size)
|
339 |
+
data.cat(batch_data)
|
340 |
+
del batch_data
|
341 |
+
self.predictor.reset_image()
|
342 |
+
|
343 |
+
# Remove duplicates within this crop.
|
344 |
+
keep_by_nms = batched_nms(
|
345 |
+
data["boxes"].float(),
|
346 |
+
data["iou_preds"],
|
347 |
+
torch.zeros(len(data["boxes"])), # categories
|
348 |
+
iou_threshold=self.box_nms_thresh,
|
349 |
+
)
|
350 |
+
data.filter(keep_by_nms)
|
351 |
+
|
352 |
+
# Return to the original image frame
|
353 |
+
data["boxes"] = uncrop_boxes_xyxy(data["boxes"], crop_box)
|
354 |
+
data["points"] = uncrop_points(data["points"], crop_box)
|
355 |
+
data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))])
|
356 |
+
|
357 |
+
return data
|
358 |
+
|
359 |
+
def _process_batch(
|
360 |
+
self,
|
361 |
+
points: np.ndarray,
|
362 |
+
im_size: Tuple[int, ...],
|
363 |
+
crop_box: List[int],
|
364 |
+
orig_size: Tuple[int, ...],
|
365 |
+
) -> MaskData:
|
366 |
+
orig_h, orig_w = orig_size
|
367 |
+
|
368 |
+
# Run model on this batch
|
369 |
+
transformed_points = self.predictor.transform.apply_coords(points, im_size)
|
370 |
+
in_points = torch.as_tensor(transformed_points, device=self.predictor.device)
|
371 |
+
in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device)
|
372 |
+
masks, iou_preds, low_res_masks = self.predictor.predict_torch(
|
373 |
+
in_points[:, None, :],
|
374 |
+
in_labels[:, None],
|
375 |
+
multimask_output=True,
|
376 |
+
return_logits=True,
|
377 |
+
)
|
378 |
+
|
379 |
+
feature = self.predictor.get_image_embedding()
|
380 |
+
low_res_masks=low_res_masks.flatten(0, 1)
|
381 |
+
low_res_masks = F.interpolate(low_res_masks[:, None, :, :], size=feature.shape[-2:],
|
382 |
+
mode='bilinear', align_corners=False)
|
383 |
+
# low_res_masks = low_res_masks > self.predictor.model.mask_threshold
|
384 |
+
|
385 |
+
# fea = feature.flatten(2, 3)
|
386 |
+
# low_res_masks = low_res_masks.flatten(2, 3)
|
387 |
+
# topk_idx = torch.topk(low_res_masks, 4)[1]
|
388 |
+
# fea = fea.expand(topk_idx.shape[0], -1, -1)
|
389 |
+
# topk_idx = topk_idx.expand(-1, fea.shape[1], -1)
|
390 |
+
# fea = fea.gather(2, topk_idx)
|
391 |
+
|
392 |
+
|
393 |
+
feature = feature.flatten(2, 3)
|
394 |
+
low_res_masks = low_res_masks.flatten(2, 3)
|
395 |
+
masks_low_res = (low_res_masks > self.predictor.model.mask_threshold).float()
|
396 |
+
topk_idx = torch.topk(low_res_masks, 1)[1]
|
397 |
+
masks_low_res.scatter_(2, topk_idx, 1.0)
|
398 |
+
pool_fea = (feature * masks_low_res).sum(dim=2) / masks_low_res.sum(dim=2)
|
399 |
+
pool_fea = F.normalize(pool_fea, dim=1)
|
400 |
+
|
401 |
+
# k_val = torch.topk(torch.flatten(low_res_masks, start_dim=2, end_dim=3), k=4, dim=-1)[0][:, :, -1, None]
|
402 |
+
# low_res_masks = (low_res_masks >= k_val).float()
|
403 |
+
# low_res_masks = low_res_masks.float()
|
404 |
+
# pool_fea = (feature * low_res_masks).sum(dim=(2, 3)) / low_res_masks.sum(dim=(2, 3))
|
405 |
+
|
406 |
+
|
407 |
+
|
408 |
+
# Serialize predictions and store in MaskData
|
409 |
+
data = MaskData(
|
410 |
+
masks=masks.flatten(0, 1),
|
411 |
+
iou_preds=iou_preds.flatten(0, 1),
|
412 |
+
points=torch.as_tensor(points.repeat(masks.shape[1], axis=0)),
|
413 |
+
fea = pool_fea,
|
414 |
+
)
|
415 |
+
del masks
|
416 |
+
|
417 |
+
|
418 |
+
# Filter by predicted IoU
|
419 |
+
if self.pred_iou_thresh > 0.0:
|
420 |
+
keep_mask = data["iou_preds"] > self.pred_iou_thresh
|
421 |
+
data.filter(keep_mask)
|
422 |
+
|
423 |
+
# Calculate stability score
|
424 |
+
data["stability_score"] = calculate_stability_score(
|
425 |
+
data["masks"], self.predictor.model.mask_threshold, self.stability_score_offset
|
426 |
+
)
|
427 |
+
if self.stability_score_thresh > 0.0:
|
428 |
+
keep_mask = data["stability_score"] >= self.stability_score_thresh
|
429 |
+
data.filter(keep_mask)
|
430 |
+
|
431 |
+
# Threshold masks and calculate boxes
|
432 |
+
data["masks"] = data["masks"] > self.predictor.model.mask_threshold
|
433 |
+
data["boxes"] = batched_mask_to_box(data["masks"])
|
434 |
+
|
435 |
+
# Filter boxes that touch crop boundaries
|
436 |
+
keep_mask = ~is_box_near_crop_edge(data["boxes"], crop_box, [0, 0, orig_w, orig_h])
|
437 |
+
if not torch.all(keep_mask):
|
438 |
+
data.filter(keep_mask)
|
439 |
+
|
440 |
+
# Compress to RLE
|
441 |
+
data["masks"] = uncrop_masks(data["masks"], crop_box, orig_h, orig_w)
|
442 |
+
data["rles"] = mask_to_rle_pytorch(data["masks"])
|
443 |
+
del data["masks"]
|
444 |
+
|
445 |
+
return data
|
446 |
+
|
447 |
+
@staticmethod
|
448 |
+
def postprocess_small_regions(
|
449 |
+
mask_data: MaskData, min_area: int, nms_thresh: float
|
450 |
+
) -> MaskData:
|
451 |
+
"""
|
452 |
+
Removes small disconnected regions and holes in masks, then reruns
|
453 |
+
box NMS to remove any new duplicates.
|
454 |
+
|
455 |
+
Edits mask_data in place.
|
456 |
+
|
457 |
+
Requires open-cv as a dependency.
|
458 |
+
"""
|
459 |
+
if len(mask_data["rles"]) == 0:
|
460 |
+
return mask_data
|
461 |
+
|
462 |
+
# Filter small disconnected regions and holes
|
463 |
+
new_masks = []
|
464 |
+
scores = []
|
465 |
+
for rle in mask_data["rles"]:
|
466 |
+
mask = rle_to_mask(rle)
|
467 |
+
|
468 |
+
mask, changed = remove_small_regions(mask, min_area, mode="holes")
|
469 |
+
unchanged = not changed
|
470 |
+
mask, changed = remove_small_regions(mask, min_area, mode="islands")
|
471 |
+
unchanged = unchanged and not changed
|
472 |
+
|
473 |
+
new_masks.append(torch.as_tensor(mask).unsqueeze(0))
|
474 |
+
# Give score=0 to changed masks and score=1 to unchanged masks
|
475 |
+
# so NMS will prefer ones that didn't need postprocessing
|
476 |
+
scores.append(float(unchanged))
|
477 |
+
|
478 |
+
# Recalculate boxes and remove any new duplicates
|
479 |
+
masks = torch.cat(new_masks, dim=0)
|
480 |
+
boxes = batched_mask_to_box(masks)
|
481 |
+
keep_by_nms = batched_nms(
|
482 |
+
boxes.float(),
|
483 |
+
torch.as_tensor(scores),
|
484 |
+
torch.zeros(len(boxes)), # categories
|
485 |
+
iou_threshold=nms_thresh,
|
486 |
+
)
|
487 |
+
|
488 |
+
# Only recalculate RLEs for masks that have changed
|
489 |
+
for i_mask in keep_by_nms:
|
490 |
+
if scores[i_mask] == 0.0:
|
491 |
+
mask_torch = masks[i_mask].unsqueeze(0)
|
492 |
+
mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0]
|
493 |
+
mask_data["boxes"][i_mask] = boxes[i_mask] # update res directly
|
494 |
+
mask_data.filter(keep_by_nms)
|
495 |
+
|
496 |
+
return mask_data
|
env.yaml
ADDED
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: ltce
|
2 |
+
channels:
|
3 |
+
- conda-forge
|
4 |
+
- defaults
|
5 |
+
dependencies:
|
6 |
+
- _libgcc_mutex=0.1=main
|
7 |
+
- _openmp_mutex=5.1=1_gnu
|
8 |
+
- asttokens=2.2.1=pyhd8ed1ab_0
|
9 |
+
- backcall=0.2.0=pyh9f0ad1d_0
|
10 |
+
- backports=1.0=pyhd8ed1ab_3
|
11 |
+
- backports.functools_lru_cache=1.6.4=pyhd8ed1ab_0
|
12 |
+
- blas=1.0=openblas
|
13 |
+
- brotli=1.0.9=h5eee18b_7
|
14 |
+
- brotli-bin=1.0.9=h5eee18b_7
|
15 |
+
- bzip2=1.0.8=h7b6447c_0
|
16 |
+
- ca-certificates=2023.01.10=h06a4308_0
|
17 |
+
- cairo=1.16.0=hb05425b_4
|
18 |
+
- certifi=2022.12.7=py310h06a4308_0
|
19 |
+
- contourpy=1.0.5=py310hdb19cb5_0
|
20 |
+
- cycler=0.11.0=pyhd3eb1b0_0
|
21 |
+
- dbus=1.13.18=hb2f20db_0
|
22 |
+
- debugpy=1.5.1=py310h295c915_0
|
23 |
+
- decorator=5.1.1=pyhd8ed1ab_0
|
24 |
+
- eigen=3.4.0=h4bd325d_0
|
25 |
+
- entrypoints=0.4=pyhd8ed1ab_0
|
26 |
+
- executing=1.2.0=pyhd8ed1ab_0
|
27 |
+
- expat=2.2.10=h9c3ff4c_0
|
28 |
+
- ffmpeg=4.2.2=h20bf706_0
|
29 |
+
- fontconfig=2.14.1=hef1e5e3_0
|
30 |
+
- fonttools=4.25.0=pyhd3eb1b0_0
|
31 |
+
- freetype=2.10.4=h0708190_1
|
32 |
+
- giflib=5.2.1=h36c2ea0_2
|
33 |
+
- glib=2.69.1=h4ff587b_1
|
34 |
+
- gmp=6.2.1=h58526e2_0
|
35 |
+
- gnutls=3.6.13=h85f3911_1
|
36 |
+
- graphite2=1.3.14=h295c915_1
|
37 |
+
- gst-plugins-base=1.14.1=h6a678d5_1
|
38 |
+
- gstreamer=1.14.1=h5eee18b_1
|
39 |
+
- harfbuzz=4.3.0=hf52aaf7_1
|
40 |
+
- hdf5=1.10.6=h3ffc7dd_1
|
41 |
+
- icu=58.2=hf484d3e_1000
|
42 |
+
- ipykernel=6.15.0=pyh210e3f2_0
|
43 |
+
- ipython=8.12.0=pyh41d4057_0
|
44 |
+
- jedi=0.18.2=pyhd8ed1ab_0
|
45 |
+
- jpeg=9e=h166bdaf_1
|
46 |
+
- jupyter_client=7.3.4=pyhd8ed1ab_0
|
47 |
+
- jupyter_core=5.3.0=py310hff52083_0
|
48 |
+
- keyutils=1.6.1=h166bdaf_0
|
49 |
+
- kiwisolver=1.4.4=py310h6a678d5_0
|
50 |
+
- krb5=1.19.3=h3790be6_0
|
51 |
+
- lame=3.100=h7f98852_1001
|
52 |
+
- lcms2=2.12=h3be6417_0
|
53 |
+
- ld_impl_linux-64=2.38=h1181459_1
|
54 |
+
- lerc=3.0=h295c915_0
|
55 |
+
- libblas=3.9.0=15_linux64_openblas
|
56 |
+
- libbrotlicommon=1.0.9=h5eee18b_7
|
57 |
+
- libbrotlidec=1.0.9=h5eee18b_7
|
58 |
+
- libbrotlienc=1.0.9=h5eee18b_7
|
59 |
+
- libcblas=3.9.0=15_linux64_openblas
|
60 |
+
- libclang=10.0.1=default_hb85057a_2
|
61 |
+
- libdeflate=1.17=h5eee18b_0
|
62 |
+
- libedit=3.1.20191231=he28a2e2_2
|
63 |
+
- libevent=2.1.12=h8f2d780_0
|
64 |
+
- libffi=3.3=he6710b0_2
|
65 |
+
- libgcc-ng=11.2.0=h1234567_1
|
66 |
+
- libgfortran-ng=12.2.0=h69a702a_19
|
67 |
+
- libgfortran5=12.2.0=h337968e_19
|
68 |
+
- libgomp=11.2.0=h1234567_1
|
69 |
+
- liblapack=3.9.0=15_linux64_openblas
|
70 |
+
- libllvm10=10.0.1=he513fc3_3
|
71 |
+
- libopenblas=0.3.20=pthreads_h78a6416_0
|
72 |
+
- libopus=1.3.1=h7f98852_1
|
73 |
+
- libpng=1.6.39=h5eee18b_0
|
74 |
+
- libpq=12.9=h16c4e8d_3
|
75 |
+
- libprotobuf=3.20.3=he621ea3_0
|
76 |
+
- libsodium=1.0.18=h36c2ea0_1
|
77 |
+
- libstdcxx-ng=11.2.0=h1234567_1
|
78 |
+
- libtiff=4.5.0=h6a678d5_2
|
79 |
+
- libuuid=1.41.5=h5eee18b_0
|
80 |
+
- libvpx=1.7.0=h439df22_0
|
81 |
+
- libwebp=1.2.4=h11a3e52_1
|
82 |
+
- libwebp-base=1.2.4=h5eee18b_1
|
83 |
+
- libxcb=1.15=h7f8727e_0
|
84 |
+
- libxkbcommon=1.0.1=hfa300c1_0
|
85 |
+
- libxml2=2.9.14=h74e7548_0
|
86 |
+
- libxslt=1.1.35=h4e12654_0
|
87 |
+
- lz4-c=1.9.3=h9c3ff4c_1
|
88 |
+
- matplotlib=3.7.1=py310h06a4308_1
|
89 |
+
- matplotlib-base=3.7.1=py310h1128e8f_1
|
90 |
+
- matplotlib-inline=0.1.6=pyhd8ed1ab_0
|
91 |
+
- munkres=1.1.4=py_0
|
92 |
+
- ncurses=6.4=h6a678d5_0
|
93 |
+
- nest-asyncio=1.5.6=pyhd8ed1ab_0
|
94 |
+
- nettle=3.6=he412f7d_0
|
95 |
+
- nspr=4.33=h295c915_0
|
96 |
+
- nss=3.74=h0370c37_0
|
97 |
+
- opencv=4.6.0=py310h1128e8f_3
|
98 |
+
- openh264=2.1.1=h4ff587b_0
|
99 |
+
- openjpeg=2.4.0=h3ad879b_0
|
100 |
+
- openssl=1.1.1t=h7f8727e_0
|
101 |
+
- packaging=23.1=pyhd8ed1ab_0
|
102 |
+
- parso=0.8.3=pyhd8ed1ab_0
|
103 |
+
- pcre=8.45=h9c3ff4c_0
|
104 |
+
- pexpect=4.8.0=pyh1a96a4e_2
|
105 |
+
- pickleshare=0.7.5=py_1003
|
106 |
+
- pip=23.0.1=py310h06a4308_0
|
107 |
+
- pixman=0.40.0=h36c2ea0_0
|
108 |
+
- platformdirs=3.2.0=pyhd8ed1ab_0
|
109 |
+
- ply=3.11=py310h06a4308_0
|
110 |
+
- prompt-toolkit=3.0.38=pyha770c72_0
|
111 |
+
- prompt_toolkit=3.0.38=hd8ed1ab_0
|
112 |
+
- psutil=5.9.0=py310h5eee18b_0
|
113 |
+
- ptyprocess=0.7.0=pyhd3deb0d_0
|
114 |
+
- pure_eval=0.2.2=pyhd8ed1ab_0
|
115 |
+
- pygments=2.15.0=pyhd8ed1ab_0
|
116 |
+
- pyparsing=3.0.9=py310h06a4308_0
|
117 |
+
- pyqt=5.15.7=py310h6a678d5_1
|
118 |
+
- python=3.10.4=h12debd9_0
|
119 |
+
- python-dateutil=2.8.2=pyhd8ed1ab_0
|
120 |
+
- python_abi=3.10=2_cp310
|
121 |
+
- pyzmq=23.2.0=py310h6a678d5_0
|
122 |
+
- qt-main=5.15.2=h327a75a_7
|
123 |
+
- qt-webengine=5.15.9=hd2b0992_4
|
124 |
+
- qtwebkit=5.212=h4eab89a_4
|
125 |
+
- readline=8.2=h5eee18b_0
|
126 |
+
- setuptools=65.6.3=py310h06a4308_0
|
127 |
+
- sip=6.6.2=py310h6a678d5_0
|
128 |
+
- six=1.16.0=pyh6c4a22f_0
|
129 |
+
- sqlite=3.41.2=h5eee18b_0
|
130 |
+
- stack_data=0.6.2=pyhd8ed1ab_0
|
131 |
+
- tk=8.6.12=h1ccaba5_0
|
132 |
+
- toml=0.10.2=pyhd3eb1b0_0
|
133 |
+
- tornado=6.1=py310h5764c6d_3
|
134 |
+
- tqdm=4.65.0=py310h2f386ee_0
|
135 |
+
- traitlets=5.9.0=pyhd8ed1ab_0
|
136 |
+
- typing-extensions=4.5.0=hd8ed1ab_0
|
137 |
+
- typing_extensions=4.5.0=pyha770c72_0
|
138 |
+
- tzdata=2023c=h04d1e81_0
|
139 |
+
- wcwidth=0.2.6=pyhd8ed1ab_0
|
140 |
+
- wheel=0.38.4=py310h06a4308_0
|
141 |
+
- x264=1!157.20191217=h7b6447c_0
|
142 |
+
- xz=5.2.10=h5eee18b_1
|
143 |
+
- zeromq=4.3.4=h9c3ff4c_1
|
144 |
+
- zlib=1.2.13=h5eee18b_0
|
145 |
+
- zstd=1.5.2=ha4553b6_0
|
146 |
+
- pip:
|
147 |
+
- charset-normalizer==3.1.0
|
148 |
+
- cmake==3.26.3
|
149 |
+
- filelock==3.11.0
|
150 |
+
- idna==3.4
|
151 |
+
- jinja2==3.1.2
|
152 |
+
- lit==16.0.1
|
153 |
+
- markupsafe==2.1.2
|
154 |
+
- mpmath==1.3.0
|
155 |
+
- networkx==3.1
|
156 |
+
- numpy==1.24.2
|
157 |
+
- nvidia-cublas-cu11==11.10.3.66
|
158 |
+
- nvidia-cuda-cupti-cu11==11.7.101
|
159 |
+
- nvidia-cuda-nvrtc-cu11==11.7.99
|
160 |
+
- nvidia-cuda-runtime-cu11==11.7.99
|
161 |
+
- nvidia-cudnn-cu11==8.5.0.96
|
162 |
+
- nvidia-cufft-cu11==10.9.0.58
|
163 |
+
- nvidia-curand-cu11==10.2.10.91
|
164 |
+
- nvidia-cusolver-cu11==11.4.0.1
|
165 |
+
- nvidia-cusparse-cu11==11.7.4.91
|
166 |
+
- nvidia-nccl-cu11==2.14.3
|
167 |
+
- nvidia-nvtx-cu11==11.7.91
|
168 |
+
- pillow==9.5.0
|
169 |
+
- pyqt5-sip==12.11.0
|
170 |
+
- requests==2.28.2
|
171 |
+
- segment-anything==1.0
|
172 |
+
- sympy==1.11.1
|
173 |
+
- torch==2.0.0
|
174 |
+
- torchaudio==2.0.1
|
175 |
+
- torchvision==0.15.1
|
176 |
+
- triton==2.0.0
|
177 |
+
- urllib3==1.26.15
|
example.png
ADDED
![]() |
resultFSC.png
ADDED
![]() |
resultcoco.png
ADDED
![]() |
test_FSC.py
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import argparse
|
3 |
+
import json
|
4 |
+
import numpy as np
|
5 |
+
from tqdm import tqdm
|
6 |
+
from os.path import exists
|
7 |
+
import os
|
8 |
+
|
9 |
+
from segment_anything import sam_model_registry
|
10 |
+
from automatic_mask_generator import SamAutomaticMaskGenerator
|
11 |
+
import matplotlib.pyplot as plt
|
12 |
+
|
13 |
+
|
14 |
+
|
15 |
+
|
16 |
+
parser = argparse.ArgumentParser(description="Few Shot Counting Evaluation code")
|
17 |
+
parser.add_argument("-dp", "--data_path", type=str, default='/data/counte/', help="Path to the FSC147 dataset")
|
18 |
+
parser.add_argument("-ts", "--test_split", type=str, default='val', choices=["val_PartA","val_PartB","test_PartA","test_PartB","test", "val"], help="what data split to evaluate on")
|
19 |
+
parser.add_argument("-mt", "--model_type", type=str, default="vit_h", help="model type")
|
20 |
+
parser.add_argument("-mp", "--model_path", type=str, default="/home/teddy/segment-anything/sam_vit_h_4b8939.pth", help="path to trained model")
|
21 |
+
parser.add_argument("-v", "--viz", type=bool, default=True, help="wether to visualize")
|
22 |
+
parser.add_argument("-d", "--device", default='0', help='assign device')
|
23 |
+
args = parser.parse_args()
|
24 |
+
|
25 |
+
data_path = args.data_path
|
26 |
+
anno_file = data_path + 'annotation_FSC147_384.json'
|
27 |
+
data_split_file = data_path + 'Train_Test_Val_FSC_147.json'
|
28 |
+
im_dir = data_path + 'images_384_VarV2'
|
29 |
+
|
30 |
+
|
31 |
+
if not exists(anno_file) or not exists(im_dir):
|
32 |
+
print("Make sure you set up the --data-path correctly.")
|
33 |
+
print("Current setting is {}, but the image dir and annotation file do not exist.".format(args.data_path))
|
34 |
+
print("Aborting the evaluation")
|
35 |
+
exit(-1)
|
36 |
+
|
37 |
+
def show_anns(anns):
|
38 |
+
if len(anns) == 0:
|
39 |
+
return
|
40 |
+
sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
|
41 |
+
ax = plt.gca()
|
42 |
+
ax.set_autoscale_on(False)
|
43 |
+
for ann in sorted_anns:
|
44 |
+
x0, y0, w, h = ann['bbox']
|
45 |
+
ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))
|
46 |
+
ax.scatter([x0+w//2], [y0+h//2], color='green', marker='*', s=10, edgecolor='white', linewidth=1.25)
|
47 |
+
|
48 |
+
|
49 |
+
debug = True
|
50 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = args.device.strip()
|
51 |
+
device = 'cuda'
|
52 |
+
sam = sam_model_registry[args.model_type](checkpoint=args.model_path)
|
53 |
+
sam.to(device=device)
|
54 |
+
|
55 |
+
|
56 |
+
mask_generator = SamAutomaticMaskGenerator(
|
57 |
+
model=sam,
|
58 |
+
min_mask_region_area=25
|
59 |
+
)
|
60 |
+
|
61 |
+
with open(anno_file) as f:
|
62 |
+
annotations = json.load(f)
|
63 |
+
|
64 |
+
with open(data_split_file) as f:
|
65 |
+
data_split = json.load(f)
|
66 |
+
|
67 |
+
|
68 |
+
cnt = 0
|
69 |
+
SAE = 0 # sum of absolute errors
|
70 |
+
SSE = 0 # sum of square errors
|
71 |
+
|
72 |
+
print("Evaluation on {} data".format(args.test_split))
|
73 |
+
im_ids = data_split[args.test_split]
|
74 |
+
|
75 |
+
# with open("err.json") as f:
|
76 |
+
# im_ids = json.load(f)
|
77 |
+
|
78 |
+
|
79 |
+
pbar = tqdm(im_ids)
|
80 |
+
# err_list = []
|
81 |
+
for im_id in pbar:
|
82 |
+
anno = annotations[im_id]
|
83 |
+
bboxes = anno['box_examples_coordinates']
|
84 |
+
dots = np.array(anno['points'])
|
85 |
+
|
86 |
+
image = cv2.imread('{}/{}'.format(im_dir, im_id))
|
87 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
88 |
+
|
89 |
+
input_boxes = list()
|
90 |
+
for bbox in bboxes:
|
91 |
+
x1, y1 = bbox[0][0], bbox[0][1]
|
92 |
+
x2, y2 = bbox[2][0], bbox[2][1]
|
93 |
+
input_boxes.append([x1, y1, x2, y2])
|
94 |
+
|
95 |
+
masks = mask_generator.generate(image, input_boxes)
|
96 |
+
if args.viz:
|
97 |
+
if not exists('viz'):
|
98 |
+
os.mkdir('viz')
|
99 |
+
plt.figure(figsize=(10,10))
|
100 |
+
plt.imshow(image)
|
101 |
+
show_anns(masks)
|
102 |
+
plt.axis('off')
|
103 |
+
plt.savefig('viz/{}'.format(im_id))
|
104 |
+
plt.close()
|
105 |
+
|
106 |
+
gt_cnt = dots.shape[0]
|
107 |
+
pred_cnt = len(masks)
|
108 |
+
cnt = cnt + 1
|
109 |
+
err = abs(gt_cnt - pred_cnt)
|
110 |
+
SAE += err
|
111 |
+
SSE += err**2
|
112 |
+
# if err / gt_cnt > 0.7:
|
113 |
+
# err_list.append(im_id)
|
114 |
+
|
115 |
+
pbar.set_description('{:<8}: actual-predicted: {:6d}, {:6.1f}, error: {:6.1f}. Current MAE: {:5.2f}, RMSE: {:5.2f}'.\
|
116 |
+
format(im_id, gt_cnt, pred_cnt, abs(pred_cnt - gt_cnt), SAE/cnt, (SSE/cnt)**0.5))
|
117 |
+
|
118 |
+
print('On {} data, MAE: {:6.2f}, RMSE: {:6.2f}'.format(args.test_split, SAE/cnt, (SSE/cnt)**0.5))
|
119 |
+
# with open('err.json', "w") as f:
|
120 |
+
# json.dump(err_list, f)
|
test_coco.py
ADDED
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import argparse
|
3 |
+
import json
|
4 |
+
import numpy as np
|
5 |
+
from tqdm import tqdm
|
6 |
+
from os.path import exists
|
7 |
+
import os
|
8 |
+
|
9 |
+
from segment_anything import sam_model_registry
|
10 |
+
from automatic_mask_generator import SamAutomaticMaskGenerator
|
11 |
+
import matplotlib.pyplot as plt
|
12 |
+
|
13 |
+
|
14 |
+
|
15 |
+
|
16 |
+
parser = argparse.ArgumentParser(description="Few Shot Counting Evaluation code")
|
17 |
+
parser.add_argument("-dp", "--data_path", type=str, default='/data/counte/', help="Path to the coco dataset")
|
18 |
+
parser.add_argument("-ts", "--test_split", type=str, default='val2017', choices=["val2017"], help="what data split to evaluate on")
|
19 |
+
parser.add_argument("-mt", "--model_type", type=str, default="vit_h", help="model type")
|
20 |
+
parser.add_argument("-mp", "--model_path", type=str, default="/home/teddy/segment-anything/sam_vit_h_4b8939.pth", help="path to trained model")
|
21 |
+
parser.add_argument("-v", "--viz", type=bool, default=True, help="wether to visualize")
|
22 |
+
parser.add_argument("-d", "--device", default='0', help='assign device')
|
23 |
+
args = parser.parse_args()
|
24 |
+
|
25 |
+
data_path = args.data_path
|
26 |
+
anno_file = data_path + 'annotations_trainval2017/annotations/instances_val2017.json'
|
27 |
+
im_dir = data_path + 'val2017'
|
28 |
+
|
29 |
+
|
30 |
+
if not exists(anno_file) or not exists(im_dir):
|
31 |
+
print("Make sure you set up the --data-path correctly.")
|
32 |
+
print("Current setting is {}, but the image dir and annotation file do not exist.".format(args.data_path))
|
33 |
+
print("Aborting the evaluation")
|
34 |
+
exit(-1)
|
35 |
+
|
36 |
+
def show_anns(anns):
|
37 |
+
if len(anns) == 0:
|
38 |
+
return
|
39 |
+
sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
|
40 |
+
ax = plt.gca()
|
41 |
+
ax.set_autoscale_on(False)
|
42 |
+
for ann in sorted_anns:
|
43 |
+
x0, y0, w, h = ann['bbox']
|
44 |
+
ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))
|
45 |
+
ax.scatter([x0+w//2], [y0+h//2], color='green', marker='*', s=10, edgecolor='white', linewidth=1.25)
|
46 |
+
|
47 |
+
|
48 |
+
debug = True
|
49 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = args.device.strip()
|
50 |
+
device = 'cuda'
|
51 |
+
sam = sam_model_registry[args.model_type](checkpoint=args.model_path)
|
52 |
+
sam.to(device=device)
|
53 |
+
|
54 |
+
|
55 |
+
mask_generator = SamAutomaticMaskGenerator(
|
56 |
+
model=sam,
|
57 |
+
min_mask_region_area=25
|
58 |
+
)
|
59 |
+
|
60 |
+
with open(anno_file) as f:
|
61 |
+
annotations = json.load(f)
|
62 |
+
|
63 |
+
images = sorted(annotations['images'],key=lambda x:x['file_name'])
|
64 |
+
|
65 |
+
prepared_json = {}
|
66 |
+
for i in images:
|
67 |
+
prepared_json[i['file_name']] = {
|
68 |
+
"H":i['height'],
|
69 |
+
"W":i['width'],
|
70 |
+
"boxes":{},
|
71 |
+
# "category_ids":[],
|
72 |
+
}
|
73 |
+
for i in annotations['annotations']:
|
74 |
+
im_id = str(i['image_id'])
|
75 |
+
prezero = 12 - len(im_id)
|
76 |
+
im_id = '0'*prezero + im_id + ".jpg"
|
77 |
+
if i["category_id"] in prepared_json[im_id]["boxes"]:
|
78 |
+
prepared_json[im_id]["boxes"][i["category_id"]].append(i['bbox'])
|
79 |
+
else:
|
80 |
+
prepared_json[im_id]["boxes"][i["category_id"]] = []
|
81 |
+
prepared_json[im_id]["boxes"][i["category_id"]].append(i['bbox'])
|
82 |
+
|
83 |
+
im_ids = []
|
84 |
+
for i in prepared_json.keys():
|
85 |
+
im_ids.append(i)
|
86 |
+
|
87 |
+
|
88 |
+
cnt = 0
|
89 |
+
folds = [
|
90 |
+
[1,5,9,14,18,22,27,33,37,41,46,50,54,58,62,67,74,78,82,87],
|
91 |
+
[2,6,10,15,19,23,28,34,38,42,47,51,55,59,63,70,75,79,84,88],
|
92 |
+
[3,7,11,16,20,24,31,35,39,43,48,52,56,60,64,72,76,80,85,89],
|
93 |
+
[4,8,13,17,21,25,32,36,40,44,49,53,57,61,65,73,77,81,86,90],
|
94 |
+
]
|
95 |
+
SAE = [0,0,0,0] # sum of absolute errors
|
96 |
+
SSE = [0,0,0,0] # sum of square errors
|
97 |
+
|
98 |
+
print("Evaluation on {} data".format(args.test_split))
|
99 |
+
|
100 |
+
# logs = []
|
101 |
+
|
102 |
+
|
103 |
+
pbar = tqdm(im_ids)
|
104 |
+
# err_list = []
|
105 |
+
for im_id in pbar:
|
106 |
+
category_id = list(prepared_json[im_id]['boxes'].keys())
|
107 |
+
|
108 |
+
image = cv2.imread('{}/{}'.format(im_dir, im_id))
|
109 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
110 |
+
# log = []
|
111 |
+
# log.append(im_id)
|
112 |
+
|
113 |
+
for id in category_id:
|
114 |
+
boxes = prepared_json[im_id]['boxes'][id]
|
115 |
+
|
116 |
+
input_boxes = list()
|
117 |
+
x1, y1 = boxes[0][0],boxes[0][1]
|
118 |
+
x2, y2 = boxes[0][0] + boxes[0][2],boxes[0][1] + boxes[0][3]
|
119 |
+
input_boxes.append([x1, y1, x2, y2])
|
120 |
+
|
121 |
+
masks = mask_generator.generate(image, input_boxes)
|
122 |
+
|
123 |
+
if args.viz:
|
124 |
+
if not exists('viz'):
|
125 |
+
os.mkdir('viz')
|
126 |
+
plt.figure(figsize=(10,10))
|
127 |
+
plt.imshow(image)
|
128 |
+
show_anns(masks)
|
129 |
+
plt.axis('off')
|
130 |
+
plt.savefig('viz/{}_{}.jpg'.format(im_id[0:-4],id))
|
131 |
+
plt.close()
|
132 |
+
|
133 |
+
gt_cnt = len(boxes)
|
134 |
+
pred_cnt = len(masks)
|
135 |
+
err = abs(gt_cnt - pred_cnt)
|
136 |
+
log.append("\n{},gt_cnt:{},pred_cnt:{}".format(id,gt_cnt,pred_cnt))
|
137 |
+
if id in folds[0]:
|
138 |
+
SAE[0] += err
|
139 |
+
SSE[0] += err**2
|
140 |
+
elif id in folds[1]:
|
141 |
+
SAE[1] += err
|
142 |
+
SSE[1] += err**2
|
143 |
+
elif id in folds[2]:
|
144 |
+
SAE[2] += err
|
145 |
+
SSE[2] += err**2
|
146 |
+
elif id in folds[3]:
|
147 |
+
SAE[3] += err
|
148 |
+
SSE[3] += err**2
|
149 |
+
|
150 |
+
cnt = cnt + 1
|
151 |
+
# logs.append(log)
|
152 |
+
pbar.set_description('fold1: {:6.2f}, fold2: {:6.2f}, fold3: {:6.2f}, fold4: {:6.2f},'.\
|
153 |
+
format(SAE[0]/cnt,SAE[1]/cnt,SAE[2]/cnt,SAE[3]/cnt))
|
154 |
+
|
155 |
+
print('On {} data, fold1 MAE: {:6.2f}, RMSE: {:6.2f}\n \
|
156 |
+
fold2 MAE: {:6.2f}, RMSE: {:6.2f}\n \
|
157 |
+
fold3 MAE: {:6.2f}, RMSE: {:6.2f}\n \
|
158 |
+
fold4 MAE: {:6.2f}, RMSE: {:6.2f}\n \
|
159 |
+
'.format(args.test_split,SAE[0]/cnt,(SSE[0]/cnt)**0.5,SAE[1]/cnt,(SSE[1]/cnt)**0.5,SAE[2]/cnt,(SSE[2]/cnt)**0.5,SAE[3]/cnt,(SSE[3]/cnt)**0.5))
|
vis_FSC.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
vis_coco.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|