Spaces:
Sleeping
Sleeping
import os | |
import imageio | |
import numpy as np | |
from typing import Union | |
import decord | |
decord.bridge.set_bridge('torch') | |
import torch | |
import torchvision | |
import PIL | |
from typing import List | |
from tqdm import tqdm | |
from einops import rearrange | |
import torchvision.transforms.functional as F | |
import random | |
def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=4, fps=8): | |
videos = rearrange(videos, "b c t h w -> t b c h w") | |
outputs = [] | |
for x in videos: | |
x = torchvision.utils.make_grid(x, nrow=n_rows) | |
x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) | |
if rescale: | |
x = (x + 1.0) / 2.0 # -1,1 -> 0,1 | |
x = (x * 255).numpy().astype(np.uint8) | |
outputs.append(x) | |
os.makedirs(os.path.dirname(path), exist_ok=True) | |
imageio.mimsave(path, outputs, fps=fps) | |
def save_videos_grid_pil(videos: List[PIL.Image.Image], path: str, rescale=False, n_rows=4, fps=8): | |
videos = rearrange(videos, "b c t h w -> t b c h w") | |
outputs = [] | |
for x in videos: | |
x = torchvision.utils.make_grid(x, nrow=n_rows) | |
x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) | |
if rescale: | |
x = (x + 1.0) / 2.0 # -1,1 -> 0,1 | |
x = (x * 255).numpy().astype(np.uint8) | |
outputs.append(x) | |
os.makedirs(os.path.dirname(path), exist_ok=True) | |
imageio.mimsave(path, outputs, fps=fps) | |
def read_video(video_path, video_length, width=512, height=512, frame_rate=None): | |
vr = decord.VideoReader(video_path, width=width, height=height) | |
if frame_rate is None: | |
frame_rate = max(1, len(vr) // video_length) | |
sample_index = list(range(0, len(vr), frame_rate))[:video_length] | |
video = vr.get_batch(sample_index) | |
video = rearrange(video, "f h w c -> f c h w") | |
video = (video / 127.5 - 1.0) | |
return video | |
# DDIM Inversion | |
def init_prompt(prompt, pipeline): | |
uncond_input = pipeline.tokenizer( | |
[""], padding="max_length", max_length=pipeline.tokenizer.model_max_length, | |
return_tensors="pt" | |
) | |
uncond_embeddings = pipeline.text_encoder(uncond_input.input_ids.to(pipeline.device))[0] | |
text_input = pipeline.tokenizer( | |
[prompt], | |
padding="max_length", | |
max_length=pipeline.tokenizer.model_max_length, | |
truncation=True, | |
return_tensors="pt", | |
) | |
text_embeddings = pipeline.text_encoder(text_input.input_ids.to(pipeline.device))[0] | |
context = torch.cat([uncond_embeddings, text_embeddings]) | |
return context | |
def next_step(model_output: Union[torch.FloatTensor, np.ndarray], timestep: int, | |
sample: Union[torch.FloatTensor, np.ndarray], ddim_scheduler): | |
timestep, next_timestep = min( | |
timestep - ddim_scheduler.config.num_train_timesteps // ddim_scheduler.num_inference_steps, 999), timestep | |
alpha_prod_t = ddim_scheduler.alphas_cumprod[timestep] if timestep >= 0 else ddim_scheduler.final_alpha_cumprod | |
alpha_prod_t_next = ddim_scheduler.alphas_cumprod[next_timestep] | |
beta_prod_t = 1 - alpha_prod_t | |
next_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5 | |
next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output | |
next_sample = alpha_prod_t_next ** 0.5 * next_original_sample + next_sample_direction | |
return next_sample | |
def get_noise_pred_single(latents, t, context, unet): | |
noise_pred = unet(latents, t, encoder_hidden_states=context)["sample"] | |
return noise_pred | |
def ddim_loop(pipeline, ddim_scheduler, latent, num_inv_steps, prompt): | |
context = init_prompt(prompt, pipeline) | |
uncond_embeddings, cond_embeddings = context.chunk(2) | |
all_latent = [latent] | |
latent = latent.clone().detach() | |
for i in tqdm(range(num_inv_steps)): | |
t = ddim_scheduler.timesteps[len(ddim_scheduler.timesteps) - i - 1] | |
noise_pred = get_noise_pred_single(latent, t, cond_embeddings, pipeline.unet) | |
latent = next_step(noise_pred, t, latent, ddim_scheduler) | |
all_latent.append(latent) | |
return all_latent | |
def ddim_inversion(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt=""): | |
ddim_latents = ddim_loop(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt) | |
return ddim_latents | |
"""optical flow and trajectories sampling""" | |
def preprocess(img1_batch, img2_batch, transforms): | |
img1_batch = F.resize(img1_batch, size=[512, 512], antialias=False) | |
img2_batch = F.resize(img2_batch, size=[512, 512], antialias=False) | |
return transforms(img1_batch, img2_batch) | |
def keys_with_same_value(dictionary): | |
result = {} | |
for key, value in dictionary.items(): | |
if value not in result: | |
result[value] = [key] | |
else: | |
result[value].append(key) | |
conflict_points = {} | |
for k in result.keys(): | |
if len(result[k]) > 1: | |
conflict_points[k] = result[k] | |
return conflict_points | |
def find_duplicates(input_list): | |
seen = set() | |
duplicates = set() | |
for item in input_list: | |
if item in seen: | |
duplicates.add(item) | |
else: | |
seen.add(item) | |
return list(duplicates) | |
def neighbors_index(point, window_size, H, W): | |
"""return the spatial neighbor indices""" | |
t, x, y = point | |
neighbors = [] | |
for i in range(-window_size, window_size + 1): | |
for j in range(-window_size, window_size + 1): | |
if i == 0 and j == 0: | |
continue | |
if x + i < 0 or x + i >= H or y + j < 0 or y + j >= W: | |
continue | |
neighbors.append((t, x + i, y + j)) | |
return neighbors | |
def sample_trajectories(frames, device): | |
from torchvision.models.optical_flow import Raft_Large_Weights | |
from torchvision.models.optical_flow import raft_large | |
weights = Raft_Large_Weights.DEFAULT | |
transforms = weights.transforms() | |
# frames, _, _ = torchvision.io.read_video(str(video_path), output_format="TCHW") | |
clips = list(range(len(frames))) | |
model = raft_large(weights=Raft_Large_Weights.DEFAULT, progress=False).to(device) | |
model = model.eval() | |
finished_trajectories = [] | |
current_frames, next_frames = preprocess(frames[clips[:-1]], frames[clips[1:]], transforms) | |
list_of_flows = model(current_frames.to(device), next_frames.to(device)) | |
predicted_flows = list_of_flows[-1] | |
predicted_flows = predicted_flows/512 | |
resolutions = [64, 32, 16, 8] | |
res = {} | |
window_sizes = {64: 2, | |
32: 1, | |
16: 1, | |
8: 1} | |
for resolution in resolutions: | |
print("="*30) | |
trajectories = {} | |
predicted_flow_resolu = torch.round(resolution*torch.nn.functional.interpolate(predicted_flows, scale_factor=(resolution/512, resolution/512))) | |
T = predicted_flow_resolu.shape[0]+1 | |
H = predicted_flow_resolu.shape[2] | |
W = predicted_flow_resolu.shape[3] | |
is_activated = torch.zeros([T, H, W], dtype=torch.bool) | |
for t in range(T-1): | |
flow = predicted_flow_resolu[t] | |
for h in range(H): | |
for w in range(W): | |
if not is_activated[t, h, w]: | |
is_activated[t, h, w] = True | |
# this point has not been traversed, start new trajectory | |
x = h + int(flow[1, h, w]) | |
y = w + int(flow[0, h, w]) | |
if x >= 0 and x < H and y >= 0 and y < W: | |
# trajectories.append([(t, h, w), (t+1, x, y)]) | |
trajectories[(t, h, w)]= (t+1, x, y) | |
conflict_points = keys_with_same_value(trajectories) | |
for k in conflict_points: | |
index_to_pop = random.randint(0, len(conflict_points[k]) - 1) | |
conflict_points[k].pop(index_to_pop) | |
for point in conflict_points[k]: | |
if point[0] != T-1: | |
trajectories[point]= (-1, -1, -1) # stupid padding with (-1, -1, -1) | |
active_traj = [] | |
all_traj = [] | |
for t in range(T): | |
pixel_set = {(t, x//H, x%H):0 for x in range(H*W)} | |
new_active_traj = [] | |
for traj in active_traj: | |
if traj[-1] in trajectories: | |
v = trajectories[traj[-1]] | |
new_active_traj.append(traj + [v]) | |
pixel_set[v] = 1 | |
else: | |
all_traj.append(traj) | |
active_traj = new_active_traj | |
active_traj+=[[pixel] for pixel in pixel_set if pixel_set[pixel] == 0] | |
all_traj += active_traj | |
useful_traj = [i for i in all_traj if len(i)>1] | |
for idx in range(len(useful_traj)): | |
if useful_traj[idx][-1] == (-1, -1, -1): | |
useful_traj[idx] = useful_traj[idx][:-1] | |
print("how many points in all trajectories for resolution{}?".format(resolution), sum([len(i) for i in useful_traj])) | |
print("how many points in the video for resolution{}?".format(resolution), T*H*W) | |
# validate if there are no duplicates in the trajectories | |
trajs = [] | |
for traj in useful_traj: | |
trajs = trajs + traj | |
assert len(find_duplicates(trajs)) == 0, "There should not be duplicates in the useful trajectories." | |
# check if non-appearing points + appearing points = all the points in the video | |
all_points = set([(t, x, y) for t in range(T) for x in range(H) for y in range(W)]) | |
left_points = all_points- set(trajs) | |
print("How many points not in the trajectories for resolution{}?".format(resolution), len(left_points)) | |
for p in list(left_points): | |
useful_traj.append([p]) | |
print("how many points in all trajectories for resolution{} after pending?".format(resolution), sum([len(i) for i in useful_traj])) | |
longest_length = max([len(i) for i in useful_traj]) | |
sequence_length = (window_sizes[resolution]*2+1)**2 + longest_length - 1 | |
seqs = [] | |
masks = [] | |
# create a dictionary to facilitate checking the trajectories to which each point belongs. | |
point_to_traj = {} | |
for traj in useful_traj: | |
for p in traj: | |
point_to_traj[p] = traj | |
for t in range(T): | |
for x in range(H): | |
for y in range(W): | |
neighbours = neighbors_index((t,x,y), window_sizes[resolution], H, W) | |
sequence = [(t,x,y)]+neighbours + [(0,0,0) for i in range((window_sizes[resolution]*2+1)**2-1-len(neighbours))] | |
sequence_mask = torch.zeros(sequence_length, dtype=torch.bool) | |
sequence_mask[:len(neighbours)+1] = True | |
traj = point_to_traj[(t,x,y)].copy() | |
traj.remove((t,x,y)) | |
sequence = sequence + traj + [(0,0,0) for k in range(longest_length-1-len(traj))] | |
sequence_mask[(window_sizes[resolution]*2+1)**2: (window_sizes[resolution]*2+1)**2 + len(traj)] = True | |
seqs.append(sequence) | |
masks.append(sequence_mask) | |
seqs = torch.tensor(seqs) | |
masks = torch.stack(masks) | |
res["traj{}".format(resolution)] = seqs | |
res["mask{}".format(resolution)] = masks | |
return res | |