sky24h's picture
init commit
2f3aac0
import os
import imageio
import numpy as np
from typing import Union
import decord
decord.bridge.set_bridge('torch')
import torch
import torchvision
import PIL
from typing import List
from tqdm import tqdm
from einops import rearrange
import torchvision.transforms.functional as F
import random
def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=4, fps=8):
videos = rearrange(videos, "b c t h w -> t b c h w")
outputs = []
for x in videos:
x = torchvision.utils.make_grid(x, nrow=n_rows)
x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
if rescale:
x = (x + 1.0) / 2.0 # -1,1 -> 0,1
x = (x * 255).numpy().astype(np.uint8)
outputs.append(x)
os.makedirs(os.path.dirname(path), exist_ok=True)
imageio.mimsave(path, outputs, fps=fps)
def save_videos_grid_pil(videos: List[PIL.Image.Image], path: str, rescale=False, n_rows=4, fps=8):
videos = rearrange(videos, "b c t h w -> t b c h w")
outputs = []
for x in videos:
x = torchvision.utils.make_grid(x, nrow=n_rows)
x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
if rescale:
x = (x + 1.0) / 2.0 # -1,1 -> 0,1
x = (x * 255).numpy().astype(np.uint8)
outputs.append(x)
os.makedirs(os.path.dirname(path), exist_ok=True)
imageio.mimsave(path, outputs, fps=fps)
def read_video(video_path, video_length, width=512, height=512, frame_rate=None):
vr = decord.VideoReader(video_path, width=width, height=height)
if frame_rate is None:
frame_rate = max(1, len(vr) // video_length)
sample_index = list(range(0, len(vr), frame_rate))[:video_length]
video = vr.get_batch(sample_index)
video = rearrange(video, "f h w c -> f c h w")
video = (video / 127.5 - 1.0)
return video
# DDIM Inversion
@torch.no_grad()
def init_prompt(prompt, pipeline):
uncond_input = pipeline.tokenizer(
[""], padding="max_length", max_length=pipeline.tokenizer.model_max_length,
return_tensors="pt"
)
uncond_embeddings = pipeline.text_encoder(uncond_input.input_ids.to(pipeline.device))[0]
text_input = pipeline.tokenizer(
[prompt],
padding="max_length",
max_length=pipeline.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_embeddings = pipeline.text_encoder(text_input.input_ids.to(pipeline.device))[0]
context = torch.cat([uncond_embeddings, text_embeddings])
return context
def next_step(model_output: Union[torch.FloatTensor, np.ndarray], timestep: int,
sample: Union[torch.FloatTensor, np.ndarray], ddim_scheduler):
timestep, next_timestep = min(
timestep - ddim_scheduler.config.num_train_timesteps // ddim_scheduler.num_inference_steps, 999), timestep
alpha_prod_t = ddim_scheduler.alphas_cumprod[timestep] if timestep >= 0 else ddim_scheduler.final_alpha_cumprod
alpha_prod_t_next = ddim_scheduler.alphas_cumprod[next_timestep]
beta_prod_t = 1 - alpha_prod_t
next_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5
next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output
next_sample = alpha_prod_t_next ** 0.5 * next_original_sample + next_sample_direction
return next_sample
def get_noise_pred_single(latents, t, context, unet):
noise_pred = unet(latents, t, encoder_hidden_states=context)["sample"]
return noise_pred
@torch.no_grad()
def ddim_loop(pipeline, ddim_scheduler, latent, num_inv_steps, prompt):
context = init_prompt(prompt, pipeline)
uncond_embeddings, cond_embeddings = context.chunk(2)
all_latent = [latent]
latent = latent.clone().detach()
for i in tqdm(range(num_inv_steps)):
t = ddim_scheduler.timesteps[len(ddim_scheduler.timesteps) - i - 1]
noise_pred = get_noise_pred_single(latent, t, cond_embeddings, pipeline.unet)
latent = next_step(noise_pred, t, latent, ddim_scheduler)
all_latent.append(latent)
return all_latent
@torch.no_grad()
def ddim_inversion(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt=""):
ddim_latents = ddim_loop(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt)
return ddim_latents
"""optical flow and trajectories sampling"""
def preprocess(img1_batch, img2_batch, transforms):
img1_batch = F.resize(img1_batch, size=[512, 512], antialias=False)
img2_batch = F.resize(img2_batch, size=[512, 512], antialias=False)
return transforms(img1_batch, img2_batch)
def keys_with_same_value(dictionary):
result = {}
for key, value in dictionary.items():
if value not in result:
result[value] = [key]
else:
result[value].append(key)
conflict_points = {}
for k in result.keys():
if len(result[k]) > 1:
conflict_points[k] = result[k]
return conflict_points
def find_duplicates(input_list):
seen = set()
duplicates = set()
for item in input_list:
if item in seen:
duplicates.add(item)
else:
seen.add(item)
return list(duplicates)
def neighbors_index(point, window_size, H, W):
"""return the spatial neighbor indices"""
t, x, y = point
neighbors = []
for i in range(-window_size, window_size + 1):
for j in range(-window_size, window_size + 1):
if i == 0 and j == 0:
continue
if x + i < 0 or x + i >= H or y + j < 0 or y + j >= W:
continue
neighbors.append((t, x + i, y + j))
return neighbors
@torch.no_grad()
def sample_trajectories(frames, device):
from torchvision.models.optical_flow import Raft_Large_Weights
from torchvision.models.optical_flow import raft_large
weights = Raft_Large_Weights.DEFAULT
transforms = weights.transforms()
# frames, _, _ = torchvision.io.read_video(str(video_path), output_format="TCHW")
clips = list(range(len(frames)))
model = raft_large(weights=Raft_Large_Weights.DEFAULT, progress=False).to(device)
model = model.eval()
finished_trajectories = []
current_frames, next_frames = preprocess(frames[clips[:-1]], frames[clips[1:]], transforms)
list_of_flows = model(current_frames.to(device), next_frames.to(device))
predicted_flows = list_of_flows[-1]
predicted_flows = predicted_flows/512
resolutions = [64, 32, 16, 8]
res = {}
window_sizes = {64: 2,
32: 1,
16: 1,
8: 1}
for resolution in resolutions:
print("="*30)
trajectories = {}
predicted_flow_resolu = torch.round(resolution*torch.nn.functional.interpolate(predicted_flows, scale_factor=(resolution/512, resolution/512)))
T = predicted_flow_resolu.shape[0]+1
H = predicted_flow_resolu.shape[2]
W = predicted_flow_resolu.shape[3]
is_activated = torch.zeros([T, H, W], dtype=torch.bool)
for t in range(T-1):
flow = predicted_flow_resolu[t]
for h in range(H):
for w in range(W):
if not is_activated[t, h, w]:
is_activated[t, h, w] = True
# this point has not been traversed, start new trajectory
x = h + int(flow[1, h, w])
y = w + int(flow[0, h, w])
if x >= 0 and x < H and y >= 0 and y < W:
# trajectories.append([(t, h, w), (t+1, x, y)])
trajectories[(t, h, w)]= (t+1, x, y)
conflict_points = keys_with_same_value(trajectories)
for k in conflict_points:
index_to_pop = random.randint(0, len(conflict_points[k]) - 1)
conflict_points[k].pop(index_to_pop)
for point in conflict_points[k]:
if point[0] != T-1:
trajectories[point]= (-1, -1, -1) # stupid padding with (-1, -1, -1)
active_traj = []
all_traj = []
for t in range(T):
pixel_set = {(t, x//H, x%H):0 for x in range(H*W)}
new_active_traj = []
for traj in active_traj:
if traj[-1] in trajectories:
v = trajectories[traj[-1]]
new_active_traj.append(traj + [v])
pixel_set[v] = 1
else:
all_traj.append(traj)
active_traj = new_active_traj
active_traj+=[[pixel] for pixel in pixel_set if pixel_set[pixel] == 0]
all_traj += active_traj
useful_traj = [i for i in all_traj if len(i)>1]
for idx in range(len(useful_traj)):
if useful_traj[idx][-1] == (-1, -1, -1):
useful_traj[idx] = useful_traj[idx][:-1]
print("how many points in all trajectories for resolution{}?".format(resolution), sum([len(i) for i in useful_traj]))
print("how many points in the video for resolution{}?".format(resolution), T*H*W)
# validate if there are no duplicates in the trajectories
trajs = []
for traj in useful_traj:
trajs = trajs + traj
assert len(find_duplicates(trajs)) == 0, "There should not be duplicates in the useful trajectories."
# check if non-appearing points + appearing points = all the points in the video
all_points = set([(t, x, y) for t in range(T) for x in range(H) for y in range(W)])
left_points = all_points- set(trajs)
print("How many points not in the trajectories for resolution{}?".format(resolution), len(left_points))
for p in list(left_points):
useful_traj.append([p])
print("how many points in all trajectories for resolution{} after pending?".format(resolution), sum([len(i) for i in useful_traj]))
longest_length = max([len(i) for i in useful_traj])
sequence_length = (window_sizes[resolution]*2+1)**2 + longest_length - 1
seqs = []
masks = []
# create a dictionary to facilitate checking the trajectories to which each point belongs.
point_to_traj = {}
for traj in useful_traj:
for p in traj:
point_to_traj[p] = traj
for t in range(T):
for x in range(H):
for y in range(W):
neighbours = neighbors_index((t,x,y), window_sizes[resolution], H, W)
sequence = [(t,x,y)]+neighbours + [(0,0,0) for i in range((window_sizes[resolution]*2+1)**2-1-len(neighbours))]
sequence_mask = torch.zeros(sequence_length, dtype=torch.bool)
sequence_mask[:len(neighbours)+1] = True
traj = point_to_traj[(t,x,y)].copy()
traj.remove((t,x,y))
sequence = sequence + traj + [(0,0,0) for k in range(longest_length-1-len(traj))]
sequence_mask[(window_sizes[resolution]*2+1)**2: (window_sizes[resolution]*2+1)**2 + len(traj)] = True
seqs.append(sequence)
masks.append(sequence_mask)
seqs = torch.tensor(seqs)
masks = torch.stack(masks)
res["traj{}".format(resolution)] = seqs
res["mask{}".format(resolution)] = masks
return res