import numpy as np import torch from mmcv.cnn.bricks.transformer import TransformerLayerSequence from mmengine.registry import MODELS from torch import nn from torch.nn.init import normal_ from .cross_view_hybrid_attention import TPVCrossViewHybridAttention from .image_cross_attention import TPVMSDeformableAttention3D @MODELS.register_module() class TPVFormerEncoder(TransformerLayerSequence): def __init__(self, tpv_h=200, tpv_w=200, tpv_z=16, pc_range=[-51.2, -51.2, -5, 51.2, 51.2, 3], num_feature_levels=4, num_cams=6, embed_dims=256, num_points_in_pillar=[4, 32, 32], num_points_in_pillar_cross_view=[32, 32, 32], num_layers=5, transformerlayers=None, positional_encoding=None, return_intermediate=False): super().__init__(transformerlayers, num_layers) self.tpv_h = tpv_h self.tpv_w = tpv_w self.tpv_z = tpv_z self.pc_range = pc_range self.real_w = pc_range[3] - pc_range[0] self.real_h = pc_range[4] - pc_range[1] self.real_z = pc_range[5] - pc_range[2] self.level_embeds = nn.Parameter( torch.Tensor(num_feature_levels, embed_dims)) self.cams_embeds = nn.Parameter(torch.Tensor(num_cams, embed_dims)) self.tpv_embedding_hw = nn.Embedding(tpv_h * tpv_w, embed_dims) self.tpv_embedding_zh = nn.Embedding(tpv_z * tpv_h, embed_dims) self.tpv_embedding_wz = nn.Embedding(tpv_w * tpv_z, embed_dims) ref_3d_hw = self.get_reference_points(tpv_h, tpv_w, self.real_z, num_points_in_pillar[0]) ref_3d_zh = self.get_reference_points(tpv_z, tpv_h, self.real_w, num_points_in_pillar[1]) ref_3d_zh = ref_3d_zh.permute(3, 0, 1, 2)[[2, 0, 1]] # change to x,y,z ref_3d_zh = ref_3d_zh.permute(1, 2, 3, 0) ref_3d_wz = self.get_reference_points(tpv_w, tpv_z, self.real_h, num_points_in_pillar[2]) ref_3d_wz = ref_3d_wz.permute(3, 0, 1, 2)[[1, 2, 0]] # change to x,y,z ref_3d_wz = ref_3d_wz.permute(1, 2, 3, 0) self.register_buffer('ref_3d_hw', ref_3d_hw) self.register_buffer('ref_3d_zh', ref_3d_zh) self.register_buffer('ref_3d_wz', ref_3d_wz) cross_view_ref_points = self.get_cross_view_ref_points( tpv_h, tpv_w, tpv_z, num_points_in_pillar_cross_view) self.register_buffer('cross_view_ref_points', cross_view_ref_points) # positional encoding self.positional_encoding = MODELS.build(positional_encoding) self.return_intermediate = return_intermediate def init_weights(self): """Initialize the transformer weights.""" for p in self.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p) for m in self.modules(): if isinstance(m, TPVMSDeformableAttention3D) or isinstance( m, TPVCrossViewHybridAttention): m.init_weights() normal_(self.level_embeds) normal_(self.cams_embeds) @staticmethod def get_cross_view_ref_points(tpv_h, tpv_w, tpv_z, num_points_in_pillar): # ref points generating target: (#query)hw+zh+wz, (#level)3, #p, 2 # generate points for hw and level 1 h_ranges = torch.linspace(0.5, tpv_h - 0.5, tpv_h) / tpv_h w_ranges = torch.linspace(0.5, tpv_w - 0.5, tpv_w) / tpv_w h_ranges = h_ranges.unsqueeze(-1).expand(-1, tpv_w).flatten() w_ranges = w_ranges.unsqueeze(0).expand(tpv_h, -1).flatten() hw_hw = torch.stack([w_ranges, h_ranges], dim=-1) # hw, 2 hw_hw = hw_hw.unsqueeze(1).expand(-1, num_points_in_pillar[2], -1) # hw, #p, 2 # generate points for hw and level 2 z_ranges = torch.linspace(0.5, tpv_z - 0.5, num_points_in_pillar[2]) / tpv_z # #p z_ranges = z_ranges.unsqueeze(0).expand(tpv_h * tpv_w, -1) # hw, #p h_ranges = torch.linspace(0.5, tpv_h - 0.5, tpv_h) / tpv_h h_ranges = h_ranges.reshape(-1, 1, 1).expand( -1, tpv_w, num_points_in_pillar[2]).flatten(0, 1) hw_zh = torch.stack([h_ranges, z_ranges], dim=-1) # hw, #p, 2 # generate points for hw and level 3 z_ranges = torch.linspace(0.5, tpv_z - 0.5, num_points_in_pillar[2]) / tpv_z # #p z_ranges = z_ranges.unsqueeze(0).expand(tpv_h * tpv_w, -1) # hw, #p w_ranges = torch.linspace(0.5, tpv_w - 0.5, tpv_w) / tpv_w w_ranges = w_ranges.reshape(1, -1, 1).expand( tpv_h, -1, num_points_in_pillar[2]).flatten(0, 1) hw_wz = torch.stack([z_ranges, w_ranges], dim=-1) # hw, #p, 2 # generate points for zh and level 1 w_ranges = torch.linspace(0.5, tpv_w - 0.5, num_points_in_pillar[1]) / tpv_w w_ranges = w_ranges.unsqueeze(0).expand(tpv_z * tpv_h, -1) h_ranges = torch.linspace(0.5, tpv_h - 0.5, tpv_h) / tpv_h h_ranges = h_ranges.reshape(1, -1, 1).expand( tpv_z, -1, num_points_in_pillar[1]).flatten(0, 1) zh_hw = torch.stack([w_ranges, h_ranges], dim=-1) # generate points for zh and level 2 z_ranges = torch.linspace(0.5, tpv_z - 0.5, tpv_z) / tpv_z z_ranges = z_ranges.reshape(-1, 1, 1).expand( -1, tpv_h, num_points_in_pillar[1]).flatten(0, 1) h_ranges = torch.linspace(0.5, tpv_h - 0.5, tpv_h) / tpv_h h_ranges = h_ranges.reshape(1, -1, 1).expand( tpv_z, -1, num_points_in_pillar[1]).flatten(0, 1) zh_zh = torch.stack([h_ranges, z_ranges], dim=-1) # zh, #p, 2 # generate points for zh and level 3 w_ranges = torch.linspace(0.5, tpv_w - 0.5, num_points_in_pillar[1]) / tpv_w w_ranges = w_ranges.unsqueeze(0).expand(tpv_z * tpv_h, -1) z_ranges = torch.linspace(0.5, tpv_z - 0.5, tpv_z) / tpv_z z_ranges = z_ranges.reshape(-1, 1, 1).expand( -1, tpv_h, num_points_in_pillar[1]).flatten(0, 1) zh_wz = torch.stack([z_ranges, w_ranges], dim=-1) # generate points for wz and level 1 h_ranges = torch.linspace(0.5, tpv_h - 0.5, num_points_in_pillar[0]) / tpv_h h_ranges = h_ranges.unsqueeze(0).expand(tpv_w * tpv_z, -1) w_ranges = torch.linspace(0.5, tpv_w - 0.5, tpv_w) / tpv_w w_ranges = w_ranges.reshape(-1, 1, 1).expand( -1, tpv_z, num_points_in_pillar[0]).flatten(0, 1) wz_hw = torch.stack([w_ranges, h_ranges], dim=-1) # generate points for wz and level 2 h_ranges = torch.linspace(0.5, tpv_h - 0.5, num_points_in_pillar[0]) / tpv_h h_ranges = h_ranges.unsqueeze(0).expand(tpv_w * tpv_z, -1) z_ranges = torch.linspace(0.5, tpv_z - 0.5, tpv_z) / tpv_z z_ranges = z_ranges.reshape(1, -1, 1).expand( tpv_w, -1, num_points_in_pillar[0]).flatten(0, 1) wz_zh = torch.stack([h_ranges, z_ranges], dim=-1) # generate points for wz and level 3 w_ranges = torch.linspace(0.5, tpv_w - 0.5, tpv_w) / tpv_w w_ranges = w_ranges.reshape(-1, 1, 1).expand( -1, tpv_z, num_points_in_pillar[0]).flatten(0, 1) z_ranges = torch.linspace(0.5, tpv_z - 0.5, tpv_z) / tpv_z z_ranges = z_ranges.reshape(1, -1, 1).expand( tpv_w, -1, num_points_in_pillar[0]).flatten(0, 1) wz_wz = torch.stack([z_ranges, w_ranges], dim=-1) reference_points = torch.cat([ torch.stack([hw_hw, hw_zh, hw_wz], dim=1), torch.stack([zh_hw, zh_zh, zh_wz], dim=1), torch.stack([wz_hw, wz_zh, wz_wz], dim=1) ], dim=0) # hw+zh+wz, 3, #p, 2 return reference_points @staticmethod def get_reference_points(H, W, Z=8, num_points_in_pillar=4, dim='3d', bs=1, device='cuda', dtype=torch.float): """Get the reference points used in SCA and TSA. Args: H, W: spatial shape of tpv. Z: height of pillar. device (obj:`device`): The device where reference_points should be. Returns: Tensor: reference points used in decoder, has \ shape (bs, num_keys, num_levels, 2). """ # reference points in 3D space, used in spatial cross-attention (SCA) zs = torch.linspace( 0.5, Z - 0.5, num_points_in_pillar, dtype=dtype, device=device).view(-1, 1, 1).expand( num_points_in_pillar, H, W) / Z xs = torch.linspace( 0.5, W - 0.5, W, dtype=dtype, device=device).view(1, 1, -1).expand( num_points_in_pillar, H, W) / W ys = torch.linspace( 0.5, H - 0.5, H, dtype=dtype, device=device).view(1, -1, 1).expand( num_points_in_pillar, H, W) / H ref_3d = torch.stack((xs, ys, zs), -1) ref_3d = ref_3d.permute(0, 3, 1, 2).flatten(2).permute(0, 2, 1) ref_3d = ref_3d[None].repeat(bs, 1, 1, 1) return ref_3d def point_sampling(self, reference_points, pc_range, batch_data_smaples): lidar2img = [] for data_sample in batch_data_smaples: lidar2img.append(data_sample.lidar2img) lidar2img = np.asarray(lidar2img) lidar2img = reference_points.new_tensor(lidar2img) # (B, N, 4, 4) reference_points = reference_points.clone() reference_points[..., 0:1] = reference_points[..., 0:1] * \ (pc_range[3] - pc_range[0]) + pc_range[0] reference_points[..., 1:2] = reference_points[..., 1:2] * \ (pc_range[4] - pc_range[1]) + pc_range[1] reference_points[..., 2:3] = reference_points[..., 2:3] * \ (pc_range[5] - pc_range[2]) + pc_range[2] reference_points = torch.cat( (reference_points, torch.ones_like(reference_points[..., :1])), -1) reference_points = reference_points.permute(1, 0, 2, 3) D, B, num_query = reference_points.size()[:3] num_cam = lidar2img.size(1) reference_points = reference_points.view(D, B, 1, num_query, 4).repeat( 1, 1, num_cam, 1, 1).unsqueeze(-1) lidar2img = lidar2img.view(1, B, num_cam, 1, 4, 4).repeat(D, 1, 1, num_query, 1, 1) reference_points_cam = torch.matmul( lidar2img.to(torch.float32), reference_points.to(torch.float32)).squeeze(-1) eps = 1e-5 tpv_mask = (reference_points_cam[..., 2:3] > eps) reference_points_cam = reference_points_cam[..., 0:2] / torch.maximum( reference_points_cam[..., 2:3], torch.ones_like(reference_points_cam[..., 2:3]) * eps) reference_points_cam[..., 0] /= data_sample.batch_input_shape[1] reference_points_cam[..., 1] /= data_sample.batch_input_shape[0] tpv_mask = ( tpv_mask & (reference_points_cam[..., 1:2] > 0.0) & (reference_points_cam[..., 1:2] < 1.0) & (reference_points_cam[..., 0:1] < 1.0) & (reference_points_cam[..., 0:1] > 0.0)) tpv_mask = torch.nan_to_num(tpv_mask) reference_points_cam = reference_points_cam.permute(2, 1, 3, 0, 4) tpv_mask = tpv_mask.permute(2, 1, 3, 0, 4).squeeze(-1) return reference_points_cam, tpv_mask def forward(self, mlvl_feats, batch_data_samples): """Forward function. Args: mlvl_feats (tuple[Tensor]): Features from the upstream network, each is a 5D-tensor with shape (B, N, C, H, W). """ bs = mlvl_feats[0].shape[0] dtype = mlvl_feats[0].dtype device = mlvl_feats[0].device # tpv queries and pos embeds tpv_queries_hw = self.tpv_embedding_hw.weight.to(dtype) tpv_queries_zh = self.tpv_embedding_zh.weight.to(dtype) tpv_queries_wz = self.tpv_embedding_wz.weight.to(dtype) tpv_queries_hw = tpv_queries_hw.unsqueeze(0).repeat(bs, 1, 1) tpv_queries_zh = tpv_queries_zh.unsqueeze(0).repeat(bs, 1, 1) tpv_queries_wz = tpv_queries_wz.unsqueeze(0).repeat(bs, 1, 1) tpv_query = [tpv_queries_hw, tpv_queries_zh, tpv_queries_wz] tpv_pos_hw = self.positional_encoding(bs, device, 'z') tpv_pos_zh = self.positional_encoding(bs, device, 'w') tpv_pos_wz = self.positional_encoding(bs, device, 'h') tpv_pos = [tpv_pos_hw, tpv_pos_zh, tpv_pos_wz] # flatten image features of different scales feat_flatten = [] spatial_shapes = [] for lvl, feat in enumerate(mlvl_feats): bs, num_cam, c, h, w = feat.shape spatial_shape = (h, w) feat = feat.flatten(3).permute(1, 0, 3, 2) # num_cam, bs, hw, c feat = feat + self.cams_embeds[:, None, None, :].to(dtype) feat = feat + self.level_embeds[None, None, lvl:lvl + 1, :].to(dtype) spatial_shapes.append(spatial_shape) feat_flatten.append(feat) feat_flatten = torch.cat(feat_flatten, 2) # num_cam, bs, hw++, c spatial_shapes = torch.as_tensor( spatial_shapes, dtype=torch.long, device=device) level_start_index = torch.cat((spatial_shapes.new_zeros( (1, )), spatial_shapes.prod(1).cumsum(0)[:-1])) feat_flatten = feat_flatten.permute( 0, 2, 1, 3) # (num_cam, H*W, bs, embed_dims) reference_points_cams, tpv_masks = [], [] ref_3ds = [self.ref_3d_hw, self.ref_3d_zh, self.ref_3d_wz] for ref_3d in ref_3ds: reference_points_cam, tpv_mask = self.point_sampling( ref_3d, self.pc_range, batch_data_samples) # num_cam, bs, hw++, #p, 2 reference_points_cams.append(reference_points_cam) tpv_masks.append(tpv_mask) ref_cross_view = self.cross_view_ref_points.clone().unsqueeze( 0).expand(bs, -1, -1, -1, -1) intermediate = [] for layer in self.layers: output = layer( tpv_query, feat_flatten, feat_flatten, tpv_pos=tpv_pos, ref_2d=ref_cross_view, tpv_h=self.tpv_h, tpv_w=self.tpv_w, tpv_z=self.tpv_z, spatial_shapes=spatial_shapes, level_start_index=level_start_index, reference_points_cams=reference_points_cams, tpv_masks=tpv_masks) tpv_query = output if self.return_intermediate: intermediate.append(output) if self.return_intermediate: return torch.stack(intermediate) return output