Spaces:
Sleeping
Sleeping
# Copyright (c) OpenMMLab. All rights reserved. | |
from typing import List | |
import torch | |
import torch.nn.functional as F | |
from mmengine.structures import InstanceData, PixelData | |
from torch import Tensor | |
from mmdet.evaluation.functional import INSTANCE_OFFSET | |
from mmdet.registry import MODELS | |
from mmdet.structures import SampleList | |
from mmdet.structures.mask import mask2bbox | |
from mmdet.utils import OptConfigType, OptMultiConfig | |
from mmdet.models.seg_heads.panoptic_fusion_heads.base_panoptic_fusion_head import BasePanopticFusionHead | |
class OMGFusionHead(BasePanopticFusionHead): | |
def __init__( | |
self, | |
num_things_classes: int = 80, | |
num_stuff_classes: int = 53, | |
test_cfg: OptConfigType = None, | |
loss_panoptic: OptConfigType = None, | |
init_cfg: OptMultiConfig = None, | |
**kwargs | |
): | |
super().__init__( | |
num_things_classes=num_things_classes, | |
num_stuff_classes=num_stuff_classes, | |
test_cfg=test_cfg, | |
loss_panoptic=loss_panoptic, | |
init_cfg=init_cfg, | |
**kwargs) | |
def loss(self, **kwargs): | |
"""MaskFormerFusionHead has no training loss.""" | |
return dict() | |
def panoptic_postprocess(self, mask_cls: Tensor, | |
mask_pred: Tensor) -> PixelData: | |
"""Panoptic segmengation inference. | |
Args: | |
mask_cls (Tensor): Classfication outputs of shape | |
(num_queries, cls_out_channels) for a image. | |
Note `cls_out_channels` should includes | |
background. | |
mask_pred (Tensor): Mask outputs of shape | |
(num_queries, h, w) for a image. | |
Returns: | |
:obj:`PixelData`: Panoptic segment result of shape \ | |
(h, w), each element in Tensor means: \ | |
``segment_id = _cls + instance_id * INSTANCE_OFFSET``. | |
""" | |
object_mask_thr = self.test_cfg.get('object_mask_thr', 0.8) | |
iou_thr = self.test_cfg.get('iou_thr', 0.8) | |
filter_low_score = self.test_cfg.get('filter_low_score', False) | |
scores, labels = F.softmax(mask_cls, dim=-1).max(-1) | |
mask_pred = mask_pred.sigmoid() | |
keep = labels.ne(self.num_classes) & (scores > object_mask_thr) | |
cur_scores = scores[keep] | |
cur_classes = labels[keep] | |
cur_masks = mask_pred[keep] | |
cur_prob_masks = cur_scores.view(-1, 1, 1) * cur_masks | |
h, w = cur_masks.shape[-2:] | |
panoptic_seg = torch.full((h, w), | |
self.num_classes, | |
dtype=torch.int32, | |
device=cur_masks.device) | |
if cur_masks.shape[0] == 0: | |
# We didn't detect any mask :( | |
pass | |
else: | |
cur_mask_ids = cur_prob_masks.argmax(0) | |
instance_id = 1 | |
for k in range(cur_classes.shape[0]): | |
pred_class = int(cur_classes[k].item()) | |
isthing = pred_class < self.num_things_classes | |
mask = cur_mask_ids == k | |
mask_area = mask.sum().item() | |
original_area = (cur_masks[k] >= 0.5).sum().item() | |
if filter_low_score: | |
mask = mask & (cur_masks[k] >= 0.5) | |
if mask_area > 0 and original_area > 0: | |
if mask_area / original_area < iou_thr: | |
continue | |
if not isthing: | |
# different stuff regions of same class will be | |
# merged here, and stuff share the instance_id 0. | |
panoptic_seg[mask] = pred_class | |
else: | |
panoptic_seg[mask] = ( | |
pred_class + instance_id * INSTANCE_OFFSET) | |
instance_id += 1 | |
return PixelData(sem_seg=panoptic_seg[None]) | |
def semantic_postprocess(self, mask_cls: Tensor, | |
mask_pred: Tensor) -> PixelData: | |
"""Semantic segmengation postprocess. | |
Args: | |
mask_cls (Tensor): Classfication outputs of shape | |
(num_queries, cls_out_channels) for a image. | |
Note `cls_out_channels` should includes | |
background. | |
mask_pred (Tensor): Mask outputs of shape | |
(num_queries, h, w) for a image. | |
Returns: | |
:obj:`PixelData`: Semantic segment result. | |
""" | |
# TODO add semantic segmentation result | |
raise NotImplementedError | |
def instance_postprocess(self, mask_cls: Tensor, | |
mask_pred: Tensor) -> InstanceData: | |
"""Instance segmengation postprocess. | |
Args: | |
mask_cls (Tensor): Classfication outputs of shape | |
(num_queries, cls_out_channels) for a image. | |
Note `cls_out_channels` should includes | |
background. | |
mask_pred (Tensor): Mask outputs of shape | |
(num_queries, h, w) for a image. | |
Returns: | |
:obj:`InstanceData`: Instance segmentation results. | |
- scores (Tensor): Classification scores, has a shape | |
(num_instance, ) | |
- labels (Tensor): Labels of bboxes, has a shape | |
(num_instances, ). | |
- bboxes (Tensor): Has a shape (num_instances, 4), | |
the last dimension 4 arrange as (x1, y1, x2, y2). | |
- masks (Tensor): Has a shape (num_instances, H, W). | |
""" | |
max_per_image = self.test_cfg.get('max_per_image', 100) | |
num_queries = mask_cls.shape[0] | |
# shape (num_queries, num_class) | |
scores = F.softmax(mask_cls, dim=-1)[:, :-1] | |
# shape (num_queries * num_class, ) | |
labels = torch.arange(self.num_classes, device=mask_cls.device). \ | |
unsqueeze(0).repeat(num_queries, 1).flatten(0, 1) | |
scores_per_image, top_indices = scores.flatten(0, 1).topk( | |
max_per_image, sorted=False) | |
labels_per_image = labels[top_indices] | |
query_indices = top_indices // self.num_classes | |
mask_pred = mask_pred[query_indices] | |
# extract things | |
is_thing = labels_per_image < self.num_things_classes | |
scores_per_image = scores_per_image[is_thing] | |
labels_per_image = labels_per_image[is_thing] | |
mask_pred = mask_pred[is_thing] | |
mask_pred_binary = (mask_pred > 0).float() | |
mask_scores_per_image = (mask_pred.sigmoid() * | |
mask_pred_binary).flatten(1).sum(1) / ( | |
mask_pred_binary.flatten(1).sum(1) + 1e-6) | |
det_scores = scores_per_image * mask_scores_per_image | |
mask_pred_binary = mask_pred_binary.bool() | |
bboxes = mask2bbox(mask_pred_binary) | |
results = InstanceData() | |
results.bboxes = bboxes | |
results.labels = labels_per_image | |
results.scores = det_scores | |
results.masks = mask_pred_binary | |
return results | |
def proposal_postprocess(self, mask_score: Tensor, mask_pred: Tensor) -> InstanceData: | |
max_per_image = self.test_cfg.get('num_proposals', 10) | |
h, w = mask_pred.shape[-2:] | |
# shape (num_queries, num_class) | |
scores = mask_score.sigmoid().squeeze(-1) | |
scores_per_image, top_indices = scores.topk(max_per_image, sorted=True) | |
mask_selected = mask_pred[top_indices] | |
proposals = [] | |
for idx in range(len(mask_selected)): | |
mask = mask_selected[len(mask_selected) - idx - 1] | |
proposals.append(mask.sigmoid() > .5) | |
seg_map = torch.stack(proposals) | |
return seg_map | |
def predict(self, | |
mask_cls_results: Tensor, | |
mask_pred_results: Tensor, | |
batch_data_samples: SampleList, | |
iou_results=None, | |
rescale: bool = False, | |
**kwargs) -> List[dict]: | |
"""Test segment without test-time aumengtation. | |
Only the output of last decoder layers was used. | |
Args: | |
mask_cls_results (Tensor): Mask classification logits, | |
shape (batch_size, num_queries, cls_out_channels). | |
Note `cls_out_channels` should includes background. | |
mask_pred_results (Tensor): Mask logits, shape | |
(batch_size, num_queries, h, w). | |
batch_data_samples (List[:obj:`DetDataSample`]): The Data | |
Samples. It usually includes information such as | |
`gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. | |
iou_results: None | |
rescale (bool): If True, return boxes in | |
original image space. Default False. | |
Returns: | |
list[dict]: Instance segmentation \ | |
results and panoptic segmentation results for each \ | |
image. | |
.. code-block:: none | |
[ | |
{ | |
'pan_results': PixelData, | |
'ins_results': InstanceData, | |
# semantic segmentation results are not supported yet | |
'sem_results': PixelData | |
}, | |
... | |
] | |
""" | |
batch_img_metas = [ | |
data_sample.metainfo for data_sample in batch_data_samples | |
] | |
panoptic_on = self.test_cfg.get('panoptic_on', True) | |
semantic_on = self.test_cfg.get('semantic_on', False) | |
instance_on = self.test_cfg.get('instance_on', False) | |
proposal_on = self.test_cfg.get('proposal_on', False) | |
assert not semantic_on, 'segmantic segmentation ' \ | |
'results are not supported yet.' | |
results = [] | |
idx = 0 | |
for mask_cls_result, mask_pred_result, meta in zip( | |
mask_cls_results, mask_pred_results, batch_img_metas): | |
# remove padding | |
img_height, img_width = meta['img_shape'][:2] | |
mask_pred_result = mask_pred_result.to(mask_cls_results.device) | |
mask_pred_result = mask_pred_result[:, :img_height, :img_width] | |
if rescale: | |
# return result in original resolution | |
ori_height, ori_width = meta['ori_shape'][:2] | |
mask_pred_result = F.interpolate( | |
mask_pred_result[:, None], | |
size=(ori_height, ori_width), | |
mode='bilinear', | |
align_corners=False)[:, 0] | |
result = dict() | |
if panoptic_on: | |
pan_results = self.panoptic_postprocess( | |
mask_cls_result, mask_pred_result | |
) | |
result['pan_results'] = pan_results | |
if instance_on: | |
ins_results = self.instance_postprocess( | |
mask_cls_result, mask_pred_result | |
) | |
result['ins_results'] = ins_results | |
if semantic_on: | |
sem_results = self.semantic_postprocess( | |
mask_cls_result, mask_pred_result | |
) | |
result['sem_results'] = sem_results | |
if proposal_on: | |
pro_results = self.proposal_postprocess( | |
iou_results[idx], mask_pred_result | |
) | |
result['pro_results'] = pro_results | |
results.append(result) | |
idx += 1 | |
return results | |