ZJUPeng's picture
add continuous
d6682b6
raw
history blame
581 Bytes
import numpy as np
import torch
from typing import Union, List
class linear:
def __init__(self):
pass
def execute(
self,
t: Union[float, List[float]],
v0: Union[List[torch.Tensor], torch.Tensor],
v1: Union[List[torch.Tensor], torch.Tensor],
DOT_THRESHOLD: float = 0.9995,
eps: float = 1e-8,
densities = None,
):
if type(v0) is list:
v0 = v0[0]
if type(t) is list:
t = t[0]
if type(v1) is list:
v1 = v1[0]
return t * v1 + (1.0 - t) * v0