MonoScene / monoscene /unet3d_kitti.py
anhquancao's picture
up
4d85df4
# encoding: utf-8
import torch
import torch.nn as nn
import torch.nn.functional as F
from monoscene.modules import SegmentationHead
from monoscene.CRP3D import CPMegaVoxels
from monoscene.modules import Process, Upsample, Downsample
class UNet3D(nn.Module):
def __init__(
self,
class_num,
norm_layer,
full_scene_size,
feature,
project_scale,
context_prior=None,
bn_momentum=0.1,
):
super(UNet3D, self).__init__()
self.business_layer = []
self.project_scale = project_scale
self.full_scene_size = full_scene_size
self.feature = feature
size_l1 = (
int(self.full_scene_size[0] / project_scale),
int(self.full_scene_size[1] / project_scale),
int(self.full_scene_size[2] / project_scale),
)
size_l2 = (size_l1[0] // 2, size_l1[1] // 2, size_l1[2] // 2)
size_l3 = (size_l2[0] // 2, size_l2[1] // 2, size_l2[2] // 2)
dilations = [1, 2, 3]
self.process_l1 = nn.Sequential(
Process(self.feature, norm_layer, bn_momentum, dilations=[1, 2, 3]),
Downsample(self.feature, norm_layer, bn_momentum),
)
self.process_l2 = nn.Sequential(
Process(self.feature * 2, norm_layer, bn_momentum, dilations=[1, 2, 3]),
Downsample(self.feature * 2, norm_layer, bn_momentum),
)
self.up_13_l2 = Upsample(
self.feature * 4, self.feature * 2, norm_layer, bn_momentum
)
self.up_12_l1 = Upsample(
self.feature * 2, self.feature, norm_layer, bn_momentum
)
self.up_l1_lfull = Upsample(
self.feature, self.feature // 2, norm_layer, bn_momentum
)
self.ssc_head = SegmentationHead(
self.feature // 2, self.feature // 2, class_num, dilations
)
self.context_prior = context_prior
if context_prior:
self.CP_mega_voxels = CPMegaVoxels(
self.feature * 4, size_l3, bn_momentum=bn_momentum
)
def forward(self, input_dict):
res = {}
x3d_l1 = input_dict["x3d"]
x3d_l2 = self.process_l1(x3d_l1)
x3d_l3 = self.process_l2(x3d_l2)
if self.context_prior:
ret = self.CP_mega_voxels(x3d_l3)
x3d_l3 = ret["x"]
for k in ret.keys():
res[k] = ret[k]
x3d_up_l2 = self.up_13_l2(x3d_l3) + x3d_l2
x3d_up_l1 = self.up_12_l1(x3d_up_l2) + x3d_l1
x3d_up_lfull = self.up_l1_lfull(x3d_up_l1)
ssc_logit_full = self.ssc_head(x3d_up_lfull)
res["ssc_logit"] = ssc_logit_full
return res