File size: 3,317 Bytes
ec378c3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
# https://gist.github.com/xmodar/ae2d94681a6fda39f3c4f3ac91eef7b7
# %%
import torch


def sinusoidal(positions, features=16, periods=10000):
    """Encode `positions` using sinusoidal positional encoding

    Args:
        positions: tensor of positions
        features: half the number of features per position
        periods: used frequencies for the sinusoidal functions

    Returns:
        Positional encoding of shape `(*positions.shape, features, 2)`
    """
    dtype = positions.dtype if positions.is_floating_point() else None
    kwargs = dict(device=positions.device, dtype=dtype)
    omega = torch.logspace(0, 1 / features - 1, features, periods, **kwargs)
    fraction = omega * positions.unsqueeze(-1)
    return torch.stack((fraction.sin(), fraction.cos()), dim=-1)


def point_pe(points, low=0, high=1, steps=100, features=16, periods=10000):
    """Encode points in bounded space using sinusoidal positional encoding

    Args:
        points: tensor of points; typically of shape (*, C)
        low: lower bound of the space; typically of shape (C,)
        high: upper bound of the space; typically of shape (C,)
        steps: number of cells that split the space; typically of shape (C,)
        features: half the number of features per position
        periods: used frequencies for the sinusoidal functions

    Returns:
        Positional encoded points of the following shape:
        `(*points.shape[:-1], points.shape[-1] * features * 2)`
    """
    positions = (points - low).mul_(steps / (high - low))
    return sinusoidal(positions, features, periods).flatten(-3)


def point_position_encoding(points, max_steps=100, features=16, periods=10000):
    low = points.min(0).values
    high = points.max(0).values
    steps = high - low
    steps *= max_steps / steps.max()
    pe = point_pe(points, low, high, steps, features, periods)
    return pe


def test(num_points=1000, max_steps=100, features=32, periods=10000):
    """Test point_pe"""
    point_cloud = torch.rand(num_points, 3)
    low = point_cloud.min(0).values
    high = point_cloud.max(0).values
    steps = high - low
    steps *= max_steps / steps.max()
    # print(point_pe(point_cloud, low, high, steps).shape)
    pe = point_pe(point_cloud, low, high, steps, features=features, periods=periods)
    return pe


# %%
if __name__ == "__main__":
    pe = test(20, 1000, periods=10000)

    import matplotlib.pyplot as plt

    fig = plt.figure(figsize=(10, 10))
    plt.imshow(pe)
# %%


def pe_2d(num_points=14, max_steps=100, features=32, periods=10000):
    x = torch.linspace(0, 1, num_points)
    y = torch.linspace(0, 1, num_points)
    points = torch.stack(torch.meshgrid(x, y), dim=-1).reshape(-1, 2)
    # print(points)
    # print(points.shape)
    low = points.min(0).values
    high = points.max(0).values
    steps = high - low
    steps *= max_steps / steps.max()
    # print(point_pe(point_cloud, low, high, steps).shape)
    pe = point_pe(points, low, high, steps, features=features, periods=periods)
    pe = pe.reshape(num_points, num_points, -1)
    pe = pe.permute(2, 0, 1)
    return pe


# %%
if __name__ == "__main__":
    pe = pe_2d(3, max_steps=1000, periods=10000, features=32)

    import matplotlib.pyplot as plt

    fig = plt.figure(figsize=(10, 10))
    plt.imshow(pe[64, :, :])
# %%