|
import torch |
|
import torch.nn as nn |
|
|
|
class MeanVFE(nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
|
|
def forward(self, batch_dict, **kwargs): |
|
""" |
|
Args: |
|
batch_dict: |
|
voxels: (num_voxels, max_points_per_voxel, C) |
|
voxel_num_points: optional (num_voxels) |
|
**kwargs: |
|
|
|
Returns: |
|
vfe_features: (num_voxels, C) |
|
""" |
|
voxel_features, voxel_num_points = batch_dict['voxels'], batch_dict['voxel_num_points'] |
|
points_mean = voxel_features[:, :, :].sum(dim=1, keepdim=False) |
|
normalizer = torch.clamp_min(voxel_num_points.view(-1, 1), min=1.0).type_as(voxel_features) |
|
points_mean = points_mean / normalizer |
|
batch_dict['voxel_features'] = points_mean.contiguous() |
|
|
|
return batch_dict |
|
|