import numpy as np import torch from typing import Union, List class linear: def __init__(self): pass def execute( self, t: Union[float, List[float]], v0: Union[List[torch.Tensor], torch.Tensor], v1: Union[List[torch.Tensor], torch.Tensor], DOT_THRESHOLD: float = 0.9995, eps: float = 1e-8, densities = None, ): if type(v0) is list: v0 = v0[0] if type(t) is list: t = t[0] if type(v1) is list: v1 = v1[0] return t * v1 + (1.0 - t) * v0