File size: 2,942 Bytes
2252f3d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
import pytorch_lightning as pl
import numpy as np


class SpatialEncoder(pl.LightningModule):

    def __init__(self,
                 sp_level=1,
                 sp_type="rel_z_decay",
                 scale=1.0,
                 n_kpt=24,
                 sigma=0.2):

        super().__init__()

        self.sp_type = sp_type
        self.sp_level = sp_level
        self.n_kpt = n_kpt
        self.scale = scale
        self.sigma = sigma

    @staticmethod
    def position_embedding(x, nlevels, scale=1.0):
        """
        args:
            x: (B, N, C)
        return:
            (B, N, C * n_levels * 2)
        """
        if nlevels <= 0:
            return x
        vec = SpatialEncoder.pe_vector(nlevels, x.device, scale)

        B, N, _ = x.shape
        y = x[:, :, None, :] * vec[None, None, :, None]
        z = torch.cat((torch.sin(y), torch.cos(y)), axis=-1).view(B, N, -1)

        return torch.cat([x, z], -1)

    @staticmethod
    def pe_vector(nlevels, device, scale=1.0):
        v, val = [], 1
        for _ in range(nlevels):
            v.append(scale * np.pi * val)
            val *= 2
        return torch.from_numpy(np.asarray(v, dtype=np.float32)).to(device)

    def get_dim(self):
        if self.sp_type in ["z", "rel_z", "rel_z_decay"]:
            if "rel" in self.sp_type:
                return (1 + 2 * self.sp_level) * self.n_kpt
            else:
                return 1 + 2 * self.sp_level
        elif "xyz" in self.sp_type:
            if "rel" in self.sp_type:
                return (1 + 2 * self.sp_level) * 3 * self.n_kpt
            else:
                return (1 + 2 * self.sp_level) * 3

        return 0

    def forward(self, cxyz, kptxyz):

        B, N = cxyz.shape[:2]
        K = kptxyz.shape[1]

        dz = cxyz[:, :, None, 2:3] - kptxyz[:, None, :, 2:3]
        dxyz = cxyz[:, :, None] - kptxyz[:, None, :]
        
        # (B, N, K)
        weight = torch.exp(-(dxyz**2).sum(-1) / (2.0 * (self.sigma**2)))

        # position embedding ( B, N, K * (2*n_levels+1) )
        out = self.position_embedding(dz.view(B, N, K), self.sp_level)
        
        # BV,N,K,(2*n_levels+1) * B,N,K,1 = B,N,K*(2*n_levels+1) -> BV,K*(2*n_levels+1),N
        out = (out.view(B, N, -1, K) * weight[:, :, None]).view(B, N, -1).permute(0,2,1) 

        return out


if __name__ == "__main__":
    pts = torch.randn(2, 10000, 3).to("cuda")
    kpts = torch.randn(2, 24, 3).to("cuda")

    sp_encoder = SpatialEncoder(sp_level=3,
                                sp_type="rel_z_decay",
                                scale=1.0,
                                n_kpt=24,
                                sigma=0.1).to("cuda")
    out = sp_encoder(pts, kpts)
    print(out.shape)