# Copyright (c) OpenMMLab. All rights reserved. from typing import List, Tuple import torch from torch import Tensor from torch import nn as nn from mmdet3d.registry import MODELS from mmdet3d.structures import points_cam2img from . import apply_3d_transformation, bbox_2d_transform, coord_2d_transform EPS = 1e-6 @MODELS.register_module() class VoteFusion(nn.Module): """Fuse 2d features from 3d seeds. Args: num_classes (int): Number of classes. max_imvote_per_pixel (int): Max number of imvotes. """ def __init__(self, num_classes: int = 10, max_imvote_per_pixel: int = 3) -> None: super(VoteFusion, self).__init__() self.num_classes = num_classes self.max_imvote_per_pixel = max_imvote_per_pixel def forward(self, imgs: List[Tensor], bboxes_2d_rescaled: List[Tensor], seeds_3d_depth: List[Tensor], img_metas: List[dict]) -> Tuple[Tensor]: """Forward function. Args: imgs (List[Tensor]): Image features. bboxes_2d_rescaled (List[Tensor]): 2D bboxes. seeds_3d_depth (List[Tensor]): 3D seeds. img_metas (List[dict]): Meta information of images. Returns: Tuple[Tensor]: - img_features: Concatenated cues of each point. - masks: Validity mask of each feature. """ img_features = [] masks = [] for i, data in enumerate( zip(imgs, bboxes_2d_rescaled, seeds_3d_depth, img_metas)): img, bbox_2d_rescaled, seed_3d_depth, img_meta = data bbox_num = bbox_2d_rescaled.shape[0] seed_num = seed_3d_depth.shape[0] img_shape = img_meta['img_shape'] # first reverse the data transformations xyz_depth = apply_3d_transformation( seed_3d_depth, 'DEPTH', img_meta, reverse=True) # project points from depth to image depth2img = xyz_depth.new_tensor(img_meta['depth2img']) uvz_origin = points_cam2img(xyz_depth, depth2img, True) z_cam = uvz_origin[..., 2] uv_origin = (uvz_origin[..., :2] - 1).round() # rescale 2d coordinates and bboxes uv_rescaled = coord_2d_transform(img_meta, uv_origin, True) bbox_2d_origin = bbox_2d_transform(img_meta, bbox_2d_rescaled, False) if bbox_num == 0: imvote_num = seed_num * self.max_imvote_per_pixel # use zero features two_cues = torch.zeros((15, imvote_num), device=seed_3d_depth.device) mask_zero = torch.zeros( imvote_num - seed_num, device=seed_3d_depth.device).bool() mask_one = torch.ones( seed_num, device=seed_3d_depth.device).bool() mask = torch.cat([mask_one, mask_zero], dim=0) else: # expand bboxes and seeds bbox_expanded = bbox_2d_origin.view(1, bbox_num, -1).expand( seed_num, -1, -1) seed_2d_expanded = uv_origin.view(seed_num, 1, -1).expand(-1, bbox_num, -1) seed_2d_expanded_x, seed_2d_expanded_y = \ seed_2d_expanded.split(1, dim=-1) bbox_expanded_l, bbox_expanded_t, bbox_expanded_r, \ bbox_expanded_b, bbox_expanded_conf, bbox_expanded_cls = \ bbox_expanded.split(1, dim=-1) bbox_expanded_midx = (bbox_expanded_l + bbox_expanded_r) / 2 bbox_expanded_midy = (bbox_expanded_t + bbox_expanded_b) / 2 seed_2d_in_bbox_x = (seed_2d_expanded_x > bbox_expanded_l) * \ (seed_2d_expanded_x < bbox_expanded_r) seed_2d_in_bbox_y = (seed_2d_expanded_y > bbox_expanded_t) * \ (seed_2d_expanded_y < bbox_expanded_b) seed_2d_in_bbox = seed_2d_in_bbox_x * seed_2d_in_bbox_y # semantic cues, dim=class_num sem_cue = torch.zeros_like(bbox_expanded_conf).expand( -1, -1, self.num_classes) sem_cue = sem_cue.scatter(-1, bbox_expanded_cls.long(), bbox_expanded_conf) # bbox center - uv delta_u = bbox_expanded_midx - seed_2d_expanded_x delta_v = bbox_expanded_midy - seed_2d_expanded_y seed_3d_expanded = seed_3d_depth.view(seed_num, 1, -1).expand( -1, bbox_num, -1) z_cam = z_cam.view(seed_num, 1, 1).expand(-1, bbox_num, -1) imvote = torch.cat( [delta_u, delta_v, torch.zeros_like(delta_v)], dim=-1).view(-1, 3) imvote = imvote * z_cam.reshape(-1, 1) imvote = imvote @ torch.inverse(depth2img.t()) # apply transformation to lifted imvotes imvote = apply_3d_transformation( imvote, 'DEPTH', img_meta, reverse=False) seed_3d_expanded = seed_3d_expanded.reshape(imvote.shape) # ray angle ray_angle = seed_3d_expanded + imvote ray_angle /= torch.sqrt(torch.sum(ray_angle**2, -1) + EPS).unsqueeze(-1) # imvote lifted to 3d xz = ray_angle[:, [0, 2]] / (ray_angle[:, [1]] + EPS) \ * seed_3d_expanded[:, [1]] - seed_3d_expanded[:, [0, 2]] # geometric cues, dim=5 geo_cue = torch.cat([xz, ray_angle], dim=-1).view(seed_num, -1, 5) two_cues = torch.cat([geo_cue, sem_cue], dim=-1) # mask to 0 if seed not in bbox two_cues = two_cues * seed_2d_in_bbox.float() feature_size = two_cues.shape[-1] # if bbox number is too small, append zeros if bbox_num < self.max_imvote_per_pixel: append_num = self.max_imvote_per_pixel - bbox_num append_zeros = torch.zeros( (seed_num, append_num, 1), device=seed_2d_in_bbox.device).bool() seed_2d_in_bbox = torch.cat( [seed_2d_in_bbox, append_zeros], dim=1) append_zeros = torch.zeros( (seed_num, append_num, feature_size), device=two_cues.device) two_cues = torch.cat([two_cues, append_zeros], dim=1) append_zeros = torch.zeros((seed_num, append_num, 1), device=two_cues.device) bbox_expanded_conf = torch.cat( [bbox_expanded_conf, append_zeros], dim=1) # sort the valid seed-bbox pair according to confidence pair_score = seed_2d_in_bbox.float() + bbox_expanded_conf # and find the largests mask, indices = pair_score.topk( self.max_imvote_per_pixel, dim=1, largest=True, sorted=True) indices_img = indices.expand(-1, -1, feature_size) two_cues = two_cues.gather(dim=1, index=indices_img) two_cues = two_cues.transpose(1, 0) two_cues = two_cues.reshape(-1, feature_size).transpose( 1, 0).contiguous() # since conf is ~ (0, 1), floor gives us validity mask = mask.floor().int() mask = mask.transpose(1, 0).reshape(-1).bool() # clear the padding img = img[:, :img_shape[0], :img_shape[1]] img_flatten = img.reshape(3, -1).float() img_flatten /= 255. # take the normalized pixel value as texture cue uv_rescaled[:, 0] = torch.clamp(uv_rescaled[:, 0].round(), 0, img_shape[1] - 1) uv_rescaled[:, 1] = torch.clamp(uv_rescaled[:, 1].round(), 0, img_shape[0] - 1) uv_flatten = uv_rescaled[:, 1].round() * \ img_shape[1] + uv_rescaled[:, 0].round() uv_expanded = uv_flatten.unsqueeze(0).expand(3, -1).long() txt_cue = torch.gather(img_flatten, dim=-1, index=uv_expanded) txt_cue = txt_cue.unsqueeze(1).expand(-1, self.max_imvote_per_pixel, -1).reshape(3, -1) # append texture cue img_feature = torch.cat([two_cues, txt_cue], dim=0) img_features.append(img_feature) masks.append(mask) return torch.stack(img_features, 0), torch.stack(masks, 0)