3dtest / projects /TPVFormer /tpvformer /tpvformer_head.py
giantmonkeyTC
mm2
c2ca15f
raw
history blame
12.6 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmengine.model import BaseModule
from mmdet3d.registry import MODELS
@MODELS.register_module()
class TPVFormerDecoder(BaseModule):
def __init__(self,
tpv_h,
tpv_w,
tpv_z,
num_classes=20,
in_dims=64,
hidden_dims=128,
out_dims=None,
scale_h=2,
scale_w=2,
scale_z=2,
ignore_index=0,
loss_lovasz=None,
loss_ce=None,
lovasz_input='points',
ce_input='voxel'):
super().__init__()
self.tpv_h = tpv_h
self.tpv_w = tpv_w
self.tpv_z = tpv_z
self.scale_h = scale_h
self.scale_w = scale_w
self.scale_z = scale_z
out_dims = in_dims if out_dims is None else out_dims
self.in_dims = in_dims
self.decoder = nn.Sequential(
nn.Linear(in_dims, hidden_dims), nn.Softplus(),
nn.Linear(hidden_dims, out_dims))
self.classifier = nn.Linear(out_dims, num_classes)
self.loss_lovasz = MODELS.build(loss_lovasz)
self.loss_ce = MODELS.build(loss_ce)
self.ignore_index = ignore_index
self.lovasz_input = lovasz_input
self.ce_input = ce_input
def forward(self, tpv_list, points=None):
"""
tpv_list[0]: bs, h*w, c
tpv_list[1]: bs, z*h, c
tpv_list[2]: bs, w*z, c
"""
tpv_hw, tpv_zh, tpv_wz = tpv_list[0], tpv_list[1], tpv_list[2]
bs, _, c = tpv_hw.shape
tpv_hw = tpv_hw.permute(0, 2, 1).reshape(bs, c, self.tpv_h, self.tpv_w)
tpv_zh = tpv_zh.permute(0, 2, 1).reshape(bs, c, self.tpv_z, self.tpv_h)
tpv_wz = tpv_wz.permute(0, 2, 1).reshape(bs, c, self.tpv_w, self.tpv_z)
if self.scale_h != 1 or self.scale_w != 1:
tpv_hw = F.interpolate(
tpv_hw,
size=(self.tpv_h * self.scale_h, self.tpv_w * self.scale_w),
mode='bilinear')
if self.scale_z != 1 or self.scale_h != 1:
tpv_zh = F.interpolate(
tpv_zh,
size=(self.tpv_z * self.scale_z, self.tpv_h * self.scale_h),
mode='bilinear')
if self.scale_w != 1 or self.scale_z != 1:
tpv_wz = F.interpolate(
tpv_wz,
size=(self.tpv_w * self.scale_w, self.tpv_z * self.scale_z),
mode='bilinear')
if points is not None:
# points: bs, n, 3
_, n, _ = points.shape
points = points.reshape(bs, 1, n, 3).float()
points[...,
0] = points[..., 0] / (self.tpv_w * self.scale_w) * 2 - 1
points[...,
1] = points[..., 1] / (self.tpv_h * self.scale_h) * 2 - 1
points[...,
2] = points[..., 2] / (self.tpv_z * self.scale_z) * 2 - 1
sample_loc = points[:, :, :, [0, 1]]
tpv_hw_pts = F.grid_sample(tpv_hw,
sample_loc).squeeze(2) # bs, c, n
sample_loc = points[:, :, :, [1, 2]]
tpv_zh_pts = F.grid_sample(tpv_zh, sample_loc).squeeze(2)
sample_loc = points[:, :, :, [2, 0]]
tpv_wz_pts = F.grid_sample(tpv_wz, sample_loc).squeeze(2)
tpv_hw_vox = tpv_hw.unsqueeze(-1).permute(0, 1, 3, 2, 4).expand(
-1, -1, -1, -1, self.scale_z * self.tpv_z)
tpv_zh_vox = tpv_zh.unsqueeze(-1).permute(0, 1, 4, 3, 2).expand(
-1, -1, self.scale_w * self.tpv_w, -1, -1)
tpv_wz_vox = tpv_wz.unsqueeze(-1).permute(0, 1, 2, 4, 3).expand(
-1, -1, -1, self.scale_h * self.tpv_h, -1)
fused_vox = (tpv_hw_vox + tpv_zh_vox + tpv_wz_vox).flatten(2)
fused_pts = tpv_hw_pts + tpv_zh_pts + tpv_wz_pts
fused = torch.cat([fused_vox, fused_pts], dim=-1) # bs, c, whz+n
fused = fused.permute(0, 2, 1)
if self.use_checkpoint:
fused = torch.utils.checkpoint.checkpoint(self.decoder, fused)
logits = torch.utils.checkpoint.checkpoint(
self.classifier, fused)
else:
fused = self.decoder(fused)
logits = self.classifier(fused)
logits = logits.permute(0, 2, 1)
logits_vox = logits[:, :, :(-n)].reshape(bs, self.classes,
self.scale_w * self.tpv_w,
self.scale_h * self.tpv_h,
self.scale_z * self.tpv_z)
logits_pts = logits[:, :, (-n):].reshape(bs, self.classes, n, 1, 1)
return logits_vox, logits_pts
else:
tpv_hw = tpv_hw.unsqueeze(-1).permute(0, 1, 3, 2, 4).expand(
-1, -1, -1, -1, self.scale_z * self.tpv_z)
tpv_zh = tpv_zh.unsqueeze(-1).permute(0, 1, 4, 3, 2).expand(
-1, -1, self.scale_w * self.tpv_w, -1, -1)
tpv_wz = tpv_wz.unsqueeze(-1).permute(0, 1, 2, 4, 3).expand(
-1, -1, -1, self.scale_h * self.tpv_h, -1)
fused = tpv_hw + tpv_zh + tpv_wz
fused = fused.permute(0, 2, 3, 4, 1)
if self.use_checkpoint:
fused = torch.utils.checkpoint.checkpoint(self.decoder, fused)
logits = torch.utils.checkpoint.checkpoint(
self.classifier, fused)
else:
fused = self.decoder(fused)
logits = self.classifier(fused)
logits = logits.permute(0, 4, 1, 2, 3)
return logits
def predict(self, tpv_list, batch_data_samples):
"""
tpv_list[0]: bs, h*w, c
tpv_list[1]: bs, z*h, c
tpv_list[2]: bs, w*z, c
"""
tpv_hw, tpv_zh, tpv_wz = tpv_list
bs, _, c = tpv_hw.shape
tpv_hw = tpv_hw.permute(0, 2, 1).reshape(bs, c, self.tpv_h, self.tpv_w)
tpv_zh = tpv_zh.permute(0, 2, 1).reshape(bs, c, self.tpv_z, self.tpv_h)
tpv_wz = tpv_wz.permute(0, 2, 1).reshape(bs, c, self.tpv_w, self.tpv_z)
if self.scale_h != 1 or self.scale_w != 1:
tpv_hw = F.interpolate(
tpv_hw,
size=(self.tpv_h * self.scale_h, self.tpv_w * self.scale_w),
mode='bilinear')
if self.scale_z != 1 or self.scale_h != 1:
tpv_zh = F.interpolate(
tpv_zh,
size=(self.tpv_z * self.scale_z, self.tpv_h * self.scale_h),
mode='bilinear')
if self.scale_w != 1 or self.scale_z != 1:
tpv_wz = F.interpolate(
tpv_wz,
size=(self.tpv_w * self.scale_w, self.tpv_z * self.scale_z),
mode='bilinear')
logits = []
for i, data_sample in enumerate(batch_data_samples):
point_coors = data_sample.point_coors.reshape(1, 1, -1, 3).float()
point_coors[
...,
0] = point_coors[..., 0] / (self.tpv_w * self.scale_w) * 2 - 1
point_coors[
...,
1] = point_coors[..., 1] / (self.tpv_h * self.scale_h) * 2 - 1
point_coors[
...,
2] = point_coors[..., 2] / (self.tpv_z * self.scale_z) * 2 - 1
sample_loc = point_coors[..., [0, 1]]
tpv_hw_pts = F.grid_sample(
tpv_hw[i:i + 1], sample_loc, align_corners=False)
sample_loc = point_coors[..., [1, 2]]
tpv_zh_pts = F.grid_sample(
tpv_zh[i:i + 1], sample_loc, align_corners=False)
sample_loc = point_coors[..., [2, 0]]
tpv_wz_pts = F.grid_sample(
tpv_wz[i:i + 1], sample_loc, align_corners=False)
fused_pts = tpv_hw_pts + tpv_zh_pts + tpv_wz_pts
fused_pts = fused_pts.squeeze(0).squeeze(1).transpose(0, 1)
fused_pts = self.decoder(fused_pts)
logit = self.classifier(fused_pts)
logits.append(logit)
return logits
def loss(self, tpv_list, batch_data_samples):
tpv_hw, tpv_zh, tpv_wz = tpv_list
bs, _, c = tpv_hw.shape
tpv_hw = tpv_hw.permute(0, 2, 1).reshape(bs, c, self.tpv_h, self.tpv_w)
tpv_zh = tpv_zh.permute(0, 2, 1).reshape(bs, c, self.tpv_z, self.tpv_h)
tpv_wz = tpv_wz.permute(0, 2, 1).reshape(bs, c, self.tpv_w, self.tpv_z)
if self.scale_h != 1 or self.scale_w != 1:
tpv_hw = F.interpolate(
tpv_hw,
size=(self.tpv_h * self.scale_h, self.tpv_w * self.scale_w),
mode='bilinear')
if self.scale_z != 1 or self.scale_h != 1:
tpv_zh = F.interpolate(
tpv_zh,
size=(self.tpv_z * self.scale_z, self.tpv_h * self.scale_h),
mode='bilinear')
if self.scale_w != 1 or self.scale_z != 1:
tpv_wz = F.interpolate(
tpv_wz,
size=(self.tpv_w * self.scale_w, self.tpv_z * self.scale_z),
mode='bilinear')
batch_pts, batch_vox = [], []
for i, data_sample in enumerate(batch_data_samples):
point_coors = data_sample.point_coors.reshape(1, 1, -1, 3).float()
point_coors[
...,
0] = point_coors[..., 0] / (self.tpv_w * self.scale_w) * 2 - 1
point_coors[
...,
1] = point_coors[..., 1] / (self.tpv_h * self.scale_h) * 2 - 1
point_coors[
...,
2] = point_coors[..., 2] / (self.tpv_z * self.scale_z) * 2 - 1
sample_loc = point_coors[..., [0, 1]]
tpv_hw_pts = F.grid_sample(
tpv_hw[i:i + 1], sample_loc, align_corners=False)
sample_loc = point_coors[..., [1, 2]]
tpv_zh_pts = F.grid_sample(
tpv_zh[i:i + 1], sample_loc, align_corners=False)
sample_loc = point_coors[..., [2, 0]]
tpv_wz_pts = F.grid_sample(
tpv_wz[i:i + 1], sample_loc, align_corners=False)
fused_pts = (tpv_hw_pts + tpv_zh_pts +
tpv_wz_pts).squeeze(0).squeeze(1)
batch_pts.append(fused_pts)
tpv_hw_vox = tpv_hw.unsqueeze(-1).permute(0, 1, 3, 2, 4).expand(
-1, -1, -1, -1, self.scale_z * self.tpv_z)
tpv_zh_vox = tpv_zh.unsqueeze(-1).permute(0, 1, 4, 3, 2).expand(
-1, -1, self.scale_w * self.tpv_w, -1, -1)
tpv_wz_vox = tpv_wz.unsqueeze(-1).permute(0, 1, 2, 4, 3).expand(
-1, -1, -1, self.scale_h * self.tpv_h, -1)
fused_vox = tpv_hw_vox + tpv_zh_vox + tpv_wz_vox
voxel_coors = data_sample.voxel_coors.long()
fused_vox = fused_vox[:, :, voxel_coors[:, 0], voxel_coors[:, 1],
voxel_coors[:, 2]]
fused_vox = fused_vox.squeeze(0)
batch_vox.append(fused_vox)
batch_pts = torch.cat(batch_pts, dim=1)
batch_vox = torch.cat(batch_vox, dim=1)
num_points = batch_pts.shape[1]
logits = self.decoder(
torch.cat([batch_pts, batch_vox], dim=1).transpose(0, 1))
logits = self.classifier(logits)
pts_logits = logits[:num_points, :]
vox_logits = logits[num_points:, :]
pts_seg_label = torch.cat([
data_sample.gt_pts_seg.pts_semantic_mask
for data_sample in batch_data_samples
])
voxel_seg_label = torch.cat([
data_sample.gt_pts_seg.voxel_semantic_mask
for data_sample in batch_data_samples
])
if self.ce_input == 'voxel':
ce_input = vox_logits
ce_label = voxel_seg_label
else:
ce_input = pts_logits
ce_label = pts_seg_label
if self.lovasz_input == 'voxel':
lovasz_input = vox_logits
lovasz_label = voxel_seg_label
else:
lovasz_input = pts_logits
lovasz_label = pts_seg_label
loss = dict()
loss['loss_ce'] = self.loss_ce(
ce_input, ce_label, ignore_index=self.ignore_index)
loss['loss_lovasz'] = self.loss_lovasz(
lovasz_input, lovasz_label, ignore_index=self.ignore_index)
return loss