Spaces:
Running
Running
"""Copyright (c) Meta Platforms, Inc. and affiliates.""" | |
import math | |
import torch | |
from geoopt.manifolds import Sphere as geoopt_Sphere | |
class Sphere(geoopt_Sphere): | |
def transp(self, x, y, v): | |
denom = 1 + self.inner(x, x, y, keepdim=True) | |
res = v - self.inner(x, y, v, keepdim=True) / denom * (x + y) | |
cond = denom.gt(1e-3) | |
return torch.where(cond, res, -v) | |
def uniform_logprob(self, x): | |
dim = x.shape[-1] | |
return torch.full_like( | |
x[..., 0], | |
math.lgamma(dim / 2) - (math.log(2) + (dim / 2) * math.log(math.pi)), | |
) | |
def random_base(self, *args, **kwargs): | |
return self.random_uniform(*args, **kwargs) | |
def base_logprob(self, *args, **kwargs): | |
return self.uniform_logprob(*args, **kwargs) | |
def geodesic(manifold, start_point, end_point): | |
shooting_tangent_vec = manifold.logmap(start_point, end_point) | |
def path(t): | |
"""Generate parameterized function for geodesic curve. | |
Parameters | |
---------- | |
t : array-like, shape=[n_points,] | |
Times at which to compute points of the geodesics. | |
""" | |
tangent_vecs = torch.einsum("i,...k->...ik", t, shooting_tangent_vec) | |
points_at_time_t = manifold.expmap(start_point.unsqueeze(-2), tangent_vecs) | |
return points_at_time_t | |
return path | |