# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. from typing import Sequence, Tuple import torch from mmcv.ops import batched_nms from mmengine.structures import InstanceData from mmdet.structures import DetDataSample, SampleList def shift_rbboxes(bboxes: torch.Tensor, offset: Sequence[int]): """Shift rotated bboxes with offset. Args: bboxes (Tensor): The rotated bboxes need to be translated. With shape (n, 5), which means (x, y, w, h, a). offset (Sequence[int]): The translation offsets with shape of (2, ). Returns: Tensor: Shifted rotated bboxes. """ offset_tensor = bboxes.new_tensor(offset) shifted_bboxes = bboxes.clone() shifted_bboxes[:, 0:2] = shifted_bboxes[:, 0:2] + offset_tensor return shifted_bboxes def shift_predictions(det_data_samples: SampleList, offsets: Sequence[Tuple[int, int]], src_image_shape: Tuple[int, int]) -> SampleList: """Shift predictions to the original image. Args: det_data_samples (List[:obj:`DetDataSample`]): A list of patch results. offsets (Sequence[Tuple[int, int]]): Positions of the left top points of patches. src_image_shape (Tuple[int, int]): A (height, width) tuple of the large image's width and height. Returns: (List[:obj:`DetDataSample`]): shifted results. """ try: from sahi.slicing import shift_bboxes, shift_masks except ImportError: raise ImportError('Please run "pip install -U sahi" ' 'to install sahi first for large image inference.') assert len(det_data_samples) == len( offsets), 'The `results` should has the ' 'same length with `offsets`.' shifted_predictions = [] for det_data_sample, offset in zip(det_data_samples, offsets): pred_inst = det_data_sample.pred_instances.clone() # Check bbox type if pred_inst.bboxes.size(-1) == 4: # Horizontal bboxes shifted_bboxes = shift_bboxes(pred_inst.bboxes, offset) elif pred_inst.bboxes.size(-1) == 5: # Rotated bboxes shifted_bboxes = shift_rbboxes(pred_inst.bboxes, offset) else: raise NotImplementedError # shift bboxes and masks pred_inst.bboxes = shifted_bboxes if 'masks' in det_data_sample: pred_inst.masks = shift_masks(pred_inst.masks, offset, src_image_shape) shifted_predictions.append(pred_inst.clone()) shifted_predictions = InstanceData.cat(shifted_predictions) return shifted_predictions def merge_results_by_nms(results: SampleList, offsets: Sequence[Tuple[int, int]], src_image_shape: Tuple[int, int], nms_cfg: dict) -> DetDataSample: """Merge patch results by nms. Args: results (List[:obj:`DetDataSample`]): A list of patch results. offsets (Sequence[Tuple[int, int]]): Positions of the left top points of patches. src_image_shape (Tuple[int, int]): A (height, width) tuple of the large image's width and height. nms_cfg (dict): it should specify nms type and other parameters like `iou_threshold`. Returns: :obj:`DetDataSample`: merged results. """ shifted_instances = shift_predictions(results, offsets, src_image_shape) _, keeps = batched_nms( boxes=shifted_instances.bboxes, scores=shifted_instances.scores, idxs=shifted_instances.labels, nms_cfg=nms_cfg) merged_instances = shifted_instances[keeps] merged_result = results[0].clone() merged_result.pred_instances = merged_instances return merged_result