MonoScene / monoscene /monoscene.py
anhquancao's picture
downsample output
b5be88d
import pytorch_lightning as pl
import torch
import torch.nn as nn
from monoscene.unet3d_nyu import UNet3D as UNet3DNYU
from monoscene.unet3d_kitti import UNet3D as UNet3DKitti
from monoscene.flosp import FLoSP
import numpy as np
import torch.nn.functional as F
from monoscene.unet2d import UNet2D
class MonoScene(pl.LightningModule):
def __init__(
self,
n_classes,
feature,
project_scale,
full_scene_size,
dataset,
project_res=["1", "2", "4", "8"],
n_relations=4,
context_prior=True,
fp_loss=True,
frustum_size=4,
relation_loss=False,
CE_ssc_loss=True,
geo_scal_loss=True,
sem_scal_loss=True,
lr=1e-4,
weight_decay=1e-4,
):
super().__init__()
self.project_res = project_res
self.fp_loss = fp_loss
self.dataset = dataset
self.context_prior = context_prior
self.frustum_size = frustum_size
self.relation_loss = relation_loss
self.CE_ssc_loss = CE_ssc_loss
self.sem_scal_loss = sem_scal_loss
self.geo_scal_loss = geo_scal_loss
self.project_scale = project_scale
self.lr = lr
self.weight_decay = weight_decay
self.projects = {}
self.scale_2ds = [1, 2, 4, 8] # 2D scales
for scale_2d in self.scale_2ds:
self.projects[str(scale_2d)] = FLoSP(
full_scene_size, project_scale=self.project_scale, dataset=self.dataset
)
self.projects = nn.ModuleDict(self.projects)
self.n_classes = n_classes
if self.dataset == "NYU":
self.net_3d_decoder = UNet3DNYU(
self.n_classes,
nn.BatchNorm3d,
n_relations=n_relations,
feature=feature,
full_scene_size=full_scene_size,
context_prior=context_prior,
)
elif self.dataset == "kitti":
self.net_3d_decoder = UNet3DKitti(
self.n_classes,
nn.BatchNorm3d,
project_scale=project_scale,
feature=feature,
full_scene_size=full_scene_size,
context_prior=context_prior,
)
self.net_rgb = UNet2D.build(out_feature=feature, use_decoder=True)
def forward(self, batch):
img = batch["img"]
bs = len(img)
out = {}
x_rgb = self.net_rgb(img)
x3ds = []
for i in range(bs):
x3d = None
for scale_2d in self.project_res:
# project features at each 2D scale to target 3D scale
scale_2d = int(scale_2d)
projected_pix = batch["projected_pix_{}".format(self.project_scale)][i]#.cuda()
fov_mask = batch["fov_mask_{}".format(self.project_scale)][i]#.cuda()
# Sum all the 3D features
if x3d is None:
x3d = self.projects[str(scale_2d)](
x_rgb["1_" + str(scale_2d)][i],
# torch.div(projected_pix, scale_2d, rounding_mode='floor'),
projected_pix // scale_2d,
fov_mask,
)
else:
x3d += self.projects[str(scale_2d)](
x_rgb["1_" + str(scale_2d)][i],
# torch.div(projected_pix, scale_2d, rounding_mode='floor'),
projected_pix // scale_2d,
fov_mask,
)
x3ds.append(x3d)
input_dict = {
"x3d": torch.stack(x3ds),
}
out_dict = self.net_3d_decoder(input_dict)
ssc_pred = out_dict["ssc_logit"]
y_pred = ssc_pred.detach().cpu().numpy()
y_pred = np.argmax(y_pred, axis=1)
return y_pred