HarborYuan
commited on
Commit
•
a1202ed
1
Parent(s):
a78077d
add vis
Browse files- app/models/detectors/rapsam.py +2 -2
- main.py +23 -9
app/models/detectors/rapsam.py
CHANGED
@@ -60,7 +60,7 @@ class YOSOVideoSam(Mask2formerVideo):
|
|
60 |
iou_results=iou_results,
|
61 |
rescale=False
|
62 |
)
|
63 |
-
mask_pred_results = results_list[0]['pan_results'].sem_seg[None]
|
64 |
-
|
65 |
|
66 |
return mask_pred_results.cpu().numpy(), mask_cls_results.cpu().numpy()
|
|
|
60 |
iou_results=iou_results,
|
61 |
rescale=False
|
62 |
)
|
63 |
+
# mask_pred_results = results_list[0]['pan_results'].sem_seg[None]
|
64 |
+
return results_list, None
|
65 |
|
66 |
return mask_pred_results.cpu().numpy(), mask_cls_results.cpu().numpy()
|
main.py
CHANGED
@@ -9,9 +9,12 @@ from PIL import Image
|
|
9 |
# mm libs
|
10 |
from mmdet.registry import MODELS
|
11 |
from mmdet.structures import DetDataSample
|
|
|
12 |
from mmengine import Config, print_log
|
13 |
from mmengine.structures import InstanceData
|
14 |
|
|
|
|
|
15 |
from PIL import ImageDraw
|
16 |
|
17 |
IMG_SIZE = 1024
|
@@ -30,6 +33,7 @@ model.init_weights()
|
|
30 |
mean = torch.tensor([123.675, 116.28, 103.53], device=device)[:, None, None]
|
31 |
std = torch.tensor([58.395, 57.12, 57.375], device=device)[:, None, None]
|
32 |
|
|
|
33 |
|
34 |
examples = [
|
35 |
["assets/000000000139.jpg"],
|
@@ -126,22 +130,32 @@ def segment_point(image, img_state):
|
|
126 |
batch_data_samples[0].gt_instances_collected = gt_instances
|
127 |
batch_data_samples[0].set_metainfo(dict(batch_input_shape=(im_h, im_w)))
|
128 |
batch_data_samples[0].set_metainfo(dict(img_shape=(h, w)))
|
|
|
129 |
else:
|
130 |
batch_data_samples = [DetDataSample()]
|
131 |
batch_data_samples[0].set_metainfo(dict(batch_input_shape=(im_h, im_w)))
|
132 |
batch_data_samples[0].set_metainfo(dict(img_shape=(h, w)))
|
|
|
133 |
with torch.no_grad():
|
134 |
masks, cls_pred = model.predict_with_point(img_tensor, batch_data_samples)
|
135 |
|
136 |
-
|
137 |
-
masks = masks
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
145 |
return image, output_img
|
146 |
|
147 |
|
|
|
9 |
# mm libs
|
10 |
from mmdet.registry import MODELS
|
11 |
from mmdet.structures import DetDataSample
|
12 |
+
from mmdet.visualization import DetLocalVisualizer
|
13 |
from mmengine import Config, print_log
|
14 |
from mmengine.structures import InstanceData
|
15 |
|
16 |
+
from mmdet.datasets.coco_panoptic import CocoPanopticDataset
|
17 |
+
|
18 |
from PIL import ImageDraw
|
19 |
|
20 |
IMG_SIZE = 1024
|
|
|
33 |
mean = torch.tensor([123.675, 116.28, 103.53], device=device)[:, None, None]
|
34 |
std = torch.tensor([58.395, 57.12, 57.375], device=device)[:, None, None]
|
35 |
|
36 |
+
visualizer = DetLocalVisualizer()
|
37 |
|
38 |
examples = [
|
39 |
["assets/000000000139.jpg"],
|
|
|
130 |
batch_data_samples[0].gt_instances_collected = gt_instances
|
131 |
batch_data_samples[0].set_metainfo(dict(batch_input_shape=(im_h, im_w)))
|
132 |
batch_data_samples[0].set_metainfo(dict(img_shape=(h, w)))
|
133 |
+
is_prompt = True
|
134 |
else:
|
135 |
batch_data_samples = [DetDataSample()]
|
136 |
batch_data_samples[0].set_metainfo(dict(batch_input_shape=(im_h, im_w)))
|
137 |
batch_data_samples[0].set_metainfo(dict(img_shape=(h, w)))
|
138 |
+
is_prompt = False
|
139 |
with torch.no_grad():
|
140 |
masks, cls_pred = model.predict_with_point(img_tensor, batch_data_samples)
|
141 |
|
142 |
+
assert len(masks) == 1
|
143 |
+
masks = masks[0]
|
144 |
+
if is_prompt:
|
145 |
+
masks = masks[0, :h, :w]
|
146 |
+
masks = masks > 0. # no sigmoid
|
147 |
+
rgb_shape = tuple(list(masks.shape) + [3])
|
148 |
+
color = np.zeros(rgb_shape, dtype=np.uint8)
|
149 |
+
color[masks] = np.array([97, 217, 54])
|
150 |
+
output_img = (output_img * 0.7 + color * 0.3).astype(np.uint8)
|
151 |
+
output_img = Image.fromarray(output_img)
|
152 |
+
else:
|
153 |
+
output_img = visualizer._draw_panoptic_seg(
|
154 |
+
output_img,
|
155 |
+
masks['pan_results'].to('cpu').numpy(),
|
156 |
+
classes=CocoPanopticDataset.METAINFO['classes'],
|
157 |
+
palette=CocoPanopticDataset.METAINFO['palette']
|
158 |
+
)
|
159 |
return image, output_img
|
160 |
|
161 |
|