|
import numpy as np |
|
import torch |
|
from mmcv.cnn.bricks.transformer import TransformerLayerSequence |
|
from mmengine.registry import MODELS |
|
from torch import nn |
|
from torch.nn.init import normal_ |
|
|
|
from .cross_view_hybrid_attention import TPVCrossViewHybridAttention |
|
from .image_cross_attention import TPVMSDeformableAttention3D |
|
|
|
|
|
@MODELS.register_module() |
|
class TPVFormerEncoder(TransformerLayerSequence): |
|
|
|
def __init__(self, |
|
tpv_h=200, |
|
tpv_w=200, |
|
tpv_z=16, |
|
pc_range=[-51.2, -51.2, -5, 51.2, 51.2, 3], |
|
num_feature_levels=4, |
|
num_cams=6, |
|
embed_dims=256, |
|
num_points_in_pillar=[4, 32, 32], |
|
num_points_in_pillar_cross_view=[32, 32, 32], |
|
num_layers=5, |
|
transformerlayers=None, |
|
positional_encoding=None, |
|
return_intermediate=False): |
|
super().__init__(transformerlayers, num_layers) |
|
|
|
self.tpv_h = tpv_h |
|
self.tpv_w = tpv_w |
|
self.tpv_z = tpv_z |
|
self.pc_range = pc_range |
|
self.real_w = pc_range[3] - pc_range[0] |
|
self.real_h = pc_range[4] - pc_range[1] |
|
self.real_z = pc_range[5] - pc_range[2] |
|
|
|
self.level_embeds = nn.Parameter( |
|
torch.Tensor(num_feature_levels, embed_dims)) |
|
self.cams_embeds = nn.Parameter(torch.Tensor(num_cams, embed_dims)) |
|
self.tpv_embedding_hw = nn.Embedding(tpv_h * tpv_w, embed_dims) |
|
self.tpv_embedding_zh = nn.Embedding(tpv_z * tpv_h, embed_dims) |
|
self.tpv_embedding_wz = nn.Embedding(tpv_w * tpv_z, embed_dims) |
|
|
|
ref_3d_hw = self.get_reference_points(tpv_h, tpv_w, self.real_z, |
|
num_points_in_pillar[0]) |
|
ref_3d_zh = self.get_reference_points(tpv_z, tpv_h, self.real_w, |
|
num_points_in_pillar[1]) |
|
ref_3d_zh = ref_3d_zh.permute(3, 0, 1, 2)[[2, 0, 1]] |
|
ref_3d_zh = ref_3d_zh.permute(1, 2, 3, 0) |
|
ref_3d_wz = self.get_reference_points(tpv_w, tpv_z, self.real_h, |
|
num_points_in_pillar[2]) |
|
ref_3d_wz = ref_3d_wz.permute(3, 0, 1, 2)[[1, 2, 0]] |
|
ref_3d_wz = ref_3d_wz.permute(1, 2, 3, 0) |
|
self.register_buffer('ref_3d_hw', ref_3d_hw) |
|
self.register_buffer('ref_3d_zh', ref_3d_zh) |
|
self.register_buffer('ref_3d_wz', ref_3d_wz) |
|
|
|
cross_view_ref_points = self.get_cross_view_ref_points( |
|
tpv_h, tpv_w, tpv_z, num_points_in_pillar_cross_view) |
|
self.register_buffer('cross_view_ref_points', cross_view_ref_points) |
|
|
|
|
|
self.positional_encoding = MODELS.build(positional_encoding) |
|
self.return_intermediate = return_intermediate |
|
|
|
def init_weights(self): |
|
"""Initialize the transformer weights.""" |
|
for p in self.parameters(): |
|
if p.dim() > 1: |
|
nn.init.xavier_uniform_(p) |
|
for m in self.modules(): |
|
if isinstance(m, TPVMSDeformableAttention3D) or isinstance( |
|
m, TPVCrossViewHybridAttention): |
|
m.init_weights() |
|
normal_(self.level_embeds) |
|
normal_(self.cams_embeds) |
|
|
|
@staticmethod |
|
def get_cross_view_ref_points(tpv_h, tpv_w, tpv_z, num_points_in_pillar): |
|
|
|
|
|
h_ranges = torch.linspace(0.5, tpv_h - 0.5, tpv_h) / tpv_h |
|
w_ranges = torch.linspace(0.5, tpv_w - 0.5, tpv_w) / tpv_w |
|
h_ranges = h_ranges.unsqueeze(-1).expand(-1, tpv_w).flatten() |
|
w_ranges = w_ranges.unsqueeze(0).expand(tpv_h, -1).flatten() |
|
hw_hw = torch.stack([w_ranges, h_ranges], dim=-1) |
|
hw_hw = hw_hw.unsqueeze(1).expand(-1, num_points_in_pillar[2], |
|
-1) |
|
|
|
z_ranges = torch.linspace(0.5, tpv_z - 0.5, |
|
num_points_in_pillar[2]) / tpv_z |
|
z_ranges = z_ranges.unsqueeze(0).expand(tpv_h * tpv_w, -1) |
|
h_ranges = torch.linspace(0.5, tpv_h - 0.5, tpv_h) / tpv_h |
|
h_ranges = h_ranges.reshape(-1, 1, 1).expand( |
|
-1, tpv_w, num_points_in_pillar[2]).flatten(0, 1) |
|
hw_zh = torch.stack([h_ranges, z_ranges], dim=-1) |
|
|
|
z_ranges = torch.linspace(0.5, tpv_z - 0.5, |
|
num_points_in_pillar[2]) / tpv_z |
|
z_ranges = z_ranges.unsqueeze(0).expand(tpv_h * tpv_w, -1) |
|
w_ranges = torch.linspace(0.5, tpv_w - 0.5, tpv_w) / tpv_w |
|
w_ranges = w_ranges.reshape(1, -1, 1).expand( |
|
tpv_h, -1, num_points_in_pillar[2]).flatten(0, 1) |
|
hw_wz = torch.stack([z_ranges, w_ranges], dim=-1) |
|
|
|
|
|
w_ranges = torch.linspace(0.5, tpv_w - 0.5, |
|
num_points_in_pillar[1]) / tpv_w |
|
w_ranges = w_ranges.unsqueeze(0).expand(tpv_z * tpv_h, -1) |
|
h_ranges = torch.linspace(0.5, tpv_h - 0.5, tpv_h) / tpv_h |
|
h_ranges = h_ranges.reshape(1, -1, 1).expand( |
|
tpv_z, -1, num_points_in_pillar[1]).flatten(0, 1) |
|
zh_hw = torch.stack([w_ranges, h_ranges], dim=-1) |
|
|
|
z_ranges = torch.linspace(0.5, tpv_z - 0.5, tpv_z) / tpv_z |
|
z_ranges = z_ranges.reshape(-1, 1, 1).expand( |
|
-1, tpv_h, num_points_in_pillar[1]).flatten(0, 1) |
|
h_ranges = torch.linspace(0.5, tpv_h - 0.5, tpv_h) / tpv_h |
|
h_ranges = h_ranges.reshape(1, -1, 1).expand( |
|
tpv_z, -1, num_points_in_pillar[1]).flatten(0, 1) |
|
zh_zh = torch.stack([h_ranges, z_ranges], dim=-1) |
|
|
|
w_ranges = torch.linspace(0.5, tpv_w - 0.5, |
|
num_points_in_pillar[1]) / tpv_w |
|
w_ranges = w_ranges.unsqueeze(0).expand(tpv_z * tpv_h, -1) |
|
z_ranges = torch.linspace(0.5, tpv_z - 0.5, tpv_z) / tpv_z |
|
z_ranges = z_ranges.reshape(-1, 1, 1).expand( |
|
-1, tpv_h, num_points_in_pillar[1]).flatten(0, 1) |
|
zh_wz = torch.stack([z_ranges, w_ranges], dim=-1) |
|
|
|
|
|
h_ranges = torch.linspace(0.5, tpv_h - 0.5, |
|
num_points_in_pillar[0]) / tpv_h |
|
h_ranges = h_ranges.unsqueeze(0).expand(tpv_w * tpv_z, -1) |
|
w_ranges = torch.linspace(0.5, tpv_w - 0.5, tpv_w) / tpv_w |
|
w_ranges = w_ranges.reshape(-1, 1, 1).expand( |
|
-1, tpv_z, num_points_in_pillar[0]).flatten(0, 1) |
|
wz_hw = torch.stack([w_ranges, h_ranges], dim=-1) |
|
|
|
h_ranges = torch.linspace(0.5, tpv_h - 0.5, |
|
num_points_in_pillar[0]) / tpv_h |
|
h_ranges = h_ranges.unsqueeze(0).expand(tpv_w * tpv_z, -1) |
|
z_ranges = torch.linspace(0.5, tpv_z - 0.5, tpv_z) / tpv_z |
|
z_ranges = z_ranges.reshape(1, -1, 1).expand( |
|
tpv_w, -1, num_points_in_pillar[0]).flatten(0, 1) |
|
wz_zh = torch.stack([h_ranges, z_ranges], dim=-1) |
|
|
|
w_ranges = torch.linspace(0.5, tpv_w - 0.5, tpv_w) / tpv_w |
|
w_ranges = w_ranges.reshape(-1, 1, 1).expand( |
|
-1, tpv_z, num_points_in_pillar[0]).flatten(0, 1) |
|
z_ranges = torch.linspace(0.5, tpv_z - 0.5, tpv_z) / tpv_z |
|
z_ranges = z_ranges.reshape(1, -1, 1).expand( |
|
tpv_w, -1, num_points_in_pillar[0]).flatten(0, 1) |
|
wz_wz = torch.stack([z_ranges, w_ranges], dim=-1) |
|
|
|
reference_points = torch.cat([ |
|
torch.stack([hw_hw, hw_zh, hw_wz], dim=1), |
|
torch.stack([zh_hw, zh_zh, zh_wz], dim=1), |
|
torch.stack([wz_hw, wz_zh, wz_wz], dim=1) |
|
], |
|
dim=0) |
|
|
|
return reference_points |
|
|
|
@staticmethod |
|
def get_reference_points(H, |
|
W, |
|
Z=8, |
|
num_points_in_pillar=4, |
|
dim='3d', |
|
bs=1, |
|
device='cuda', |
|
dtype=torch.float): |
|
"""Get the reference points used in SCA and TSA. |
|
|
|
Args: |
|
H, W: spatial shape of tpv. |
|
Z: height of pillar. |
|
device (obj:`device`): The device where |
|
reference_points should be. |
|
Returns: |
|
Tensor: reference points used in decoder, has \ |
|
shape (bs, num_keys, num_levels, 2). |
|
""" |
|
|
|
|
|
zs = torch.linspace( |
|
0.5, Z - 0.5, num_points_in_pillar, |
|
dtype=dtype, device=device).view(-1, 1, 1).expand( |
|
num_points_in_pillar, H, W) / Z |
|
xs = torch.linspace( |
|
0.5, W - 0.5, W, dtype=dtype, device=device).view(1, 1, -1).expand( |
|
num_points_in_pillar, H, W) / W |
|
ys = torch.linspace( |
|
0.5, H - 0.5, H, dtype=dtype, device=device).view(1, -1, 1).expand( |
|
num_points_in_pillar, H, W) / H |
|
ref_3d = torch.stack((xs, ys, zs), -1) |
|
ref_3d = ref_3d.permute(0, 3, 1, 2).flatten(2).permute(0, 2, 1) |
|
ref_3d = ref_3d[None].repeat(bs, 1, 1, 1) |
|
return ref_3d |
|
|
|
def point_sampling(self, reference_points, pc_range, batch_data_smaples): |
|
|
|
lidar2img = [] |
|
for data_sample in batch_data_smaples: |
|
lidar2img.append(data_sample.lidar2img) |
|
lidar2img = np.asarray(lidar2img) |
|
lidar2img = reference_points.new_tensor(lidar2img) |
|
reference_points = reference_points.clone() |
|
|
|
reference_points[..., 0:1] = reference_points[..., 0:1] * \ |
|
(pc_range[3] - pc_range[0]) + pc_range[0] |
|
reference_points[..., 1:2] = reference_points[..., 1:2] * \ |
|
(pc_range[4] - pc_range[1]) + pc_range[1] |
|
reference_points[..., 2:3] = reference_points[..., 2:3] * \ |
|
(pc_range[5] - pc_range[2]) + pc_range[2] |
|
|
|
reference_points = torch.cat( |
|
(reference_points, torch.ones_like(reference_points[..., :1])), -1) |
|
|
|
reference_points = reference_points.permute(1, 0, 2, 3) |
|
D, B, num_query = reference_points.size()[:3] |
|
num_cam = lidar2img.size(1) |
|
|
|
reference_points = reference_points.view(D, B, 1, num_query, 4).repeat( |
|
1, 1, num_cam, 1, 1).unsqueeze(-1) |
|
|
|
lidar2img = lidar2img.view(1, B, num_cam, 1, 4, |
|
4).repeat(D, 1, 1, num_query, 1, 1) |
|
|
|
reference_points_cam = torch.matmul( |
|
lidar2img.to(torch.float32), |
|
reference_points.to(torch.float32)).squeeze(-1) |
|
eps = 1e-5 |
|
|
|
tpv_mask = (reference_points_cam[..., 2:3] > eps) |
|
reference_points_cam = reference_points_cam[..., 0:2] / torch.maximum( |
|
reference_points_cam[..., 2:3], |
|
torch.ones_like(reference_points_cam[..., 2:3]) * eps) |
|
|
|
reference_points_cam[..., 0] /= data_sample.batch_input_shape[1] |
|
reference_points_cam[..., 1] /= data_sample.batch_input_shape[0] |
|
|
|
tpv_mask = ( |
|
tpv_mask & (reference_points_cam[..., 1:2] > 0.0) |
|
& (reference_points_cam[..., 1:2] < 1.0) |
|
& (reference_points_cam[..., 0:1] < 1.0) |
|
& (reference_points_cam[..., 0:1] > 0.0)) |
|
|
|
tpv_mask = torch.nan_to_num(tpv_mask) |
|
|
|
reference_points_cam = reference_points_cam.permute(2, 1, 3, 0, 4) |
|
tpv_mask = tpv_mask.permute(2, 1, 3, 0, 4).squeeze(-1) |
|
|
|
return reference_points_cam, tpv_mask |
|
|
|
def forward(self, mlvl_feats, batch_data_samples): |
|
"""Forward function. |
|
|
|
Args: |
|
mlvl_feats (tuple[Tensor]): Features from the upstream |
|
network, each is a 5D-tensor with shape |
|
(B, N, C, H, W). |
|
""" |
|
bs = mlvl_feats[0].shape[0] |
|
dtype = mlvl_feats[0].dtype |
|
device = mlvl_feats[0].device |
|
|
|
|
|
tpv_queries_hw = self.tpv_embedding_hw.weight.to(dtype) |
|
tpv_queries_zh = self.tpv_embedding_zh.weight.to(dtype) |
|
tpv_queries_wz = self.tpv_embedding_wz.weight.to(dtype) |
|
tpv_queries_hw = tpv_queries_hw.unsqueeze(0).repeat(bs, 1, 1) |
|
tpv_queries_zh = tpv_queries_zh.unsqueeze(0).repeat(bs, 1, 1) |
|
tpv_queries_wz = tpv_queries_wz.unsqueeze(0).repeat(bs, 1, 1) |
|
tpv_query = [tpv_queries_hw, tpv_queries_zh, tpv_queries_wz] |
|
|
|
tpv_pos_hw = self.positional_encoding(bs, device, 'z') |
|
tpv_pos_zh = self.positional_encoding(bs, device, 'w') |
|
tpv_pos_wz = self.positional_encoding(bs, device, 'h') |
|
tpv_pos = [tpv_pos_hw, tpv_pos_zh, tpv_pos_wz] |
|
|
|
|
|
feat_flatten = [] |
|
spatial_shapes = [] |
|
for lvl, feat in enumerate(mlvl_feats): |
|
bs, num_cam, c, h, w = feat.shape |
|
spatial_shape = (h, w) |
|
feat = feat.flatten(3).permute(1, 0, 3, 2) |
|
feat = feat + self.cams_embeds[:, None, None, :].to(dtype) |
|
feat = feat + self.level_embeds[None, None, |
|
lvl:lvl + 1, :].to(dtype) |
|
spatial_shapes.append(spatial_shape) |
|
feat_flatten.append(feat) |
|
|
|
feat_flatten = torch.cat(feat_flatten, 2) |
|
spatial_shapes = torch.as_tensor( |
|
spatial_shapes, dtype=torch.long, device=device) |
|
level_start_index = torch.cat((spatial_shapes.new_zeros( |
|
(1, )), spatial_shapes.prod(1).cumsum(0)[:-1])) |
|
feat_flatten = feat_flatten.permute( |
|
0, 2, 1, 3) |
|
|
|
reference_points_cams, tpv_masks = [], [] |
|
ref_3ds = [self.ref_3d_hw, self.ref_3d_zh, self.ref_3d_wz] |
|
for ref_3d in ref_3ds: |
|
reference_points_cam, tpv_mask = self.point_sampling( |
|
ref_3d, self.pc_range, |
|
batch_data_samples) |
|
reference_points_cams.append(reference_points_cam) |
|
tpv_masks.append(tpv_mask) |
|
|
|
ref_cross_view = self.cross_view_ref_points.clone().unsqueeze( |
|
0).expand(bs, -1, -1, -1, -1) |
|
|
|
intermediate = [] |
|
for layer in self.layers: |
|
output = layer( |
|
tpv_query, |
|
feat_flatten, |
|
feat_flatten, |
|
tpv_pos=tpv_pos, |
|
ref_2d=ref_cross_view, |
|
tpv_h=self.tpv_h, |
|
tpv_w=self.tpv_w, |
|
tpv_z=self.tpv_z, |
|
spatial_shapes=spatial_shapes, |
|
level_start_index=level_start_index, |
|
reference_points_cams=reference_points_cams, |
|
tpv_masks=tpv_masks) |
|
tpv_query = output |
|
if self.return_intermediate: |
|
intermediate.append(output) |
|
|
|
if self.return_intermediate: |
|
return torch.stack(intermediate) |
|
|
|
return output |
|
|