import torch import torch.nn as nn from mmengine.model import BaseModule from mmdet3d.registry import MODELS @MODELS.register_module() class TPVFormerPositionalEncoding(BaseModule): def __init__(self, num_feats, h, w, z, init_cfg=dict(type='Uniform', layer='Embedding')): super().__init__(init_cfg) if not isinstance(num_feats, list): num_feats = [num_feats] * 3 self.h_embed = nn.Embedding(h, num_feats[0]) self.w_embed = nn.Embedding(w, num_feats[1]) self.z_embed = nn.Embedding(z, num_feats[2]) self.num_feats = num_feats self.h, self.w, self.z = h, w, z def forward(self, bs, device, ignore_axis='z'): if ignore_axis == 'h': h_embed = torch.zeros( 1, 1, self.num_feats[0], device=device).repeat(self.w, self.z, 1) # w, z, d w_embed = self.w_embed(torch.arange(self.w, device=device)) w_embed = w_embed.reshape(self.w, 1, -1).repeat(1, self.z, 1) z_embed = self.z_embed(torch.arange(self.z, device=device)) z_embed = z_embed.reshape(1, self.z, -1).repeat(self.w, 1, 1) elif ignore_axis == 'w': h_embed = self.h_embed(torch.arange(self.h, device=device)) h_embed = h_embed.reshape(1, self.h, -1).repeat(self.z, 1, 1) w_embed = torch.zeros( 1, 1, self.num_feats[1], device=device).repeat(self.z, self.h, 1) z_embed = self.z_embed(torch.arange(self.z, device=device)) z_embed = z_embed.reshape(self.z, 1, -1).repeat(1, self.h, 1) elif ignore_axis == 'z': h_embed = self.h_embed(torch.arange(self.h, device=device)) h_embed = h_embed.reshape(self.h, 1, -1).repeat(1, self.w, 1) w_embed = self.w_embed(torch.arange(self.w, device=device)) w_embed = w_embed.reshape(1, self.w, -1).repeat(self.h, 1, 1) z_embed = torch.zeros( 1, 1, self.num_feats[2], device=device).repeat(self.h, self.w, 1) pos = torch.cat((h_embed, w_embed, z_embed), dim=-1).flatten(0, 1).unsqueeze(0).repeat(bs, 1, 1) return pos