HarborYuan commited on
Commit
a1202ed
1 Parent(s): a78077d
Files changed (2) hide show
  1. app/models/detectors/rapsam.py +2 -2
  2. main.py +23 -9
app/models/detectors/rapsam.py CHANGED
@@ -60,7 +60,7 @@ class YOSOVideoSam(Mask2formerVideo):
60
  iou_results=iou_results,
61
  rescale=False
62
  )
63
- mask_pred_results = results_list[0]['pan_results'].sem_seg[None]
64
- mask_cls_results = mask_cls_results
65
 
66
  return mask_pred_results.cpu().numpy(), mask_cls_results.cpu().numpy()
 
60
  iou_results=iou_results,
61
  rescale=False
62
  )
63
+ # mask_pred_results = results_list[0]['pan_results'].sem_seg[None]
64
+ return results_list, None
65
 
66
  return mask_pred_results.cpu().numpy(), mask_cls_results.cpu().numpy()
main.py CHANGED
@@ -9,9 +9,12 @@ from PIL import Image
9
  # mm libs
10
  from mmdet.registry import MODELS
11
  from mmdet.structures import DetDataSample
 
12
  from mmengine import Config, print_log
13
  from mmengine.structures import InstanceData
14
 
 
 
15
  from PIL import ImageDraw
16
 
17
  IMG_SIZE = 1024
@@ -30,6 +33,7 @@ model.init_weights()
30
  mean = torch.tensor([123.675, 116.28, 103.53], device=device)[:, None, None]
31
  std = torch.tensor([58.395, 57.12, 57.375], device=device)[:, None, None]
32
 
 
33
 
34
  examples = [
35
  ["assets/000000000139.jpg"],
@@ -126,22 +130,32 @@ def segment_point(image, img_state):
126
  batch_data_samples[0].gt_instances_collected = gt_instances
127
  batch_data_samples[0].set_metainfo(dict(batch_input_shape=(im_h, im_w)))
128
  batch_data_samples[0].set_metainfo(dict(img_shape=(h, w)))
 
129
  else:
130
  batch_data_samples = [DetDataSample()]
131
  batch_data_samples[0].set_metainfo(dict(batch_input_shape=(im_h, im_w)))
132
  batch_data_samples[0].set_metainfo(dict(img_shape=(h, w)))
 
133
  with torch.no_grad():
134
  masks, cls_pred = model.predict_with_point(img_tensor, batch_data_samples)
135
 
136
- masks = masks[0, 0, :h, :w]
137
- masks = masks > 0.
138
- rgb_shape = tuple(list(masks.shape) + [3])
139
- color = np.zeros(rgb_shape, dtype=np.uint8)
140
- color[masks] = np.array([97, 217, 54])
141
- # color[masks] = np.array([217, 90, 54])
142
- output_img = (output_img * 0.7 + color * 0.3).astype(np.uint8)
143
-
144
- output_img = Image.fromarray(output_img)
 
 
 
 
 
 
 
 
145
  return image, output_img
146
 
147
 
 
9
  # mm libs
10
  from mmdet.registry import MODELS
11
  from mmdet.structures import DetDataSample
12
+ from mmdet.visualization import DetLocalVisualizer
13
  from mmengine import Config, print_log
14
  from mmengine.structures import InstanceData
15
 
16
+ from mmdet.datasets.coco_panoptic import CocoPanopticDataset
17
+
18
  from PIL import ImageDraw
19
 
20
  IMG_SIZE = 1024
 
33
  mean = torch.tensor([123.675, 116.28, 103.53], device=device)[:, None, None]
34
  std = torch.tensor([58.395, 57.12, 57.375], device=device)[:, None, None]
35
 
36
+ visualizer = DetLocalVisualizer()
37
 
38
  examples = [
39
  ["assets/000000000139.jpg"],
 
130
  batch_data_samples[0].gt_instances_collected = gt_instances
131
  batch_data_samples[0].set_metainfo(dict(batch_input_shape=(im_h, im_w)))
132
  batch_data_samples[0].set_metainfo(dict(img_shape=(h, w)))
133
+ is_prompt = True
134
  else:
135
  batch_data_samples = [DetDataSample()]
136
  batch_data_samples[0].set_metainfo(dict(batch_input_shape=(im_h, im_w)))
137
  batch_data_samples[0].set_metainfo(dict(img_shape=(h, w)))
138
+ is_prompt = False
139
  with torch.no_grad():
140
  masks, cls_pred = model.predict_with_point(img_tensor, batch_data_samples)
141
 
142
+ assert len(masks) == 1
143
+ masks = masks[0]
144
+ if is_prompt:
145
+ masks = masks[0, :h, :w]
146
+ masks = masks > 0. # no sigmoid
147
+ rgb_shape = tuple(list(masks.shape) + [3])
148
+ color = np.zeros(rgb_shape, dtype=np.uint8)
149
+ color[masks] = np.array([97, 217, 54])
150
+ output_img = (output_img * 0.7 + color * 0.3).astype(np.uint8)
151
+ output_img = Image.fromarray(output_img)
152
+ else:
153
+ output_img = visualizer._draw_panoptic_seg(
154
+ output_img,
155
+ masks['pan_results'].to('cpu').numpy(),
156
+ classes=CocoPanopticDataset.METAINFO['classes'],
157
+ palette=CocoPanopticDataset.METAINFO['palette']
158
+ )
159
  return image, output_img
160
 
161