|
|
|
|
|
import torch |
|
|
|
|
|
def sinusoidal(positions, features=16, periods=10000): |
|
"""Encode `positions` using sinusoidal positional encoding |
|
|
|
Args: |
|
positions: tensor of positions |
|
features: half the number of features per position |
|
periods: used frequencies for the sinusoidal functions |
|
|
|
Returns: |
|
Positional encoding of shape `(*positions.shape, features, 2)` |
|
""" |
|
dtype = positions.dtype if positions.is_floating_point() else None |
|
kwargs = dict(device=positions.device, dtype=dtype) |
|
omega = torch.logspace(0, 1 / features - 1, features, periods, **kwargs) |
|
fraction = omega * positions.unsqueeze(-1) |
|
return torch.stack((fraction.sin(), fraction.cos()), dim=-1) |
|
|
|
|
|
def point_pe(points, low=0, high=1, steps=100, features=16, periods=10000): |
|
"""Encode points in bounded space using sinusoidal positional encoding |
|
|
|
Args: |
|
points: tensor of points; typically of shape (*, C) |
|
low: lower bound of the space; typically of shape (C,) |
|
high: upper bound of the space; typically of shape (C,) |
|
steps: number of cells that split the space; typically of shape (C,) |
|
features: half the number of features per position |
|
periods: used frequencies for the sinusoidal functions |
|
|
|
Returns: |
|
Positional encoded points of the following shape: |
|
`(*points.shape[:-1], points.shape[-1] * features * 2)` |
|
""" |
|
positions = (points - low).mul_(steps / (high - low)) |
|
return sinusoidal(positions, features, periods).flatten(-3) |
|
|
|
|
|
def point_position_encoding(points, max_steps=100, features=16, periods=10000): |
|
low = points.min(0).values |
|
high = points.max(0).values |
|
steps = high - low |
|
steps *= max_steps / steps.max() |
|
pe = point_pe(points, low, high, steps, features, periods) |
|
return pe |
|
|
|
|
|
def test(num_points=1000, max_steps=100, features=32, periods=10000): |
|
"""Test point_pe""" |
|
point_cloud = torch.rand(num_points, 3) |
|
low = point_cloud.min(0).values |
|
high = point_cloud.max(0).values |
|
steps = high - low |
|
steps *= max_steps / steps.max() |
|
|
|
pe = point_pe(point_cloud, low, high, steps, features=features, periods=periods) |
|
return pe |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
pe = test(20, 1000, periods=10000) |
|
|
|
import matplotlib.pyplot as plt |
|
|
|
fig = plt.figure(figsize=(10, 10)) |
|
plt.imshow(pe) |
|
|
|
|
|
|
|
def pe_2d(num_points=14, max_steps=100, features=32, periods=10000): |
|
x = torch.linspace(0, 1, num_points) |
|
y = torch.linspace(0, 1, num_points) |
|
points = torch.stack(torch.meshgrid(x, y), dim=-1).reshape(-1, 2) |
|
|
|
|
|
low = points.min(0).values |
|
high = points.max(0).values |
|
steps = high - low |
|
steps *= max_steps / steps.max() |
|
|
|
pe = point_pe(points, low, high, steps, features=features, periods=periods) |
|
pe = pe.reshape(num_points, num_points, -1) |
|
pe = pe.permute(2, 0, 1) |
|
return pe |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
pe = pe_2d(3, max_steps=1000, periods=10000, features=32) |
|
|
|
import matplotlib.pyplot as plt |
|
|
|
fig = plt.figure(figsize=(10, 10)) |
|
plt.imshow(pe[64, :, :]) |
|
|
|
|