nsd_model / point_pe.py
huzey's picture
Upload folder using huggingface_hub
ec378c3 verified
# https://gist.github.com/xmodar/ae2d94681a6fda39f3c4f3ac91eef7b7
# %%
import torch
def sinusoidal(positions, features=16, periods=10000):
"""Encode `positions` using sinusoidal positional encoding
Args:
positions: tensor of positions
features: half the number of features per position
periods: used frequencies for the sinusoidal functions
Returns:
Positional encoding of shape `(*positions.shape, features, 2)`
"""
dtype = positions.dtype if positions.is_floating_point() else None
kwargs = dict(device=positions.device, dtype=dtype)
omega = torch.logspace(0, 1 / features - 1, features, periods, **kwargs)
fraction = omega * positions.unsqueeze(-1)
return torch.stack((fraction.sin(), fraction.cos()), dim=-1)
def point_pe(points, low=0, high=1, steps=100, features=16, periods=10000):
"""Encode points in bounded space using sinusoidal positional encoding
Args:
points: tensor of points; typically of shape (*, C)
low: lower bound of the space; typically of shape (C,)
high: upper bound of the space; typically of shape (C,)
steps: number of cells that split the space; typically of shape (C,)
features: half the number of features per position
periods: used frequencies for the sinusoidal functions
Returns:
Positional encoded points of the following shape:
`(*points.shape[:-1], points.shape[-1] * features * 2)`
"""
positions = (points - low).mul_(steps / (high - low))
return sinusoidal(positions, features, periods).flatten(-3)
def point_position_encoding(points, max_steps=100, features=16, periods=10000):
low = points.min(0).values
high = points.max(0).values
steps = high - low
steps *= max_steps / steps.max()
pe = point_pe(points, low, high, steps, features, periods)
return pe
def test(num_points=1000, max_steps=100, features=32, periods=10000):
"""Test point_pe"""
point_cloud = torch.rand(num_points, 3)
low = point_cloud.min(0).values
high = point_cloud.max(0).values
steps = high - low
steps *= max_steps / steps.max()
# print(point_pe(point_cloud, low, high, steps).shape)
pe = point_pe(point_cloud, low, high, steps, features=features, periods=periods)
return pe
# %%
if __name__ == "__main__":
pe = test(20, 1000, periods=10000)
import matplotlib.pyplot as plt
fig = plt.figure(figsize=(10, 10))
plt.imshow(pe)
# %%
def pe_2d(num_points=14, max_steps=100, features=32, periods=10000):
x = torch.linspace(0, 1, num_points)
y = torch.linspace(0, 1, num_points)
points = torch.stack(torch.meshgrid(x, y), dim=-1).reshape(-1, 2)
# print(points)
# print(points.shape)
low = points.min(0).values
high = points.max(0).values
steps = high - low
steps *= max_steps / steps.max()
# print(point_pe(point_cloud, low, high, steps).shape)
pe = point_pe(points, low, high, steps, features=features, periods=periods)
pe = pe.reshape(num_points, num_points, -1)
pe = pe.permute(2, 0, 1)
return pe
# %%
if __name__ == "__main__":
pe = pe_2d(3, max_steps=1000, periods=10000, features=32)
import matplotlib.pyplot as plt
fig = plt.figure(figsize=(10, 10))
plt.imshow(pe[64, :, :])
# %%