giantmonkeyTC
mm2
c2ca15f
raw
history blame
2.57 kB
from typing import Optional, Union
from torch import nn
from mmdet3d.models import Base3DSegmentor
from mmdet3d.registry import MODELS
from mmdet3d.structures.det3d_data_sample import SampleList
@MODELS.register_module()
class TPVFormer(Base3DSegmentor):
def __init__(self,
data_preprocessor: Optional[Union[dict, nn.Module]] = None,
backbone=None,
neck=None,
encoder=None,
decode_head=None):
super().__init__(data_preprocessor=data_preprocessor)
self.backbone = MODELS.build(backbone)
if neck is not None:
self.neck = MODELS.build(neck)
self.encoder = MODELS.build(encoder)
self.decode_head = MODELS.build(decode_head)
def extract_feat(self, img):
"""Extract features of images."""
B, N, C, H, W = img.size()
img = img.view(B * N, C, H, W)
img_feats = self.backbone(img)
if hasattr(self, 'neck'):
img_feats = self.neck(img_feats)
img_feats_reshaped = []
for img_feat in img_feats:
_, C, H, W = img_feat.size()
img_feats_reshaped.append(img_feat.view(B, N, C, H, W))
return img_feats_reshaped
def _forward(self, batch_inputs, batch_data_samples):
"""Forward training function."""
img_feats = self.extract_feat(batch_inputs['imgs'])
outs = self.encoder(img_feats, batch_data_samples)
outs = self.decode_head(outs, batch_inputs['voxels']['coors'])
return outs
def loss(self, batch_inputs: dict,
batch_data_samples: SampleList) -> SampleList:
img_feats = self.extract_feat(batch_inputs['imgs'])
queries = self.encoder(img_feats, batch_data_samples)
losses = self.decode_head.loss(queries, batch_data_samples)
return losses
def predict(self, batch_inputs: dict,
batch_data_samples: SampleList) -> SampleList:
"""Forward predict function."""
img_feats = self.extract_feat(batch_inputs['imgs'])
tpv_queries = self.encoder(img_feats, batch_data_samples)
seg_logits = self.decode_head.predict(tpv_queries, batch_data_samples)
seg_preds = [seg_logit.argmax(dim=1) for seg_logit in seg_logits]
return self.postprocess_result(seg_preds, batch_data_samples)
def aug_test(self, batch_inputs, batch_data_samples):
pass
def encode_decode(self, batch_inputs: dict,
batch_data_samples: SampleList) -> SampleList:
pass