File size: 2,304 Bytes
c2ca15f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 |
import torch
import torch.nn as nn
from mmengine.model import BaseModule
from mmdet3d.registry import MODELS
@MODELS.register_module()
class TPVFormerPositionalEncoding(BaseModule):
def __init__(self,
num_feats,
h,
w,
z,
init_cfg=dict(type='Uniform', layer='Embedding')):
super().__init__(init_cfg)
if not isinstance(num_feats, list):
num_feats = [num_feats] * 3
self.h_embed = nn.Embedding(h, num_feats[0])
self.w_embed = nn.Embedding(w, num_feats[1])
self.z_embed = nn.Embedding(z, num_feats[2])
self.num_feats = num_feats
self.h, self.w, self.z = h, w, z
def forward(self, bs, device, ignore_axis='z'):
if ignore_axis == 'h':
h_embed = torch.zeros(
1, 1, self.num_feats[0],
device=device).repeat(self.w, self.z, 1) # w, z, d
w_embed = self.w_embed(torch.arange(self.w, device=device))
w_embed = w_embed.reshape(self.w, 1, -1).repeat(1, self.z, 1)
z_embed = self.z_embed(torch.arange(self.z, device=device))
z_embed = z_embed.reshape(1, self.z, -1).repeat(self.w, 1, 1)
elif ignore_axis == 'w':
h_embed = self.h_embed(torch.arange(self.h, device=device))
h_embed = h_embed.reshape(1, self.h, -1).repeat(self.z, 1, 1)
w_embed = torch.zeros(
1, 1, self.num_feats[1],
device=device).repeat(self.z, self.h, 1)
z_embed = self.z_embed(torch.arange(self.z, device=device))
z_embed = z_embed.reshape(self.z, 1, -1).repeat(1, self.h, 1)
elif ignore_axis == 'z':
h_embed = self.h_embed(torch.arange(self.h, device=device))
h_embed = h_embed.reshape(self.h, 1, -1).repeat(1, self.w, 1)
w_embed = self.w_embed(torch.arange(self.w, device=device))
w_embed = w_embed.reshape(1, self.w, -1).repeat(self.h, 1, 1)
z_embed = torch.zeros(
1, 1, self.num_feats[2],
device=device).repeat(self.h, self.w, 1)
pos = torch.cat((h_embed, w_embed, z_embed),
dim=-1).flatten(0, 1).unsqueeze(0).repeat(bs, 1, 1)
return pos
|