Plonk / utils /manifolds.py
nicolas-dufour's picture
squash: merge all unpushed commits
c4c7cee
raw
history blame
1.36 kB
"""Copyright (c) Meta Platforms, Inc. and affiliates."""
import math
import torch
from geoopt.manifolds import Sphere as geoopt_Sphere
class Sphere(geoopt_Sphere):
def transp(self, x, y, v):
denom = 1 + self.inner(x, x, y, keepdim=True)
res = v - self.inner(x, y, v, keepdim=True) / denom * (x + y)
cond = denom.gt(1e-3)
return torch.where(cond, res, -v)
def uniform_logprob(self, x):
dim = x.shape[-1]
return torch.full_like(
x[..., 0],
math.lgamma(dim / 2) - (math.log(2) + (dim / 2) * math.log(math.pi)),
)
def random_base(self, *args, **kwargs):
return self.random_uniform(*args, **kwargs)
def base_logprob(self, *args, **kwargs):
return self.uniform_logprob(*args, **kwargs)
def geodesic(manifold, start_point, end_point):
shooting_tangent_vec = manifold.logmap(start_point, end_point)
def path(t):
"""Generate parameterized function for geodesic curve.
Parameters
----------
t : array-like, shape=[n_points,]
Times at which to compute points of the geodesics.
"""
tangent_vecs = torch.einsum("i,...k->...ik", t, shooting_tangent_vec)
points_at_time_t = manifold.expmap(start_point.unsqueeze(-2), tangent_vecs)
return points_at_time_t
return path