Spaces:
Running
Running
File size: 1,362 Bytes
c4c7cee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 |
"""Copyright (c) Meta Platforms, Inc. and affiliates."""
import math
import torch
from geoopt.manifolds import Sphere as geoopt_Sphere
class Sphere(geoopt_Sphere):
def transp(self, x, y, v):
denom = 1 + self.inner(x, x, y, keepdim=True)
res = v - self.inner(x, y, v, keepdim=True) / denom * (x + y)
cond = denom.gt(1e-3)
return torch.where(cond, res, -v)
def uniform_logprob(self, x):
dim = x.shape[-1]
return torch.full_like(
x[..., 0],
math.lgamma(dim / 2) - (math.log(2) + (dim / 2) * math.log(math.pi)),
)
def random_base(self, *args, **kwargs):
return self.random_uniform(*args, **kwargs)
def base_logprob(self, *args, **kwargs):
return self.uniform_logprob(*args, **kwargs)
def geodesic(manifold, start_point, end_point):
shooting_tangent_vec = manifold.logmap(start_point, end_point)
def path(t):
"""Generate parameterized function for geodesic curve.
Parameters
----------
t : array-like, shape=[n_points,]
Times at which to compute points of the geodesics.
"""
tangent_vecs = torch.einsum("i,...k->...ik", t, shooting_tangent_vec)
points_at_time_t = manifold.expmap(start_point.unsqueeze(-2), tangent_vecs)
return points_at_time_t
return path
|