Spaces:
Running
on
L40S
Running
on
L40S
import torch | |
tensor_interpolation = None | |
def get_tensor_interpolation_method(): | |
return tensor_interpolation | |
def set_tensor_interpolation_method(is_slerp): | |
global tensor_interpolation | |
tensor_interpolation = slerp if is_slerp else linear | |
def linear(v1, v2, t): | |
return (1.0 - t) * v1 + t * v2 | |
def slerp( | |
v0: torch.Tensor, v1: torch.Tensor, t: float, DOT_THRESHOLD: float = 0.9995 | |
) -> torch.Tensor: | |
u0 = v0 / v0.norm() | |
u1 = v1 / v1.norm() | |
dot = (u0 * u1).sum() | |
if dot.abs() > DOT_THRESHOLD: | |
# logger.info(f'warning: v0 and v1 close to parallel, using linear interpolation instead.') | |
return (1.0 - t) * v0 + t * v1 | |
omega = dot.acos() | |
return (((1.0 - t) * omega).sin() * v0 + (t * omega).sin() * v1) / omega.sin() | |