from typing import Union import torch from torch import Tensor from mmdet3d.registry import TASK_UTILS @TASK_UTILS.register_module() class BBox3DL1Cost(object): """BBox3DL1Cost. Args: weight (Union[float, int]): Cost weight. Defaults to 1. """ def __init__(self, weight: Union[float, int] = 1.): self.weight = weight def __call__(self, bbox_pred: Tensor, gt_bboxes: Tensor) -> Tensor: """Compute match cost. Args: bbox_pred (Tensor): Predicted boxes with normalized coordinates (cx,cy,l,w,cz,h,sin(φ),cos(φ),v_x,v_y) which are all in range [0, 1] and shape [num_query, 10]. gt_bboxes (Tensor): Ground truth boxes with `normalized` coordinates (cx,cy,l,w,cz,h,sin(φ),cos(φ),v_x,v_y). Shape [num_gt, 10]. Returns: Tensor: Match Cost matrix of shape (num_preds, num_gts). """ bbox_cost = torch.cdist(bbox_pred, gt_bboxes, p=1) return bbox_cost * self.weight