mm3dtest / projects /CenterFormer /centerformer /centerformer_backbone.py
giantmonkeyTC
2344
34d1f8b
# modify from https://github.com/TuSimple/centerformer/blob/master/det3d/models/necks/rpn_transformer.py # noqa
from typing import List, Tuple
import numpy as np
import torch
from mmcv.cnn import build_norm_layer
from mmdet.models.utils import multi_apply
from mmengine.logging import print_log
from mmengine.structures import InstanceData
from torch import Tensor, nn
from mmdet3d.models.utils import draw_heatmap_gaussian, gaussian_radius
from mmdet3d.registry import MODELS
from mmdet3d.structures import center_to_corner_box2d
from .transformer import DeformableTransformerDecoder
class ChannelAttention(nn.Module):
def __init__(self, in_planes, ratio=16):
super(ChannelAttention, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.max_pool = nn.AdaptiveMaxPool2d(1)
self.fc = nn.Sequential(
nn.Conv2d(in_planes, in_planes // 16, 1, bias=False),
nn.ReLU(),
nn.Conv2d(in_planes // 16, in_planes, 1, bias=False),
)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
avg_out = self.fc(self.avg_pool(x))
max_out = self.fc(self.max_pool(x))
out = avg_out + max_out
return self.sigmoid(out) * x
class SpatialAttention(nn.Module):
def __init__(self, kernel_size=7):
super(SpatialAttention, self).__init__()
self.conv1 = nn.Conv2d(
2, 1, kernel_size, padding=kernel_size // 2, bias=False)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
avg_out = torch.mean(x, dim=1, keepdim=True)
max_out, _ = torch.max(x, dim=1, keepdim=True)
y = torch.cat([avg_out, max_out], dim=1)
y = self.conv1(y)
return self.sigmoid(y) * x
class MultiFrameSpatialAttention(nn.Module):
def __init__(self, kernel_size=7):
super(MultiFrameSpatialAttention, self).__init__()
self.conv1 = nn.Conv2d(
2, 1, kernel_size, padding=kernel_size // 2, bias=False)
self.sigmoid = nn.Sigmoid()
def forward(self, curr, prev):
avg_out = torch.mean(curr, dim=1, keepdim=True)
max_out, _ = torch.max(curr, dim=1, keepdim=True)
y = torch.cat([avg_out, max_out], dim=1)
y = self.conv1(y)
return self.sigmoid(y) * prev
class BaseDecoderRPN(nn.Module):
def __init__(
self,
layer_nums, # [2,2,2]
ds_num_filters, # [128,256,64]
num_input_features, # 256
transformer_config=None,
hm_head_layer=2,
corner_head_layer=2,
corner=False,
assign_label_window_size=1,
classes=3,
use_gt_training=False,
norm_cfg=None,
logger=None,
init_bias=-2.19,
score_threshold=0.1,
obj_num=500,
**kwargs):
super(BaseDecoderRPN, self).__init__()
self._layer_strides = [1, 2, -4]
self._num_filters = ds_num_filters
self._layer_nums = layer_nums
self._num_input_features = num_input_features
self.score_threshold = score_threshold
self.transformer_config = transformer_config
self.corner = corner
self.obj_num = obj_num
self.use_gt_training = use_gt_training
self.window_size = assign_label_window_size**2
self.cross_attention_kernel_size = [3, 3, 3]
self.batch_id = None
if norm_cfg is None:
norm_cfg = dict(type='BN', eps=1e-3, momentum=0.01)
self._norm_cfg = norm_cfg
assert len(self._layer_strides) == len(self._layer_nums)
assert len(self._num_filters) == len(self._layer_nums)
assert self.transformer_config is not None
in_filters = [
self._num_input_features,
self._num_filters[0],
self._num_filters[1],
]
blocks = []
for i, layer_num in enumerate(self._layer_nums):
block, num_out_filters = self._make_layer(
in_filters[i],
self._num_filters[i],
layer_num,
stride=self._layer_strides[i],
)
blocks.append(block)
self.blocks = nn.ModuleList(blocks)
self.up = nn.Sequential(
nn.ConvTranspose2d(
self._num_filters[0],
self._num_filters[2],
2,
stride=2,
bias=False),
build_norm_layer(self._norm_cfg, self._num_filters[2])[1],
nn.ReLU())
# heatmap prediction
hm_head = []
for i in range(hm_head_layer - 1):
hm_head.append(
nn.Conv2d(
self._num_filters[-1] * 2,
64,
kernel_size=3,
stride=1,
padding=1,
bias=True,
))
hm_head.append(build_norm_layer(self._norm_cfg, 64)[1])
hm_head.append(nn.ReLU())
hm_head.append(
nn.Conv2d(
64, classes, kernel_size=3, stride=1, padding=1, bias=True))
hm_head[-1].bias.data.fill_(init_bias)
self.hm_head = nn.Sequential(*hm_head)
if self.corner:
self.corner_head = []
for i in range(corner_head_layer - 1):
self.corner_head.append(
nn.Conv2d(
self._num_filters[-1] * 2,
64,
kernel_size=3,
stride=1,
padding=1,
bias=True,
))
self.corner_head.append(
build_norm_layer(self._norm_cfg, 64)[1])
self.corner_head.append(nn.ReLU())
self.corner_head.append(
nn.Conv2d(
64, 1, kernel_size=3, stride=1, padding=1, bias=True))
self.corner_head[-1].bias.data.fill_(init_bias)
self.corner_head = nn.Sequential(*self.corner_head)
def _make_layer(self, inplanes, planes, num_blocks, stride=1):
if stride > 0:
block = [
nn.ZeroPad2d(1),
nn.Conv2d(inplanes, planes, 3, stride=stride, bias=False),
build_norm_layer(self._norm_cfg, planes)[1],
nn.ReLU(),
]
else:
block = [
nn.ConvTranspose2d(
inplanes, planes, -stride, stride=-stride, bias=False),
build_norm_layer(self._norm_cfg, planes)[1],
nn.ReLU(),
]
for j in range(num_blocks):
block.append(nn.Conv2d(planes, planes, 3, padding=1, bias=False))
block.append(build_norm_layer(self._norm_cfg, planes)[1], )
block.append(nn.ReLU())
block.append(ChannelAttention(planes))
block.append(SpatialAttention())
block = nn.Sequential(*block)
return block, planes
def forward(self, x, example=None):
pass
def get_multi_scale_feature(self, center_pos, feats):
"""
Args:
center_pos: center coor at the lowest scale feature map [B 500 2]
feats: multi scale BEV feature 3*[B C H W]
Returns:
neighbor_feat: [B 500 K C]
neighbor_pos: [B 500 K 2]
"""
kernel_size = self.cross_attention_kernel_size
batch, num_cls, H, W = feats[0].size()
center_num = center_pos.shape[1]
relative_pos_list = []
neighbor_feat_list = []
for i, k in enumerate(kernel_size):
neighbor_coords = torch.arange(-(k // 2), (k // 2) + 1)
neighbor_coords = torch.flatten(
torch.stack(
torch.meshgrid([neighbor_coords, neighbor_coords]), dim=0),
1,
) # [2, k]
neighbor_coords = (neighbor_coords.permute(
1,
0).contiguous().to(center_pos)) # relative coordinate [k, 2]
neighbor_coords = (center_pos[:, :, None, :] // (2**i) +
neighbor_coords[None, None, :, :]
) # coordinates [B, 500, k, 2]
neighbor_coords = torch.clamp(
neighbor_coords, min=0,
max=H // (2**i) - 1) # prevent out of bound
feat_id = (neighbor_coords[:, :, :, 1] * (W // (2**i)) +
neighbor_coords[:, :, :, 0]) # pixel id [B, 500, k]
feat_id = feat_id.reshape(batch, -1) # pixel id [B, 500*k]
selected_feat = (
feats[i].reshape(batch, num_cls, (H * W) // (4**i)).permute(
0, 2, 1).contiguous()[self.batch_id.repeat(1, k**2),
feat_id]) # B, 500*k, C
neighbor_feat_list.append(
selected_feat.reshape(batch, center_num, -1,
num_cls)) # B, 500, k, C
relative_pos_list.append(neighbor_coords * (2**i)) # B, 500, k, 2
neighbor_pos = torch.cat(relative_pos_list, dim=2) # B, 500, K, 2/3
neighbor_feats = torch.cat(neighbor_feat_list, dim=2) # B, 500, K, C
return neighbor_feats, neighbor_pos
def get_multi_scale_feature_multiframe(self, center_pos, feats, timeframe):
"""
Args:
center_pos: center coor at the lowest scale feature map [B 500 2]
feats: multi scale BEV feature (3+k)*[B C H W]
timeframe: timeframe [B,k]
Returns:
neighbor_feat: [B 500 K C]
neighbor_pos: [B 500 K 2]
neighbor_time: [B 500 K 1]
"""
kernel_size = self.cross_attention_kernel_size
batch, num_cls, H, W = feats[0].size()
center_num = center_pos.shape[1]
relative_pos_list = []
neighbor_feat_list = []
timeframe_list = []
for i, k in enumerate(kernel_size):
neighbor_coords = torch.arange(-(k // 2), (k // 2) + 1)
neighbor_coords = torch.flatten(
torch.stack(
torch.meshgrid([neighbor_coords, neighbor_coords]), dim=0),
1,
) # [2, k]
neighbor_coords = (neighbor_coords.permute(
1,
0).contiguous().to(center_pos)) # relative coordinate [k, 2]
neighbor_coords = (center_pos[:, :, None, :] // (2**i) +
neighbor_coords[None, None, :, :]
) # coordinates [B, 500, k, 2]
neighbor_coords = torch.clamp(
neighbor_coords, min=0,
max=H // (2**i) - 1) # prevent out of bound
feat_id = (neighbor_coords[:, :, :, 1] * (W // (2**i)) +
neighbor_coords[:, :, :, 0]) # pixel id [B, 500, k]
feat_id = feat_id.reshape(batch, -1) # pixel id [B, 500*k]
selected_feat = (
feats[i].reshape(batch, num_cls, (H * W) // (4**i)).permute(
0, 2, 1).contiguous()[self.batch_id.repeat(1, k**2),
feat_id]) # B, 500*k, C
neighbor_feat_list.append(
selected_feat.reshape(batch, center_num, -1,
num_cls)) # B, 500, k, C
relative_pos_list.append(neighbor_coords * (2**i)) # B, 500, k, 2
timeframe_list.append(
torch.full_like(neighbor_coords[:, :, :, 0:1], 0)) # B, 500, k
if i == 0:
# add previous frame feature
for frame_num in range(feats[-1].shape[1]):
selected_feat = (feats[-1][:, frame_num, :, :, :].reshape(
batch, num_cls, (H * W) // (4**i)).permute(
0, 2,
1).contiguous()[self.batch_id.repeat(1, k**2),
feat_id]) # B, 500*k, C
neighbor_feat_list.append(
selected_feat.reshape(batch, center_num, -1, num_cls))
relative_pos_list.append(neighbor_coords * (2**i))
time = timeframe[:, frame_num + 1].to(selected_feat) # B
timeframe_list.append(
time[:, None, None, None] * torch.full_like(
neighbor_coords[:, :, :, 0:1], 1)) # B, 500, k
neighbor_pos = torch.cat(relative_pos_list, dim=2) # B, 500, K, 2/3
neighbor_feats = torch.cat(neighbor_feat_list, dim=2) # B, 500, K, C
neighbor_time = torch.cat(timeframe_list, dim=2) # B, 500, K, 1
return neighbor_feats, neighbor_pos, neighbor_time
@MODELS.register_module()
class DeformableDecoderRPN(BaseDecoderRPN):
"""The original implement of CenterFormer modules.
It fuse the backbone, neck and heatmap head into one module. The backbone
is `SECOND` with attention and the neck is `SECONDFPN` with attention.
TODO: split this module into backbone、neck and head.
"""
def __init__(self,
layer_nums,
ds_num_filters,
num_input_features,
tasks=dict(),
transformer_config=None,
hm_head_layer=2,
corner_head_layer=2,
corner=False,
parametric_embedding=False,
assign_label_window_size=1,
classes=3,
use_gt_training=False,
norm_cfg=None,
logger=None,
init_bias=-2.19,
score_threshold=0.1,
obj_num=500,
train_cfg=None,
test_cfg=None,
**kwargs):
super(DeformableDecoderRPN, self).__init__(
layer_nums,
ds_num_filters,
num_input_features,
transformer_config,
hm_head_layer,
corner_head_layer,
corner,
assign_label_window_size,
classes,
use_gt_training,
norm_cfg,
logger,
init_bias,
score_threshold,
obj_num,
)
self.train_cfg = train_cfg
self.test_cfg = test_cfg
self.tasks = tasks
self.class_names = [t['class_names'] for t in tasks]
self.transformer_decoder = DeformableTransformerDecoder(
self._num_filters[-1] * 2,
depth=transformer_config.depth,
n_heads=transformer_config.n_heads,
dim_single_head=transformer_config.dim_single_head,
dim_ffn=transformer_config.dim_ffn,
dropout=transformer_config.dropout,
out_attention=transformer_config.out_attn,
n_points=transformer_config.get('n_points', 9),
)
self.pos_embedding_type = transformer_config.get(
'pos_embedding_type', 'linear')
if self.pos_embedding_type == 'linear':
self.pos_embedding = nn.Linear(2, self._num_filters[-1] * 2)
else:
raise NotImplementedError()
self.parametric_embedding = parametric_embedding
if self.parametric_embedding:
self.query_embed = nn.Embedding(self.obj_num,
self._num_filters[-1] * 2)
nn.init.uniform_(self.query_embed.weight, -1.0, 1.0)
print_log('Finish RPN_transformer_deformable Initialization',
'current')
def _sigmoid(self, x):
y = torch.clamp(x.sigmoid_(), min=1e-4, max=1 - 1e-4)
return y
def forward(self, x, batch_data_samples):
batch_gt_instance_3d = []
for data_sample in batch_data_samples:
batch_gt_instance_3d.append(data_sample.gt_instances_3d)
# FPN
x = self.blocks[0](x)
x_down = self.blocks[1](x)
x_up = torch.cat([self.blocks[2](x_down), self.up(x)], dim=1)
# heatmap head
hm = self.hm_head(x_up)
if self.corner and self.corner_head.training:
corner_hm = self.corner_head(x_up)
corner_hm = self._sigmoid(corner_hm)
# find top K center location
hm = self._sigmoid(hm)
batch, num_cls, H, W = hm.size()
scores, labels = torch.max(
hm.reshape(batch, num_cls, H * W), dim=1) # b,H*W
self.batch_id = torch.from_numpy(np.indices(
(batch, self.obj_num))[0]).to(labels)
if self.training:
heatmaps, anno_boxes, gt_inds, gt_masks, corner_heatmaps, cat_labels = self.get_targets( # noqa: E501
batch_gt_instance_3d)
batch_targets = dict(
ind=gt_inds,
mask=gt_masks,
hm=heatmaps,
anno_box=anno_boxes,
corners=corner_heatmaps,
cat=cat_labels)
inds = gt_inds[0][:, (self.window_size // 2)::self.window_size]
masks = gt_masks[0][:, (self.window_size // 2)::self.window_size]
batch_id_gt = torch.from_numpy(
np.indices((batch, inds.shape[1]))[0]).to(labels)
scores[batch_id_gt, inds] = scores[batch_id_gt, inds] + masks
order = scores.sort(1, descending=True)[1]
order = order[:, :self.obj_num]
scores[batch_id_gt, inds] = scores[batch_id_gt, inds] - masks
else:
order = scores.sort(1, descending=True)[1]
order = order[:, :self.obj_num]
batch_targets = None
scores = torch.gather(scores, 1, order)
labels = torch.gather(labels, 1, order)
mask = scores > self.score_threshold
ct_feat = x_up.reshape(batch, -1, H * W).transpose(2, 1).contiguous()
ct_feat = ct_feat[self.batch_id, order] # B, 500, C
# create position embedding for each center
y_coor = order // W
x_coor = order - y_coor * W
y_coor, x_coor = y_coor.to(ct_feat), x_coor.to(ct_feat)
y_coor, x_coor = y_coor / H, x_coor / W
pos_features = torch.stack([x_coor, y_coor], dim=2)
if self.parametric_embedding:
ct_feat = self.query_embed.weight
ct_feat = ct_feat.unsqueeze(0).expand(batch, -1, -1)
# run transformer
src = torch.cat(
(
x_up.reshape(batch, -1, H * W).transpose(2, 1).contiguous(),
x.reshape(batch, -1,
(H * W) // 4).transpose(2, 1).contiguous(),
x_down.reshape(batch, -1,
(H * W) // 16).transpose(2, 1).contiguous(),
),
dim=1,
) # B ,sum(H*W), C
spatial_shapes = torch.as_tensor(
[(H, W), (H // 2, W // 2), (H // 4, W // 4)],
dtype=torch.long,
device=ct_feat.device,
)
level_start_index = torch.cat((
spatial_shapes.new_zeros((1, )),
spatial_shapes.prod(1).cumsum(0)[:-1],
))
transformer_out = self.transformer_decoder(
ct_feat,
self.pos_embedding,
src,
spatial_shapes,
level_start_index,
center_pos=pos_features,
) # (B,N,C)
ct_feat = (transformer_out['ct_feat'].transpose(2, 1).contiguous()
) # B, C, 500
out_dict = {
'hm': hm,
'scores': scores,
'labels': labels,
'order': order,
'ct_feat': ct_feat,
'mask': mask,
}
if 'out_attention' in transformer_out:
out_dict.update(
{'out_attention': transformer_out['out_attention']})
if self.corner and self.corner_head.training:
out_dict.update({'corner_hm': corner_hm})
return out_dict, batch_targets
def get_targets(
self,
batch_gt_instances_3d: List[InstanceData],
) -> Tuple[List[Tensor]]:
"""Generate targets. How each output is transformed: Each nested list
is transposed so that all same-index elements in each sub-list (1, ...,
N) become the new sub-lists.
[ [a0, a1, a2, ... ], [b0, b1, b2, ... ], ... ]
==> [ [a0, b0, ... ], [a1, b1, ... ], [a2, b2, ... ] ]
The new transposed nested list is converted into a list of N
tensors generated by concatenating tensors in the new sub-lists.
[ tensor0, tensor1, tensor2, ... ]
Args:
batch_gt_instances_3d (list[:obj:`InstanceData`]): Batch of
gt_instances. It usually includes ``bboxes_3d`` and
``labels_3d`` attributes.
Returns:
Returns:
tuple[list[torch.Tensor]]: Tuple of target including
the following results in order.
- list[torch.Tensor]: Heatmap scores.
- list[torch.Tensor]: Ground truth boxes.
- list[torch.Tensor]: Indexes indicating the
position of the valid boxes.
- list[torch.Tensor]: Masks indicating which
boxes are valid.
- list[torch.Tensor]: catagrate labels.
"""
heatmaps, anno_boxes, inds, masks, corner_heatmaps, cat_labels = multi_apply( # noqa: E501
self.get_targets_single, batch_gt_instances_3d)
# Transpose heatmaps
heatmaps = list(map(list, zip(*heatmaps)))
heatmaps = [torch.stack(hms_) for hms_ in heatmaps]
# Transpose heatmaps
corner_heatmaps = list(map(list, zip(*corner_heatmaps)))
corner_heatmaps = [torch.stack(hms_) for hms_ in corner_heatmaps]
# Transpose anno_boxes
anno_boxes = list(map(list, zip(*anno_boxes)))
anno_boxes = [torch.stack(anno_boxes_) for anno_boxes_ in anno_boxes]
# Transpose inds
inds = list(map(list, zip(*inds)))
inds = [torch.stack(inds_) for inds_ in inds]
# Transpose inds
masks = list(map(list, zip(*masks)))
masks = [torch.stack(masks_) for masks_ in masks]
# Transpose cat_labels
cat_labels = list(map(list, zip(*cat_labels)))
cat_labels = [torch.stack(labels_) for labels_ in cat_labels]
return heatmaps, anno_boxes, inds, masks, corner_heatmaps, cat_labels
def get_targets_single(self,
gt_instances_3d: InstanceData) -> Tuple[Tensor]:
"""Generate training targets for a single sample.
Args:
gt_instances_3d (:obj:`InstanceData`): Gt_instances of
single data sample. It usually includes
``bboxes_3d`` and ``labels_3d`` attributes.
Returns:
tuple[list[torch.Tensor]]: Tuple of target including
the following results in order.
- list[torch.Tensor]: Heatmap scores.
- list[torch.Tensor]: Ground truth boxes.
- list[torch.Tensor]: Indexes indicating the position
of the valid boxes.
- list[torch.Tensor]: Masks indicating which boxes
are valid.
- list[torch.Tensor]: catagrate labels.
"""
gt_labels_3d = gt_instances_3d.labels_3d
gt_bboxes_3d = gt_instances_3d.bboxes_3d
device = gt_labels_3d.device
gt_bboxes_3d = torch.cat(
(gt_bboxes_3d.gravity_center, gt_bboxes_3d.tensor[:, 3:]),
dim=1).to(device)
max_objs = self.train_cfg['max_objs'] * self.train_cfg['dense_reg']
grid_size = torch.tensor(self.train_cfg['grid_size'])
pc_range = torch.tensor(self.train_cfg['point_cloud_range'])
voxel_size = torch.tensor(self.train_cfg['voxel_size'])
feature_map_size = grid_size[:2] // self.train_cfg['out_size_factor']
# reorganize the gt_dict by tasks
task_masks = []
flag = 0
for class_name in self.class_names:
task_masks.append([
torch.where(gt_labels_3d == class_name.index(i) + flag)
for i in class_name
])
flag += len(class_name)
task_boxes = []
task_classes = []
flag2 = 0
for idx, mask in enumerate(task_masks):
task_box = []
task_class = []
for m in mask:
task_box.append(gt_bboxes_3d[m])
# 0 is background for each task, so we need to add 1 here.
task_class.append(gt_labels_3d[m] + 1 - flag2)
task_boxes.append(torch.cat(task_box, axis=0).to(device))
task_classes.append(torch.cat(task_class).long().to(device))
flag2 += len(mask)
draw_gaussian = draw_heatmap_gaussian
heatmaps, anno_boxes, inds, masks, corner_heatmaps, cat_labels = [], [], [], [], [], [] # noqa: E501
for idx in range(len(self.tasks)):
heatmap = gt_bboxes_3d.new_zeros(
(len(self.class_names[idx]), feature_map_size[1],
feature_map_size[0]))
corner_heatmap = torch.zeros(
(1, feature_map_size[1], feature_map_size[0]),
dtype=torch.float32,
device=device)
anno_box = gt_bboxes_3d.new_zeros((max_objs, 8),
dtype=torch.float32)
ind = gt_labels_3d.new_zeros((max_objs), dtype=torch.int64)
mask = gt_bboxes_3d.new_zeros((max_objs), dtype=torch.uint8)
cat_label = gt_bboxes_3d.new_zeros((max_objs), dtype=torch.int64)
num_objs = min(task_boxes[idx].shape[0], max_objs)
for k in range(num_objs):
cls_id = task_classes[idx][k] - 1
# gt boxes [xyzlwhr]
length = task_boxes[idx][k][3]
width = task_boxes[idx][k][4]
length = length / voxel_size[0] / self.train_cfg[
'out_size_factor']
width = width / voxel_size[1] / self.train_cfg[
'out_size_factor']
if width > 0 and length > 0:
radius = gaussian_radius(
(width, length),
min_overlap=self.train_cfg['gaussian_overlap'])
radius = max(self.train_cfg['min_radius'], int(radius))
# be really careful for the coordinate system of
# your box annotation.
x, y, z = task_boxes[idx][k][0], task_boxes[idx][k][
1], task_boxes[idx][k][2]
coor_x = (
x - pc_range[0]
) / voxel_size[0] / self.train_cfg['out_size_factor']
coor_y = (
y - pc_range[1]
) / voxel_size[1] / self.train_cfg['out_size_factor']
center = torch.tensor([coor_x, coor_y],
dtype=torch.float32,
device=device)
center_int = center.to(torch.int32)
# throw out not in range objects to avoid out of array
# area when creating the heatmap
if not (0 <= center_int[0] < feature_map_size[0]
and 0 <= center_int[1] < feature_map_size[1]):
continue
draw_gaussian(heatmap[cls_id], center_int, radius)
radius = radius // 2
# # draw four corner and center TODO: use torch
rot = task_boxes[idx][k][6]
corner_keypoints = center_to_corner_box2d(
center.unsqueeze(0).cpu().numpy(),
torch.tensor([[length, width]],
dtype=torch.float32).numpy(),
angles=rot,
origin=0.5)
corner_keypoints = torch.from_numpy(corner_keypoints).to(
center)
draw_gaussian(corner_heatmap[0], center_int, radius)
draw_gaussian(
corner_heatmap[0],
(corner_keypoints[0, 0] + corner_keypoints[0, 1]) / 2,
radius)
draw_gaussian(
corner_heatmap[0],
(corner_keypoints[0, 2] + corner_keypoints[0, 3]) / 2,
radius)
draw_gaussian(
corner_heatmap[0],
(corner_keypoints[0, 0] + corner_keypoints[0, 3]) / 2,
radius)
draw_gaussian(
corner_heatmap[0],
(corner_keypoints[0, 1] + corner_keypoints[0, 2]) / 2,
radius)
new_idx = k
x, y = center_int[0], center_int[1]
assert (y * feature_map_size[0] + x <
feature_map_size[0] * feature_map_size[1])
ind[new_idx] = y * feature_map_size[0] + x
mask[new_idx] = 1
cat_label[new_idx] = cls_id
# TODO: support other outdoor dataset
# vx, vy = task_boxes[idx][k][7:]
rot = task_boxes[idx][k][6]
box_dim = task_boxes[idx][k][3:6]
box_dim = box_dim.log()
anno_box[new_idx] = torch.cat([
center - torch.tensor([x, y], device=device),
z.unsqueeze(0), box_dim,
torch.sin(rot).unsqueeze(0),
torch.cos(rot).unsqueeze(0)
])
heatmaps.append(heatmap)
corner_heatmaps.append(corner_heatmap)
anno_boxes.append(anno_box)
masks.append(mask)
inds.append(ind)
cat_labels.append(cat_label)
return heatmaps, anno_boxes, inds, masks, corner_heatmaps, cat_labels
@MODELS.register_module()
class MultiFrameDeformableDecoderRPN(BaseDecoderRPN):
"""The original implementation of CenterFormer modules.
The difference between this module and
`DeformableDecoderRPN` is that this module uses information from multi
frames.
TODO: split this module into backbone、neck and head.
"""
def __init__(
self,
layer_nums, # [2,2,2]
ds_num_filters, # [128,256,64]
num_input_features, # 256
transformer_config=None,
hm_head_layer=2,
corner_head_layer=2,
corner=False,
parametric_embedding=False,
assign_label_window_size=1,
classes=3,
use_gt_training=False,
norm_cfg=None,
logger=None,
init_bias=-2.19,
score_threshold=0.1,
obj_num=500,
frame=1,
**kwargs):
super(MultiFrameDeformableDecoderRPN, self).__init__(
layer_nums,
ds_num_filters,
num_input_features,
transformer_config,
hm_head_layer,
corner_head_layer,
corner,
assign_label_window_size,
classes,
use_gt_training,
norm_cfg,
logger,
init_bias,
score_threshold,
obj_num,
)
self.frame = frame
self.out = nn.Sequential(
nn.Conv2d(
self._num_filters[0] * frame,
self._num_filters[0],
3,
padding=1,
bias=False,
),
build_norm_layer(self._norm_cfg, self._num_filters[0])[1],
nn.ReLU(),
)
self.mtf_attention = MultiFrameSpatialAttention()
self.time_embedding = nn.Linear(1, self._num_filters[0])
self.transformer_decoder = DeformableTransformerDecoder(
self._num_filters[-1] * 2,
depth=transformer_config.depth,
n_heads=transformer_config.n_heads,
n_levels=2 + self.frame,
dim_single_head=transformer_config.dim_single_head,
dim_ffn=transformer_config.dim_ffn,
dropout=transformer_config.dropout,
out_attention=transformer_config.out_attn,
n_points=transformer_config.get('n_points', 9),
)
self.pos_embedding_type = transformer_config.get(
'pos_embedding_type', 'linear')
if self.pos_embedding_type == 'linear':
self.pos_embedding = nn.Linear(2, self._num_filters[-1] * 2)
else:
raise NotImplementedError()
self.parametric_embedding = parametric_embedding
if self.parametric_embedding:
self.query_embed = nn.Embedding(self.obj_num,
self._num_filters[-1] * 2)
nn.init.uniform_(self.query_embed.weight, -1.0, 1.0)
print_log('Finish RPN_transformer_deformable Initialization',
'current')
def forward(self, x, example=None):
# FPN
x = self.blocks[0](x)
x_down = self.blocks[1](x)
x_up = torch.cat([self.blocks[2](x_down), self.up(x)], dim=1)
# take out the BEV feature on current frame
x = torch.split(x, self.frame)
x_up = torch.split(x_up, self.frame)
x_down = torch.split(x_down, self.frame)
x_prev = torch.stack([t[1:] for t in x_up], dim=0) # B,K,C,H,W
x = torch.stack([t[0] for t in x], dim=0)
x_down = torch.stack([t[0] for t in x_down], dim=0)
x_up = torch.stack([t[0] for t in x_up], dim=0) # B,C,H,W
# use spatial attention in current frame on previous feature
x_prev_cat = self.mtf_attention(
x_up,
x_prev.reshape(x_up.shape[0], -1, x_up.shape[2],
x_up.shape[3])) # B,K*C,H,W
# time embedding
x_up_fuse = torch.cat((x_up, x_prev_cat), dim=1) + self.time_embedding(
example['times'][:, :, None].to(x_up)).reshape(
x_up.shape[0], -1, 1, 1)
# fuse mtf feature
x_up_fuse = self.out(x_up_fuse)
# heatmap head
hm = self.hm_head(x_up_fuse)
if self.corner and self.corner_head.training:
corner_hm = self.corner_head(x_up_fuse)
corner_hm = torch.sigmoid(corner_hm)
# find top K center location
hm = torch.sigmoid(hm)
batch, num_cls, H, W = hm.size()
scores, labels = torch.max(
hm.reshape(batch, num_cls, H * W), dim=1) # b,H*W
self.batch_id = torch.from_numpy(np.indices(
(batch, self.obj_num))[0]).to(labels)
if self.use_gt_training and self.hm_head.training:
gt_inds = example['ind'][0][:, (self.window_size //
2)::self.window_size]
gt_masks = example['mask'][0][:, (self.window_size //
2)::self.window_size]
batch_id_gt = torch.from_numpy(
np.indices((batch, gt_inds.shape[1]))[0]).to(labels)
scores[batch_id_gt,
gt_inds] = scores[batch_id_gt, gt_inds] + gt_masks
order = scores.sort(1, descending=True)[1]
order = order[:, :self.obj_num]
scores[batch_id_gt,
gt_inds] = scores[batch_id_gt, gt_inds] - gt_masks
else:
order = scores.sort(1, descending=True)[1]
order = order[:, :self.obj_num]
scores = torch.gather(scores, 1, order)
labels = torch.gather(labels, 1, order)
mask = scores > self.score_threshold
ct_feat = (x_up.reshape(batch, -1,
H * W).transpose(2,
1).contiguous()[self.batch_id,
order]
) # B, 500, C
# create position embedding for each center
y_coor = order // W
x_coor = order - y_coor * W
y_coor, x_coor = y_coor.to(ct_feat), x_coor.to(ct_feat)
y_coor, x_coor = y_coor / H, x_coor / W
pos_features = torch.stack([x_coor, y_coor], dim=2)
if self.parametric_embedding:
ct_feat = self.query_embed.weight
ct_feat = ct_feat.unsqueeze(0).expand(batch, -1, -1)
# run transformer
src_list = [
x_up.reshape(batch, -1, H * W).transpose(2, 1).contiguous(),
x.reshape(batch, -1, (H * W) // 4).transpose(2, 1).contiguous(),
x_down.reshape(batch, -1, (H * W) // 16).transpose(2,
1).contiguous(),
]
for frame in range(x_prev.shape[1]):
src_list.append(x_prev[:, frame].reshape(batch,
-1, (H * W)).transpose(
2, 1).contiguous())
src = torch.cat(src_list, dim=1) # B ,sum(H*W), C
spatial_list = [(H, W), (H // 2, W // 2), (H // 4, W // 4)]
spatial_list += [(H, W) for frame in range(x_prev.shape[1])]
spatial_shapes = torch.as_tensor(
spatial_list, dtype=torch.long, device=ct_feat.device)
level_start_index = torch.cat((
spatial_shapes.new_zeros((1, )),
spatial_shapes.prod(1).cumsum(0)[:-1],
))
transformer_out = self.transformer_decoder(
ct_feat,
self.pos_embedding,
src,
spatial_shapes,
level_start_index,
center_pos=pos_features,
) # (B,N,C)
ct_feat = (transformer_out['ct_feat'].transpose(2, 1).contiguous()
) # B, C, 500
out_dict = {
'hm': hm,
'scores': scores,
'labels': labels,
'order': order,
'ct_feat': ct_feat,
'mask': mask,
}
if 'out_attention' in transformer_out:
out_dict.update(
{'out_attention': transformer_out['out_attention']})
if self.corner and self.corner_head.training:
out_dict.update({'corner_hm': corner_hm})
return out_dict