|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from mmengine.model import BaseModule |
|
|
|
from mmdet3d.registry import MODELS |
|
|
|
|
|
@MODELS.register_module() |
|
class TPVFormerDecoder(BaseModule): |
|
|
|
def __init__(self, |
|
tpv_h, |
|
tpv_w, |
|
tpv_z, |
|
num_classes=20, |
|
in_dims=64, |
|
hidden_dims=128, |
|
out_dims=None, |
|
scale_h=2, |
|
scale_w=2, |
|
scale_z=2, |
|
ignore_index=0, |
|
loss_lovasz=None, |
|
loss_ce=None, |
|
lovasz_input='points', |
|
ce_input='voxel'): |
|
super().__init__() |
|
self.tpv_h = tpv_h |
|
self.tpv_w = tpv_w |
|
self.tpv_z = tpv_z |
|
self.scale_h = scale_h |
|
self.scale_w = scale_w |
|
self.scale_z = scale_z |
|
|
|
out_dims = in_dims if out_dims is None else out_dims |
|
self.in_dims = in_dims |
|
self.decoder = nn.Sequential( |
|
nn.Linear(in_dims, hidden_dims), nn.Softplus(), |
|
nn.Linear(hidden_dims, out_dims)) |
|
|
|
self.classifier = nn.Linear(out_dims, num_classes) |
|
self.loss_lovasz = MODELS.build(loss_lovasz) |
|
self.loss_ce = MODELS.build(loss_ce) |
|
self.ignore_index = ignore_index |
|
self.lovasz_input = lovasz_input |
|
self.ce_input = ce_input |
|
|
|
def forward(self, tpv_list, points=None): |
|
""" |
|
tpv_list[0]: bs, h*w, c |
|
tpv_list[1]: bs, z*h, c |
|
tpv_list[2]: bs, w*z, c |
|
""" |
|
tpv_hw, tpv_zh, tpv_wz = tpv_list[0], tpv_list[1], tpv_list[2] |
|
bs, _, c = tpv_hw.shape |
|
tpv_hw = tpv_hw.permute(0, 2, 1).reshape(bs, c, self.tpv_h, self.tpv_w) |
|
tpv_zh = tpv_zh.permute(0, 2, 1).reshape(bs, c, self.tpv_z, self.tpv_h) |
|
tpv_wz = tpv_wz.permute(0, 2, 1).reshape(bs, c, self.tpv_w, self.tpv_z) |
|
|
|
if self.scale_h != 1 or self.scale_w != 1: |
|
tpv_hw = F.interpolate( |
|
tpv_hw, |
|
size=(self.tpv_h * self.scale_h, self.tpv_w * self.scale_w), |
|
mode='bilinear') |
|
if self.scale_z != 1 or self.scale_h != 1: |
|
tpv_zh = F.interpolate( |
|
tpv_zh, |
|
size=(self.tpv_z * self.scale_z, self.tpv_h * self.scale_h), |
|
mode='bilinear') |
|
if self.scale_w != 1 or self.scale_z != 1: |
|
tpv_wz = F.interpolate( |
|
tpv_wz, |
|
size=(self.tpv_w * self.scale_w, self.tpv_z * self.scale_z), |
|
mode='bilinear') |
|
|
|
if points is not None: |
|
|
|
_, n, _ = points.shape |
|
points = points.reshape(bs, 1, n, 3).float() |
|
points[..., |
|
0] = points[..., 0] / (self.tpv_w * self.scale_w) * 2 - 1 |
|
points[..., |
|
1] = points[..., 1] / (self.tpv_h * self.scale_h) * 2 - 1 |
|
points[..., |
|
2] = points[..., 2] / (self.tpv_z * self.scale_z) * 2 - 1 |
|
sample_loc = points[:, :, :, [0, 1]] |
|
tpv_hw_pts = F.grid_sample(tpv_hw, |
|
sample_loc).squeeze(2) |
|
sample_loc = points[:, :, :, [1, 2]] |
|
tpv_zh_pts = F.grid_sample(tpv_zh, sample_loc).squeeze(2) |
|
sample_loc = points[:, :, :, [2, 0]] |
|
tpv_wz_pts = F.grid_sample(tpv_wz, sample_loc).squeeze(2) |
|
|
|
tpv_hw_vox = tpv_hw.unsqueeze(-1).permute(0, 1, 3, 2, 4).expand( |
|
-1, -1, -1, -1, self.scale_z * self.tpv_z) |
|
tpv_zh_vox = tpv_zh.unsqueeze(-1).permute(0, 1, 4, 3, 2).expand( |
|
-1, -1, self.scale_w * self.tpv_w, -1, -1) |
|
tpv_wz_vox = tpv_wz.unsqueeze(-1).permute(0, 1, 2, 4, 3).expand( |
|
-1, -1, -1, self.scale_h * self.tpv_h, -1) |
|
|
|
fused_vox = (tpv_hw_vox + tpv_zh_vox + tpv_wz_vox).flatten(2) |
|
fused_pts = tpv_hw_pts + tpv_zh_pts + tpv_wz_pts |
|
fused = torch.cat([fused_vox, fused_pts], dim=-1) |
|
|
|
fused = fused.permute(0, 2, 1) |
|
if self.use_checkpoint: |
|
fused = torch.utils.checkpoint.checkpoint(self.decoder, fused) |
|
logits = torch.utils.checkpoint.checkpoint( |
|
self.classifier, fused) |
|
else: |
|
fused = self.decoder(fused) |
|
logits = self.classifier(fused) |
|
logits = logits.permute(0, 2, 1) |
|
logits_vox = logits[:, :, :(-n)].reshape(bs, self.classes, |
|
self.scale_w * self.tpv_w, |
|
self.scale_h * self.tpv_h, |
|
self.scale_z * self.tpv_z) |
|
logits_pts = logits[:, :, (-n):].reshape(bs, self.classes, n, 1, 1) |
|
return logits_vox, logits_pts |
|
|
|
else: |
|
tpv_hw = tpv_hw.unsqueeze(-1).permute(0, 1, 3, 2, 4).expand( |
|
-1, -1, -1, -1, self.scale_z * self.tpv_z) |
|
tpv_zh = tpv_zh.unsqueeze(-1).permute(0, 1, 4, 3, 2).expand( |
|
-1, -1, self.scale_w * self.tpv_w, -1, -1) |
|
tpv_wz = tpv_wz.unsqueeze(-1).permute(0, 1, 2, 4, 3).expand( |
|
-1, -1, -1, self.scale_h * self.tpv_h, -1) |
|
|
|
fused = tpv_hw + tpv_zh + tpv_wz |
|
fused = fused.permute(0, 2, 3, 4, 1) |
|
if self.use_checkpoint: |
|
fused = torch.utils.checkpoint.checkpoint(self.decoder, fused) |
|
logits = torch.utils.checkpoint.checkpoint( |
|
self.classifier, fused) |
|
else: |
|
fused = self.decoder(fused) |
|
logits = self.classifier(fused) |
|
logits = logits.permute(0, 4, 1, 2, 3) |
|
|
|
return logits |
|
|
|
def predict(self, tpv_list, batch_data_samples): |
|
""" |
|
tpv_list[0]: bs, h*w, c |
|
tpv_list[1]: bs, z*h, c |
|
tpv_list[2]: bs, w*z, c |
|
""" |
|
tpv_hw, tpv_zh, tpv_wz = tpv_list |
|
bs, _, c = tpv_hw.shape |
|
tpv_hw = tpv_hw.permute(0, 2, 1).reshape(bs, c, self.tpv_h, self.tpv_w) |
|
tpv_zh = tpv_zh.permute(0, 2, 1).reshape(bs, c, self.tpv_z, self.tpv_h) |
|
tpv_wz = tpv_wz.permute(0, 2, 1).reshape(bs, c, self.tpv_w, self.tpv_z) |
|
|
|
if self.scale_h != 1 or self.scale_w != 1: |
|
tpv_hw = F.interpolate( |
|
tpv_hw, |
|
size=(self.tpv_h * self.scale_h, self.tpv_w * self.scale_w), |
|
mode='bilinear') |
|
if self.scale_z != 1 or self.scale_h != 1: |
|
tpv_zh = F.interpolate( |
|
tpv_zh, |
|
size=(self.tpv_z * self.scale_z, self.tpv_h * self.scale_h), |
|
mode='bilinear') |
|
if self.scale_w != 1 or self.scale_z != 1: |
|
tpv_wz = F.interpolate( |
|
tpv_wz, |
|
size=(self.tpv_w * self.scale_w, self.tpv_z * self.scale_z), |
|
mode='bilinear') |
|
|
|
logits = [] |
|
for i, data_sample in enumerate(batch_data_samples): |
|
point_coors = data_sample.point_coors.reshape(1, 1, -1, 3).float() |
|
point_coors[ |
|
..., |
|
0] = point_coors[..., 0] / (self.tpv_w * self.scale_w) * 2 - 1 |
|
point_coors[ |
|
..., |
|
1] = point_coors[..., 1] / (self.tpv_h * self.scale_h) * 2 - 1 |
|
point_coors[ |
|
..., |
|
2] = point_coors[..., 2] / (self.tpv_z * self.scale_z) * 2 - 1 |
|
sample_loc = point_coors[..., [0, 1]] |
|
tpv_hw_pts = F.grid_sample( |
|
tpv_hw[i:i + 1], sample_loc, align_corners=False) |
|
sample_loc = point_coors[..., [1, 2]] |
|
tpv_zh_pts = F.grid_sample( |
|
tpv_zh[i:i + 1], sample_loc, align_corners=False) |
|
sample_loc = point_coors[..., [2, 0]] |
|
tpv_wz_pts = F.grid_sample( |
|
tpv_wz[i:i + 1], sample_loc, align_corners=False) |
|
|
|
fused_pts = tpv_hw_pts + tpv_zh_pts + tpv_wz_pts |
|
|
|
fused_pts = fused_pts.squeeze(0).squeeze(1).transpose(0, 1) |
|
fused_pts = self.decoder(fused_pts) |
|
logit = self.classifier(fused_pts) |
|
logits.append(logit) |
|
|
|
return logits |
|
|
|
def loss(self, tpv_list, batch_data_samples): |
|
tpv_hw, tpv_zh, tpv_wz = tpv_list |
|
bs, _, c = tpv_hw.shape |
|
tpv_hw = tpv_hw.permute(0, 2, 1).reshape(bs, c, self.tpv_h, self.tpv_w) |
|
tpv_zh = tpv_zh.permute(0, 2, 1).reshape(bs, c, self.tpv_z, self.tpv_h) |
|
tpv_wz = tpv_wz.permute(0, 2, 1).reshape(bs, c, self.tpv_w, self.tpv_z) |
|
|
|
if self.scale_h != 1 or self.scale_w != 1: |
|
tpv_hw = F.interpolate( |
|
tpv_hw, |
|
size=(self.tpv_h * self.scale_h, self.tpv_w * self.scale_w), |
|
mode='bilinear') |
|
if self.scale_z != 1 or self.scale_h != 1: |
|
tpv_zh = F.interpolate( |
|
tpv_zh, |
|
size=(self.tpv_z * self.scale_z, self.tpv_h * self.scale_h), |
|
mode='bilinear') |
|
if self.scale_w != 1 or self.scale_z != 1: |
|
tpv_wz = F.interpolate( |
|
tpv_wz, |
|
size=(self.tpv_w * self.scale_w, self.tpv_z * self.scale_z), |
|
mode='bilinear') |
|
|
|
batch_pts, batch_vox = [], [] |
|
for i, data_sample in enumerate(batch_data_samples): |
|
point_coors = data_sample.point_coors.reshape(1, 1, -1, 3).float() |
|
point_coors[ |
|
..., |
|
0] = point_coors[..., 0] / (self.tpv_w * self.scale_w) * 2 - 1 |
|
point_coors[ |
|
..., |
|
1] = point_coors[..., 1] / (self.tpv_h * self.scale_h) * 2 - 1 |
|
point_coors[ |
|
..., |
|
2] = point_coors[..., 2] / (self.tpv_z * self.scale_z) * 2 - 1 |
|
sample_loc = point_coors[..., [0, 1]] |
|
tpv_hw_pts = F.grid_sample( |
|
tpv_hw[i:i + 1], sample_loc, align_corners=False) |
|
sample_loc = point_coors[..., [1, 2]] |
|
tpv_zh_pts = F.grid_sample( |
|
tpv_zh[i:i + 1], sample_loc, align_corners=False) |
|
sample_loc = point_coors[..., [2, 0]] |
|
tpv_wz_pts = F.grid_sample( |
|
tpv_wz[i:i + 1], sample_loc, align_corners=False) |
|
fused_pts = (tpv_hw_pts + tpv_zh_pts + |
|
tpv_wz_pts).squeeze(0).squeeze(1) |
|
batch_pts.append(fused_pts) |
|
|
|
tpv_hw_vox = tpv_hw.unsqueeze(-1).permute(0, 1, 3, 2, 4).expand( |
|
-1, -1, -1, -1, self.scale_z * self.tpv_z) |
|
tpv_zh_vox = tpv_zh.unsqueeze(-1).permute(0, 1, 4, 3, 2).expand( |
|
-1, -1, self.scale_w * self.tpv_w, -1, -1) |
|
tpv_wz_vox = tpv_wz.unsqueeze(-1).permute(0, 1, 2, 4, 3).expand( |
|
-1, -1, -1, self.scale_h * self.tpv_h, -1) |
|
fused_vox = tpv_hw_vox + tpv_zh_vox + tpv_wz_vox |
|
voxel_coors = data_sample.voxel_coors.long() |
|
fused_vox = fused_vox[:, :, voxel_coors[:, 0], voxel_coors[:, 1], |
|
voxel_coors[:, 2]] |
|
fused_vox = fused_vox.squeeze(0) |
|
batch_vox.append(fused_vox) |
|
batch_pts = torch.cat(batch_pts, dim=1) |
|
batch_vox = torch.cat(batch_vox, dim=1) |
|
num_points = batch_pts.shape[1] |
|
|
|
logits = self.decoder( |
|
torch.cat([batch_pts, batch_vox], dim=1).transpose(0, 1)) |
|
logits = self.classifier(logits) |
|
pts_logits = logits[:num_points, :] |
|
vox_logits = logits[num_points:, :] |
|
|
|
pts_seg_label = torch.cat([ |
|
data_sample.gt_pts_seg.pts_semantic_mask |
|
for data_sample in batch_data_samples |
|
]) |
|
voxel_seg_label = torch.cat([ |
|
data_sample.gt_pts_seg.voxel_semantic_mask |
|
for data_sample in batch_data_samples |
|
]) |
|
if self.ce_input == 'voxel': |
|
ce_input = vox_logits |
|
ce_label = voxel_seg_label |
|
else: |
|
ce_input = pts_logits |
|
ce_label = pts_seg_label |
|
if self.lovasz_input == 'voxel': |
|
lovasz_input = vox_logits |
|
lovasz_label = voxel_seg_label |
|
else: |
|
lovasz_input = pts_logits |
|
lovasz_label = pts_seg_label |
|
|
|
loss = dict() |
|
loss['loss_ce'] = self.loss_ce( |
|
ce_input, ce_label, ignore_index=self.ignore_index) |
|
loss['loss_lovasz'] = self.loss_lovasz( |
|
lovasz_input, lovasz_label, ignore_index=self.ignore_index) |
|
return loss |
|
|