Spaces:
Running
on
L40S
Running
on
L40S
# TODO: Adapted from cli | |
from typing import Callable, List, Optional | |
import numpy as np | |
def ordered_halving(val): | |
bin_str = f"{val:064b}" | |
bin_flip = bin_str[::-1] | |
as_int = int(bin_flip, 2) | |
return as_int / (1 << 64) | |
def uniform( | |
step: int = ..., | |
num_steps: Optional[int] = None, | |
num_frames: int = ..., | |
context_size: Optional[int] = None, | |
context_stride: int = 3, | |
context_overlap: int = 4, | |
closed_loop: bool = True, | |
): | |
if num_frames <= context_size: | |
yield list(range(num_frames)) | |
return | |
context_stride = min( | |
context_stride, int(np.ceil(np.log2(num_frames / context_size))) + 1 | |
) | |
for context_step in 1 << np.arange(context_stride): | |
pad = int(round(num_frames * ordered_halving(step))) | |
for j in range( | |
int(ordered_halving(step) * context_step) + pad, | |
num_frames + pad + (0 if closed_loop else -context_overlap), | |
(context_size * context_step - context_overlap), | |
): | |
yield [ | |
e % num_frames | |
for e in range(j, j + context_size * context_step, context_step) | |
] | |
def get_context_scheduler(name: str) -> Callable: | |
if name == "uniform": | |
return uniform | |
else: | |
raise ValueError(f"Unknown context_overlap policy {name}") | |
def get_total_steps( | |
scheduler, | |
timesteps: List[int], | |
num_steps: Optional[int] = None, | |
num_frames: int = ..., | |
context_size: Optional[int] = None, | |
context_stride: int = 3, | |
context_overlap: int = 4, | |
closed_loop: bool = True, | |
): | |
return sum( | |
len( | |
list( | |
scheduler( | |
i, | |
num_steps, | |
num_frames, | |
context_size, | |
context_stride, | |
context_overlap, | |
) | |
) | |
) | |
for i in range(len(timesteps)) | |
) | |