from typing import Optional, Union from torch import nn from mmdet3d.models import Base3DSegmentor from mmdet3d.registry import MODELS from mmdet3d.structures.det3d_data_sample import SampleList @MODELS.register_module() class TPVFormer(Base3DSegmentor): def __init__(self, data_preprocessor: Optional[Union[dict, nn.Module]] = None, backbone=None, neck=None, encoder=None, decode_head=None): super().__init__(data_preprocessor=data_preprocessor) self.backbone = MODELS.build(backbone) if neck is not None: self.neck = MODELS.build(neck) self.encoder = MODELS.build(encoder) self.decode_head = MODELS.build(decode_head) def extract_feat(self, img): """Extract features of images.""" B, N, C, H, W = img.size() img = img.view(B * N, C, H, W) img_feats = self.backbone(img) if hasattr(self, 'neck'): img_feats = self.neck(img_feats) img_feats_reshaped = [] for img_feat in img_feats: _, C, H, W = img_feat.size() img_feats_reshaped.append(img_feat.view(B, N, C, H, W)) return img_feats_reshaped def _forward(self, batch_inputs, batch_data_samples): """Forward training function.""" img_feats = self.extract_feat(batch_inputs['imgs']) outs = self.encoder(img_feats, batch_data_samples) outs = self.decode_head(outs, batch_inputs['voxels']['coors']) return outs def loss(self, batch_inputs: dict, batch_data_samples: SampleList) -> SampleList: img_feats = self.extract_feat(batch_inputs['imgs']) queries = self.encoder(img_feats, batch_data_samples) losses = self.decode_head.loss(queries, batch_data_samples) return losses def predict(self, batch_inputs: dict, batch_data_samples: SampleList) -> SampleList: """Forward predict function.""" img_feats = self.extract_feat(batch_inputs['imgs']) tpv_queries = self.encoder(img_feats, batch_data_samples) seg_logits = self.decode_head.predict(tpv_queries, batch_data_samples) seg_preds = [seg_logit.argmax(dim=1) for seg_logit in seg_logits] return self.postprocess_result(seg_preds, batch_data_samples) def aug_test(self, batch_inputs, batch_data_samples): pass def encode_decode(self, batch_inputs: dict, batch_data_samples: SampleList) -> SampleList: pass