|
import warnings |
|
|
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from mmcv.cnn.bricks.transformer import (TransformerLayerSequence, |
|
build_transformer_layer_sequence) |
|
from mmcv.ops.multi_scale_deform_attn import MultiScaleDeformableAttention |
|
from mmengine.model import BaseModule, constant_init, xavier_init |
|
|
|
from mmdet3d.registry import MODELS |
|
|
|
|
|
def inverse_sigmoid(x, eps=1e-5): |
|
"""Inverse function of sigmoid. |
|
|
|
Args: |
|
x (Tensor): The tensor to do the |
|
inverse. |
|
eps (float): EPS avoid numerical |
|
overflow. Defaults 1e-5. |
|
Returns: |
|
Tensor: The x has passed the inverse |
|
function of sigmoid, has same |
|
shape with input. |
|
""" |
|
x = x.clamp(min=0, max=1) |
|
x1 = x.clamp(min=eps) |
|
x2 = (1 - x).clamp(min=eps) |
|
return torch.log(x1 / x2) |
|
|
|
|
|
@MODELS.register_module() |
|
class Detr3DTransformer(BaseModule): |
|
"""Implements the DETR3D transformer. |
|
|
|
Args: |
|
as_two_stage (bool): Generate query from encoder features. |
|
Default: False. |
|
num_feature_levels (int): Number of feature maps from FPN: |
|
Default: 4. |
|
num_cams (int): Number of cameras in the dataset. |
|
Default: 6 in NuScenes Det. |
|
two_stage_num_proposals (int): Number of proposals when set |
|
`as_two_stage` as True. Default: 300. |
|
""" |
|
|
|
def __init__(self, |
|
num_feature_levels=4, |
|
num_cams=6, |
|
two_stage_num_proposals=300, |
|
decoder=None, |
|
**kwargs): |
|
super(Detr3DTransformer, self).__init__(**kwargs) |
|
self.decoder = build_transformer_layer_sequence(decoder) |
|
self.embed_dims = self.decoder.embed_dims |
|
self.num_feature_levels = num_feature_levels |
|
self.num_cams = num_cams |
|
self.two_stage_num_proposals = two_stage_num_proposals |
|
self.init_layers() |
|
|
|
def init_layers(self): |
|
"""Initialize layers of the Detr3DTransformer.""" |
|
self.reference_points = nn.Linear(self.embed_dims, 3) |
|
|
|
def init_weights(self): |
|
"""Initialize the transformer weights.""" |
|
for p in self.parameters(): |
|
if p.dim() > 1: |
|
nn.init.xavier_uniform_(p) |
|
for m in self.modules(): |
|
if isinstance(m, MultiScaleDeformableAttention) or isinstance( |
|
m, Detr3DCrossAtten): |
|
m.init_weight() |
|
xavier_init(self.reference_points, distribution='uniform', bias=0.) |
|
|
|
def forward(self, mlvl_feats, query_embed, reg_branches=None, **kwargs): |
|
"""Forward function for `Detr3DTransformer`. |
|
Args: |
|
mlvl_feats (list(Tensor)): Input queries from |
|
different level. Each element has shape |
|
(B, N, C, H_lvl, W_lvl). |
|
query_embed (Tensor): The query positional and semantic embedding |
|
for decoder, with shape [num_query, c+c]. |
|
mlvl_pos_embeds (list(Tensor)): The positional encoding |
|
of feats from different level, has the shape |
|
[bs, N, embed_dims, h, w]. It is unused here. |
|
reg_branches (obj:`nn.ModuleList`): Regression heads for |
|
feature maps from each decoder layer. Only would |
|
be passed when `with_box_refine` is True. Default to None. |
|
Returns: |
|
tuple[Tensor]: results of decoder containing the following tensor. |
|
- inter_states: Outputs from decoder. If |
|
return_intermediate_dec is True output has shape |
|
(num_dec_layers, bs, num_query, embed_dims), else has |
|
shape (1, bs, num_query, embed_dims). |
|
- init_reference_out: The initial value of reference |
|
points, has shape (bs, num_queries, 4). |
|
- inter_references_out: The internal value of reference |
|
points in decoder, has shape |
|
(num_dec_layers, bs, num_query, embed_dims) |
|
""" |
|
assert query_embed is not None |
|
bs = mlvl_feats[0].size(0) |
|
query_pos, query = torch.split(query_embed, self.embed_dims, dim=1) |
|
query_pos = query_pos.unsqueeze(0).expand(bs, -1, -1) |
|
query = query.unsqueeze(0).expand(bs, -1, -1) |
|
reference_points = self.reference_points(query_pos) |
|
reference_points = reference_points.sigmoid() |
|
init_reference_out = reference_points |
|
|
|
|
|
query = query.permute(1, 0, 2) |
|
query_pos = query_pos.permute(1, 0, 2) |
|
inter_states, inter_references = self.decoder( |
|
query=query, |
|
key=None, |
|
value=mlvl_feats, |
|
query_pos=query_pos, |
|
reference_points=reference_points, |
|
reg_branches=reg_branches, |
|
**kwargs) |
|
|
|
inter_references_out = inter_references |
|
return inter_states, init_reference_out, inter_references_out |
|
|
|
|
|
@MODELS.register_module() |
|
class Detr3DTransformerDecoder(TransformerLayerSequence): |
|
"""Implements the decoder in DETR3D transformer. |
|
|
|
Args: |
|
return_intermediate (bool): Whether to return intermediate outputs. |
|
coder_norm_cfg (dict): Config of last normalization layer. Default: |
|
`LN`. |
|
""" |
|
|
|
def __init__(self, *args, return_intermediate=False, **kwargs): |
|
super(Detr3DTransformerDecoder, self).__init__(*args, **kwargs) |
|
self.return_intermediate = return_intermediate |
|
|
|
def forward(self, |
|
query, |
|
*args, |
|
reference_points=None, |
|
reg_branches=None, |
|
**kwargs): |
|
"""Forward function for `Detr3DTransformerDecoder`. |
|
Args: |
|
query (Tensor): Input query with shape |
|
`(num_query, bs, embed_dims)`. |
|
reference_points (Tensor): The reference |
|
points of offset. has shape |
|
(bs, num_query, 4) when as_two_stage, |
|
otherwise has shape self.reference_points = |
|
nn.Linear(self.embed_dims, 3) |
|
reg_branch: (obj:`nn.ModuleList`): Used for |
|
refining the regression results. Only would |
|
be passed when with_box_refine is True, |
|
otherwise would be passed a `None`. |
|
Returns: |
|
Tensor: Results with shape [1, num_query, bs, embed_dims] when |
|
return_intermediate is `False`, otherwise it has shape |
|
[num_layers, num_query, bs, embed_dims]. |
|
""" |
|
output = query |
|
intermediate = [] |
|
intermediate_reference_points = [] |
|
for lid, layer in enumerate(self.layers): |
|
reference_points_input = reference_points |
|
output = layer( |
|
output, |
|
*args, |
|
reference_points=reference_points_input, |
|
**kwargs) |
|
output = output.permute(1, 0, 2) |
|
if reg_branches is not None: |
|
tmp = reg_branches[lid](output) |
|
|
|
assert reference_points.shape[-1] == 3 |
|
|
|
new_reference_points = torch.zeros_like(reference_points) |
|
new_reference_points[..., :2] = tmp[..., :2] + inverse_sigmoid( |
|
reference_points[..., :2]) |
|
new_reference_points[..., |
|
2:3] = tmp[..., 4:5] + inverse_sigmoid( |
|
reference_points[..., 2:3]) |
|
new_reference_points = new_reference_points.sigmoid() |
|
|
|
reference_points = new_reference_points.detach() |
|
|
|
output = output.permute(1, 0, 2) |
|
if self.return_intermediate: |
|
intermediate.append(output) |
|
intermediate_reference_points.append(reference_points) |
|
|
|
if self.return_intermediate: |
|
return torch.stack(intermediate), torch.stack( |
|
intermediate_reference_points) |
|
|
|
return output, reference_points |
|
|
|
|
|
@MODELS.register_module() |
|
class Detr3DCrossAtten(BaseModule): |
|
"""An attention module used in Detr3d. |
|
|
|
Args: |
|
embed_dims (int): The embedding dimension of Attention. |
|
Default: 256. |
|
num_heads (int): Parallel attention heads. Default: 64. |
|
num_levels (int): The number of feature map used in |
|
Attention. Default: 4. |
|
num_points (int): The number of sampling points for |
|
each query in each head. Default: 4. |
|
im2col_step (int): The step used in image_to_column. |
|
Default: 64. |
|
dropout (float): A Dropout layer on `inp_residual`. |
|
Default: 0.. |
|
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. |
|
Default: None. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
embed_dims=256, |
|
num_heads=8, |
|
num_levels=4, |
|
num_points=5, |
|
num_cams=6, |
|
im2col_step=64, |
|
pc_range=None, |
|
dropout=0.1, |
|
norm_cfg=None, |
|
init_cfg=None, |
|
batch_first=False, |
|
): |
|
super(Detr3DCrossAtten, self).__init__(init_cfg) |
|
if embed_dims % num_heads != 0: |
|
raise ValueError(f'embed_dims must be divisible by num_heads, ' |
|
f'but got {embed_dims} and {num_heads}') |
|
dim_per_head = embed_dims // num_heads |
|
self.norm_cfg = norm_cfg |
|
self.init_cfg = init_cfg |
|
self.dropout = nn.Dropout(dropout) |
|
self.pc_range = pc_range |
|
|
|
|
|
|
|
def _is_power_of_2(n): |
|
if (not isinstance(n, int)) or (n < 0): |
|
raise ValueError( |
|
'invalid input for _is_power_of_2: {} (type: {})'.format( |
|
n, type(n))) |
|
return (n & (n - 1) == 0) and n != 0 |
|
|
|
if not _is_power_of_2(dim_per_head): |
|
warnings.warn( |
|
"You'd better set embed_dims in " |
|
'MultiScaleDeformAttention to make ' |
|
'the dimension of each attention head a power of 2 ' |
|
'which is more efficient in our CUDA implementation.') |
|
|
|
self.im2col_step = im2col_step |
|
self.embed_dims = embed_dims |
|
self.num_levels = num_levels |
|
self.num_heads = num_heads |
|
self.num_points = num_points |
|
self.num_cams = num_cams |
|
self.attention_weights = nn.Linear(embed_dims, |
|
num_cams * num_levels * num_points) |
|
|
|
self.output_proj = nn.Linear(embed_dims, embed_dims) |
|
|
|
self.position_encoder = nn.Sequential( |
|
nn.Linear(3, self.embed_dims), |
|
nn.LayerNorm(self.embed_dims), |
|
nn.ReLU(inplace=True), |
|
nn.Linear(self.embed_dims, self.embed_dims), |
|
nn.LayerNorm(self.embed_dims), |
|
nn.ReLU(inplace=True), |
|
) |
|
self.batch_first = batch_first |
|
self.init_weight() |
|
|
|
def init_weight(self): |
|
"""Default initialization for Parameters of Module.""" |
|
constant_init(self.attention_weights, val=0., bias=0.) |
|
xavier_init(self.output_proj, distribution='uniform', bias=0.) |
|
|
|
def forward(self, |
|
query, |
|
key, |
|
value, |
|
residual=None, |
|
query_pos=None, |
|
reference_points=None, |
|
**kwargs): |
|
"""Forward Function of Detr3DCrossAtten. |
|
|
|
Args: |
|
query (Tensor): Query of Transformer with shape |
|
(num_query, bs, embed_dims). |
|
key (Tensor): The key tensor with shape |
|
`(num_key, bs, embed_dims)`. |
|
value (List[Tensor]): Image features from |
|
different level. Each element has shape |
|
(B, N, C, H_lvl, W_lvl). |
|
residual (Tensor): The tensor used for addition, with the |
|
same shape as `x`. Default None. If None, `x` will be used. |
|
query_pos (Tensor): The positional encoding for `query`. |
|
Default: None. |
|
reference_points (Tensor): The normalized 3D reference |
|
points with shape (bs, num_query, 3) |
|
Returns: |
|
Tensor: forwarded results with shape [num_query, bs, embed_dims]. |
|
""" |
|
if key is None: |
|
key = query |
|
if value is None: |
|
value = key |
|
|
|
if residual is None: |
|
inp_residual = query |
|
if query_pos is not None: |
|
query = query + query_pos |
|
|
|
query = query.permute(1, 0, 2) |
|
|
|
bs, num_query, _ = query.size() |
|
|
|
attention_weights = self.attention_weights(query).view( |
|
bs, 1, num_query, self.num_cams, self.num_points, self.num_levels) |
|
reference_points_3d, output, mask = feature_sampling( |
|
value, reference_points, self.pc_range, kwargs['img_metas']) |
|
output = torch.nan_to_num(output) |
|
mask = torch.nan_to_num(mask) |
|
attention_weights = attention_weights.sigmoid() * mask |
|
output = output * attention_weights |
|
output = output.sum(-1).sum(-1).sum(-1) |
|
output = output.permute(2, 0, 1) |
|
|
|
output = self.output_proj(output) |
|
pos_feat = self.position_encoder( |
|
inverse_sigmoid(reference_points_3d)).permute(1, 0, 2) |
|
return self.dropout(output) + inp_residual + pos_feat |
|
|
|
|
|
def feature_sampling(mlvl_feats, |
|
ref_pt, |
|
pc_range, |
|
img_metas, |
|
no_sampling=False): |
|
""" sample multi-level features by projecting 3D reference points |
|
to 2D image |
|
Args: |
|
mlvl_feats (List[Tensor]): Image features from |
|
different level. Each element has shape |
|
(B, N, C, H_lvl, W_lvl). |
|
ref_pt (Tensor): The normalized 3D reference |
|
points with shape (bs, num_query, 3) |
|
pc_range: perception range of the detector |
|
img_metas (list[dict]): Meta information of multiple inputs |
|
in a batch, containing `lidar2img`. |
|
no_sampling (bool): If set 'True', the function will return |
|
2D projected points and mask only. |
|
Returns: |
|
ref_pt_3d (Tensor): A copy of original ref_pt |
|
sampled_feats (Tensor): sampled features with shape \ |
|
(B C num_q N 1 fpn_lvl) |
|
mask (Tensor): Determine whether the reference point \ |
|
has projected outsied of images, with shape \ |
|
(B 1 num_q N 1 1) |
|
""" |
|
lidar2img = [meta['lidar2img'] for meta in img_metas] |
|
lidar2img = np.asarray(lidar2img) |
|
lidar2img = ref_pt.new_tensor(lidar2img) |
|
ref_pt = ref_pt.clone() |
|
ref_pt_3d = ref_pt.clone() |
|
|
|
B, num_query = ref_pt.size()[:2] |
|
num_cam = lidar2img.size(1) |
|
eps = 1e-5 |
|
|
|
ref_pt[..., 0:1] = \ |
|
ref_pt[..., 0:1] * (pc_range[3] - pc_range[0]) + pc_range[0] |
|
ref_pt[..., 1:2] = \ |
|
ref_pt[..., 1:2] * (pc_range[4] - pc_range[1]) + pc_range[1] |
|
ref_pt[..., 2:3] = \ |
|
ref_pt[..., 2:3] * (pc_range[5] - pc_range[2]) + pc_range[2] |
|
|
|
|
|
ref_pt = torch.cat((ref_pt, torch.ones_like(ref_pt[..., :1])), -1) |
|
ref_pt = ref_pt.view(B, 1, num_query, 4) |
|
ref_pt = ref_pt.repeat(1, num_cam, 1, 1).unsqueeze(-1) |
|
|
|
lidar2img = lidar2img.view(B, num_cam, 1, 4, 4)\ |
|
.repeat(1, 1, num_query, 1, 1) |
|
|
|
pt_cam = torch.matmul(lidar2img, ref_pt).squeeze(-1) |
|
|
|
|
|
z = pt_cam[..., 2:3] |
|
eps = eps * torch.ones_like(z) |
|
mask = (z > eps) |
|
pt_cam = pt_cam[..., 0:2] / torch.maximum(z, eps) |
|
|
|
(h, w) = img_metas[0]['pad_shape'] |
|
pt_cam[..., 0] /= w |
|
pt_cam[..., 1] /= h |
|
|
|
|
|
|
|
|
|
|
|
|
|
mask = ( |
|
mask & (pt_cam[..., 0:1] > 0.0) |
|
& (pt_cam[..., 0:1] < 1.0) |
|
& (pt_cam[..., 1:2] > 0.0) |
|
& (pt_cam[..., 1:2] < 1.0)) |
|
|
|
if no_sampling: |
|
return pt_cam, mask |
|
|
|
|
|
mask = mask.view(B, num_cam, 1, num_query, 1, 1).permute(0, 2, 3, 1, 4, 5) |
|
mask = torch.nan_to_num(mask) |
|
|
|
pt_cam = (pt_cam - 0.5) * 2 |
|
sampled_feats = [] |
|
for lvl, feat in enumerate(mlvl_feats): |
|
B, N, C, H, W = feat.size() |
|
feat = feat.view(B * N, C, H, W) |
|
pt_cam_lvl = pt_cam.view(B * N, num_query, 1, 2) |
|
sampled_feat = F.grid_sample(feat, pt_cam_lvl) |
|
|
|
sampled_feat = sampled_feat.view(B, N, C, num_query, 1) |
|
sampled_feat = sampled_feat.permute(0, 2, 3, 1, 4) |
|
sampled_feats.append(sampled_feat) |
|
|
|
sampled_feats = torch.stack(sampled_feats, -1) |
|
|
|
sampled_feats = \ |
|
sampled_feats.view(B, C, num_query, num_cam, 1, len(mlvl_feats)) |
|
return ref_pt_3d, sampled_feats, mask |
|
|