Spaces:
Running
on
L40S
Running
on
L40S
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
import torch | |
import pytorch_lightning as pl | |
import numpy as np | |
class SpatialEncoder(pl.LightningModule): | |
def __init__(self, | |
sp_level=1, | |
sp_type="rel_z_decay", | |
scale=1.0, | |
n_kpt=24, | |
sigma=0.2): | |
super().__init__() | |
self.sp_type = sp_type | |
self.sp_level = sp_level | |
self.n_kpt = n_kpt | |
self.scale = scale | |
self.sigma = sigma | |
def position_embedding(x, nlevels, scale=1.0): | |
""" | |
args: | |
x: (B, N, C) | |
return: | |
(B, N, C * n_levels * 2) | |
""" | |
if nlevels <= 0: | |
return x | |
vec = SpatialEncoder.pe_vector(nlevels, x.device, scale) | |
B, N, _ = x.shape | |
y = x[:, :, None, :] * vec[None, None, :, None] | |
z = torch.cat((torch.sin(y), torch.cos(y)), axis=-1).view(B, N, -1) | |
return torch.cat([x, z], -1) | |
def pe_vector(nlevels, device, scale=1.0): | |
v, val = [], 1 | |
for _ in range(nlevels): | |
v.append(scale * np.pi * val) | |
val *= 2 | |
return torch.from_numpy(np.asarray(v, dtype=np.float32)).to(device) | |
def get_dim(self): | |
if self.sp_type in ["z", "rel_z", "rel_z_decay"]: | |
if "rel" in self.sp_type: | |
return (1 + 2 * self.sp_level) * self.n_kpt | |
else: | |
return 1 + 2 * self.sp_level | |
elif "xyz" in self.sp_type: | |
if "rel" in self.sp_type: | |
return (1 + 2 * self.sp_level) * 3 * self.n_kpt | |
else: | |
return (1 + 2 * self.sp_level) * 3 | |
return 0 | |
def forward(self, cxyz, kptxyz): | |
B, N = cxyz.shape[:2] | |
K = kptxyz.shape[1] | |
dz = cxyz[:, :, None, 2:3] - kptxyz[:, None, :, 2:3] | |
dxyz = cxyz[:, :, None] - kptxyz[:, None, :] | |
# (B, N, K) | |
weight = torch.exp(-(dxyz**2).sum(-1) / (2.0 * (self.sigma**2))) | |
# position embedding ( B, N, K * (2*n_levels+1) ) | |
out = self.position_embedding(dz.view(B, N, K), self.sp_level) | |
# BV,N,K,(2*n_levels+1) * B,N,K,1 = B,N,K*(2*n_levels+1) -> BV,K*(2*n_levels+1),N | |
out = (out.view(B, N, -1, K) * weight[:, :, None]).view(B, N, -1).permute(0,2,1) | |
return out | |
if __name__ == "__main__": | |
pts = torch.randn(2, 10000, 3).to("cuda") | |
kpts = torch.randn(2, 24, 3).to("cuda") | |
sp_encoder = SpatialEncoder(sp_level=3, | |
sp_type="rel_z_decay", | |
scale=1.0, | |
n_kpt=24, | |
sigma=0.1).to("cuda") | |
out = sp_encoder(pts, kpts) | |
print(out.shape) | |