giantmonkeyTC
2344
34d1f8b
from typing import Union
import torch
from torch import Tensor
from mmdet3d.registry import TASK_UTILS
@TASK_UTILS.register_module()
class BBox3DL1Cost(object):
"""BBox3DL1Cost.
Args:
weight (Union[float, int]): Cost weight. Defaults to 1.
"""
def __init__(self, weight: Union[float, int] = 1.):
self.weight = weight
def __call__(self, bbox_pred: Tensor, gt_bboxes: Tensor) -> Tensor:
"""Compute match cost.
Args:
bbox_pred (Tensor): Predicted boxes with normalized coordinates
(cx,cy,l,w,cz,h,sin(φ),cos(φ),v_x,v_y)
which are all in range [0, 1] and shape [num_query, 10].
gt_bboxes (Tensor): Ground truth boxes with `normalized`
coordinates (cx,cy,l,w,cz,h,sin(φ),cos(φ),v_x,v_y).
Shape [num_gt, 10].
Returns:
Tensor: Match Cost matrix of shape (num_preds, num_gts).
"""
bbox_cost = torch.cdist(bbox_pred, gt_bboxes, p=1)
return bbox_cost * self.weight