Spaces:
Running
Running
File size: 778 Bytes
7c3ff16 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 |
import torch
tensor_interpolation = None
def get_tensor_interpolation_method():
return tensor_interpolation
def set_tensor_interpolation_method(is_slerp):
global tensor_interpolation
tensor_interpolation = slerp if is_slerp else linear
def linear(v1, v2, t):
return (1.0 - t) * v1 + t * v2
def slerp(
v0: torch.Tensor, v1: torch.Tensor, t: float, DOT_THRESHOLD: float = 0.9995
) -> torch.Tensor:
u0 = v0 / v0.norm()
u1 = v1 / v1.norm()
dot = (u0 * u1).sum()
if dot.abs() > DOT_THRESHOLD:
# logger.info(f'warning: v0 and v1 close to parallel, using linear interpolation instead.')
return (1.0 - t) * v0 + t * v1
omega = dot.acos()
return (((1.0 - t) * omega).sin() * v0 + (t * omega).sin() * v1) / omega.sin()
|