"""Copyright (c) Meta Platforms, Inc. and affiliates.""" import math import torch from geoopt.manifolds import Sphere as geoopt_Sphere class Sphere(geoopt_Sphere): def transp(self, x, y, v): denom = 1 + self.inner(x, x, y, keepdim=True) res = v - self.inner(x, y, v, keepdim=True) / denom * (x + y) cond = denom.gt(1e-3) return torch.where(cond, res, -v) def uniform_logprob(self, x): dim = x.shape[-1] return torch.full_like( x[..., 0], math.lgamma(dim / 2) - (math.log(2) + (dim / 2) * math.log(math.pi)), ) def random_base(self, *args, **kwargs): return self.random_uniform(*args, **kwargs) def base_logprob(self, *args, **kwargs): return self.uniform_logprob(*args, **kwargs) def geodesic(manifold, start_point, end_point): shooting_tangent_vec = manifold.logmap(start_point, end_point) def path(t): """Generate parameterized function for geodesic curve. Parameters ---------- t : array-like, shape=[n_points,] Times at which to compute points of the geodesics. """ tangent_vecs = torch.einsum("i,...k->...ik", t, shooting_tangent_vec) points_at_time_t = manifold.expmap(start_point.unsqueeze(-2), tangent_vecs) return points_at_time_t return path