Spaces:
Sleeping
Sleeping
File size: 11,976 Bytes
b34d1d6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 |
# Copyright (c) OpenMMLab. All rights reserved.
import os
import torch
from scipy.optimize import linear_sum_assignment
from torch import Tensor
import torch.nn.functional as F
from mmdet.registry import MODELS
from mmdet.structures import SampleList, TrackDataSample
from seg.models.detectors import Mask2formerVideo
from seg.models.utils import mask_pool
BACKBONE_BATCH = 50
def video_split(total, tube_size, overlap=0):
assert tube_size > overlap
total -= overlap
tube_size -= overlap
if total % tube_size == 0:
splits = total // tube_size
else:
splits = (total // tube_size) + 1
ind_list = []
for i in range(splits):
ind_list.append((i + 1) * tube_size)
diff = ind_list[-1] - total
# currently only supports diff < splits
if diff < splits:
for i in range(diff):
ind_list[splits - 1 - i] -= diff - i
else:
ind_list[splits - 1] -= diff
assert ind_list[splits - 1] > 0
print("Warning: {} / {}".format(total, tube_size))
for idx in range(len(ind_list)):
ind_list[idx] += overlap
return ind_list
def match_from_embeds(tgt_embds, cur_embds):
cur_embds = cur_embds / cur_embds.norm(dim=-1, keepdim=True)
tgt_embds = tgt_embds / tgt_embds.norm(dim=-1, keepdim=True)
cos_sim = torch.bmm(cur_embds, tgt_embds.transpose(1, 2))
cost_embd = 1 - cos_sim
C = 1.0 * cost_embd
C = C.cpu()
indices = []
for i in range(len(cur_embds)):
indice = linear_sum_assignment(C[i].transpose(0, 1)) # target x current
indice = indice[1] # permutation that makes current aligns to target
indices.append(indice)
return indices
@MODELS.register_module()
class Mask2formerVideoMinVIS(Mask2formerVideo):
r"""Implementation of `Per-Pixel Classification is
NOT All You Need for Semantic Segmentation
<https://arxiv.org/pdf/2107.06278>`_."""
OVERLAPPING = None
def __init__(self,
*args,
clip_size=6,
clip_size_small=3,
whole_clip_thr=0,
small_clip_thr=12,
overlap=0,
**kwargs,
):
super().__init__(*args, **kwargs)
self.clip_size = clip_size
self.clip_size_small = clip_size_small
self.overlap = overlap
self.whole_clip_thr = whole_clip_thr
self.small_clip_thr = small_clip_thr
def predict(self,
batch_inputs: Tensor,
batch_data_samples: SampleList,
rescale: bool = True) -> SampleList:
"""Predict results from a batch of inputs and data samples with post-
processing.
Args:
batch_inputs (Tensor): Inputs with shape (N, C, H, W).
batch_data_samples (List[:obj:`DetDataSample`]): The Data
Samples. It usually includes information such as
`gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`.
rescale (bool): Whether to rescale the results.
Defaults to True.
Returns:
list[:obj:`DetDataSample`]: Detection results of the
input images. Each DetDataSample usually contain
'pred_instances' and `pred_panoptic_seg`. And the
``pred_instances`` usually contains following keys.
- scores (Tensor): Classification scores, has a shape
(num_instance, )
- labels (Tensor): Labels of bboxes, has a shape
(num_instances, ).
- bboxes (Tensor): Has a shape (num_instances, 4),
the last dimension 4 arrange as (x1, y1, x2, y2).
- masks (Tensor): Has a shape (num_instances, H, W).
And the ``pred_panoptic_seg`` contains the following key
- sem_seg (Tensor): panoptic segmentation mask, has a
shape (1, h, w).
"""
assert isinstance(batch_data_samples[0], TrackDataSample)
bs, num_frames, three, h, w = batch_inputs.shape
assert three == 3, "Only supporting images with 3 channels."
if num_frames <= self.whole_clip_thr:
return super().predict(batch_inputs, batch_data_samples, rescale)
device = batch_inputs.device
if num_frames > self.small_clip_thr:
tube_inds = video_split(num_frames, self.clip_size, self.overlap)
else:
tube_inds = video_split(num_frames, self.clip_size_small, self.overlap)
if num_frames > BACKBONE_BATCH:
feat_bins = [[], [], [], []]
num_clip = num_frames // BACKBONE_BATCH + 1
step_size = num_frames // num_clip + 1
for i in range(num_clip):
start = i * step_size
end = min(num_frames, (i + 1) * step_size)
inputs = batch_inputs[:, start:end].reshape(
(bs * (end - start), three, h, w))
_feats = self.extract_feat(inputs)
assert len(_feats) == 4
for idx, item in enumerate(_feats):
feat_bins[idx].append(item.to('cpu'))
feats = []
for item in feat_bins:
feat = torch.cat(item, dim=0)
assert feat.size(0) == bs * num_frames, "{} vs {}".format(feat.size(0), bs * num_frames)
feats.append(feat)
else:
x = batch_inputs.reshape((bs * num_frames, three, h, w))
feats = self.extract_feat(x)
assert len(feats[0]) == bs * num_frames
del batch_inputs
ind_pre = 0
cls_list = []
mask_list = []
query_list = []
iou_list = []
flag = False
for ind in tube_inds:
tube_feats = [itm[ind_pre:ind].to(device=device) for itm in feats]
tube_data_samples = [TrackDataSample(video_data_samples=itm[ind_pre:ind]) for itm in batch_data_samples]
_mask_cls_results, _mask_pred_results, _query_feat, _iou_results = \
self.panoptic_head.predict(tube_feats, tube_data_samples, return_query=True)
cls_list.append(_mask_cls_results)
if not flag:
mask_list.append(_mask_pred_results.cpu())
flag = True
else:
mask_list.append(_mask_pred_results[:, self.overlap:].cpu())
query_list.append(_query_feat.cpu())
iou_list.append(_iou_results)
ind_pre = ind
ind_pre -= self.overlap
num_tubes = len(tube_inds)
out_cls = [cls_list[0]]
out_mask = [mask_list[0]]
out_embed = [query_list[0]]
ious = [iou_list[0]]
for i in range(1, num_tubes):
indices = match_from_embeds(out_embed[-1], query_list[i])
indices = indices[0] # since bs == 1
out_cls.append(cls_list[i][:, indices])
out_mask.append(mask_list[i][:, indices])
out_embed.append(query_list[i][:, indices])
ious.append(iou_list[i][:, indices])
del mask_list
del out_embed
mask_cls_results = sum(out_cls) / num_tubes
mask_pred_results = torch.cat(out_mask, dim=2)
iou_results = sum(ious) / num_tubes
if self.OVERLAPPING is not None:
assert len(self.OVERLAPPING) == self.num_classes
mask_cls_results = self.open_voc_inference(feats, mask_cls_results, mask_pred_results)
del feats
mask_cls_results = mask_cls_results.to(device='cpu')
iou_results = iou_results.to(device='cpu')
id_assigner = [{} for _ in range(bs)]
for frame_id in range(num_frames):
results_list_img = self.panoptic_fusion_head.predict(
mask_cls_results,
mask_pred_results[:, :, frame_id],
[batch_data_samples[idx][frame_id] for idx in range(bs)],
iou_results=iou_results,
rescale=rescale
)
if frame_id == 0 and 'pro_results' in results_list_img[0]:
for batch_id in range(bs):
mask = results_list_img[batch_id]['pro_results'].to(dtype=torch.int32)
mask_gt = torch.tensor(batch_data_samples[batch_id][frame_id].gt_instances.masks.masks, dtype=torch.int32)
a, b = mask.flatten(1), mask_gt.flatten(1)
intersection = torch.einsum('nc,mc->nm', a, b)
union = (a[:, None] + b[None]).clamp(min=0, max=1).sum(-1)
iou_cost = intersection / union
a_indices, b_indices = linear_sum_assignment(-iou_cost.numpy())
for a_ind, b_ind in zip(a_indices, b_indices):
id_assigner[batch_id][a_ind] = batch_data_samples[batch_id][frame_id].gt_instances.instances_ids[b_ind].item()
if 'pro_results' in results_list_img[0]:
h, w = results_list_img[batch_id]['pro_results'].shape[-2:]
seg_map = torch.full((h, w), 0, dtype=torch.int32, device='cpu')
for ind in id_assigner[batch_id]:
seg_map[results_list_img[batch_id]['pro_results'][ind]] = id_assigner[batch_id][ind]
results_list_img[batch_id]['pro_results'] = seg_map.cpu().numpy()
_ = self.add_track_pred_to_datasample(
[batch_data_samples[idx][frame_id] for idx in range(bs)], results_list_img
)
results = batch_data_samples
return results
def open_voc_inference(self, feats, mask_cls_results, mask_pred_results):
if len(mask_pred_results.shape) == 5:
batch_size = mask_cls_results.shape[0]
num_frames = mask_pred_results.shape[2]
mask_pred_results = mask_pred_results.permute(0, 2, 1, 3, 4).flatten(0, 1)
else:
batch_size = mask_cls_results.shape[0]
num_frames = 0
clip_feat = self.backbone.get_clip_feature(feats[-1]).to(device=mask_cls_results.device)
clip_feat_mask = F.interpolate(
mask_pred_results,
size=clip_feat.shape[-2:],
mode='bilinear',
align_corners=False
).to(device=mask_cls_results.device)
if num_frames > 0:
clip_feat_mask = clip_feat_mask.unflatten(0, (batch_size, num_frames)).permute(0, 2, 1, 3, 4).flatten(2, 3)
clip_feat = clip_feat.unflatten(0, (batch_size, num_frames)).permute(0, 2, 1, 3, 4).flatten(2, 3)
instance_feat = mask_pool(clip_feat, clip_feat_mask)
instance_feat = self.backbone.forward_feat(instance_feat)
clip_logit = self.panoptic_head.forward_logit(instance_feat)
clip_logit = clip_logit[..., :-1]
query_logit = mask_cls_results[..., :-1]
clip_logit = clip_logit.softmax(-1)
query_logit = query_logit.softmax(-1)
overlapping_mask = torch.tensor(self.OVERLAPPING, dtype=torch.float32, device=clip_logit.device)
valid_masking = ((clip_feat_mask > 0).to(dtype=torch.float32).flatten(-2).sum(-1) > 0).to(
torch.float32)[..., None]
alpha = torch.ones_like(clip_logit) * self.alpha * valid_masking
beta = torch.ones_like(clip_logit) * self.beta * valid_masking
cls_logits_seen = (
(query_logit ** (1 - alpha) * clip_logit ** alpha).log()
* overlapping_mask
)
cls_logits_unseen = (
(query_logit ** (1 - beta) * clip_logit ** beta).log()
* (1 - overlapping_mask)
)
cls_results = cls_logits_seen + cls_logits_unseen
is_void_prob = F.softmax(mask_cls_results, dim=-1)[..., -1:]
mask_cls_results = torch.cat([
cls_results.softmax(-1) * (1.0 - is_void_prob), is_void_prob], dim=-1)
mask_cls_results = torch.log(mask_cls_results + 1e-8)
return mask_cls_results
|