Text2Video-Zero / utils.py
kirch's picture
Duplicate from PAIR/Text2Video-Zero
508927a
raw
history blame
7 kB
import os
import numpy as np
import torch
import torchvision
from torchvision.transforms import Resize
import imageio
from einops import rearrange
import cv2
from PIL import Image
from annotator.util import resize_image, HWC3
from annotator.canny import CannyDetector
from annotator.openpose import OpenposeDetector
import decord
decord.bridge.set_bridge('torch')
apply_canny = CannyDetector()
apply_openpose = OpenposeDetector()
def add_watermark(image, im_size_h, im_size_w, watermark_path="__assets__/picsart_watermark.jpg",
wmsize=16, bbuf=5, opacity=0.9):
'''
Creates a watermark on the saved inference image.
We request that you do not remove this to properly assign credit to
Shi-Lab's work.
'''
watermark = Image.open(watermark_path).resize((wmsize, wmsize))
loc_h = im_size_h - wmsize - bbuf
loc_w = im_size_w - wmsize - bbuf
image[loc_h:-bbuf, loc_w:-bbuf, :] = watermark
return image
def pre_process_canny(input_video, low_threshold=100, high_threshold=200):
detected_maps = []
for frame in input_video:
img = rearrange(frame, 'c h w -> h w c').cpu().numpy().astype(np.uint8)
detected_map = apply_canny(img, low_threshold, high_threshold)
detected_map = HWC3(detected_map)
detected_maps.append(detected_map[None])
detected_maps = np.concatenate(detected_maps)
control = torch.from_numpy(detected_maps.copy()).float() / 255.0
return rearrange(control, 'f h w c -> f c h w')
def pre_process_pose(input_video, apply_pose_detect: bool = True):
detected_maps = []
for frame in input_video:
img = rearrange(frame, 'c h w -> h w c').cpu().numpy().astype(np.uint8)
img = HWC3(img)
if apply_pose_detect:
detected_map, _ = apply_openpose(img)
else:
detected_map = img
detected_map = HWC3(detected_map)
H, W, C = img.shape
detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_NEAREST)
detected_maps.append(detected_map[None])
detected_maps = np.concatenate(detected_maps)
control = torch.from_numpy(detected_maps.copy()).float() / 255.0
return rearrange(control, 'f h w c -> f c h w')
def create_video(frames, fps, rescale=False, path=None):
if path is None:
dir = "temporal"
os.makedirs(dir, exist_ok=True)
path = os.path.join(dir, 'movie.mp4')
outputs = []
for i, x in enumerate(frames):
x = torchvision.utils.make_grid(torch.Tensor(x), nrow=4)
if rescale:
x = (x + 1.0) / 2.0 # -1,1 -> 0,1
x = (x * 255).numpy().astype(np.uint8)
h_, w_, _ = x.shape
x = add_watermark(x, im_size_h=h_, im_size_w=w_)
outputs.append(x)
# imageio.imsave(os.path.join(dir, os.path.splitext(name)[0] + f'_{i}.jpg'), x)
imageio.mimsave(path, outputs, fps=fps)
return path
def create_gif(frames, fps, rescale=False):
dir = "temporal"
os.makedirs(dir, exist_ok=True)
path = os.path.join(dir, 'canny_db.gif')
outputs = []
for i, x in enumerate(frames):
x = torchvision.utils.make_grid(torch.Tensor(x), nrow=4)
if rescale:
x = (x + 1.0) / 2.0 # -1,1 -> 0,1
x = (x * 255).numpy().astype(np.uint8)
h_, w_, _ = x.shape
x = add_watermark(x, im_size_h=h_, im_size_w=w_)
outputs.append(x)
# imageio.imsave(os.path.join(dir, os.path.splitext(name)[0] + f'_{i}.jpg'), x)
imageio.mimsave(path, outputs, fps=fps)
return path
def prepare_video(video_path:str, resolution:int, device, dtype, normalize=True, start_t:float=0, end_t:float=-1, output_fps:int=-1):
vr = decord.VideoReader(video_path)
video = vr.get_batch(range(0, len(vr))).asnumpy()
initial_fps = vr.get_avg_fps()
if output_fps == -1:
output_fps = int(initial_fps)
if end_t == -1:
end_t = len(vr) / initial_fps
else:
end_t = min(len(vr) / initial_fps, end_t)
assert 0 <= start_t < end_t
assert output_fps > 0
f, h, w, c = video.shape
start_f_ind = int(start_t * initial_fps)
end_f_ind = int(end_t * initial_fps)
num_f = int((end_t - start_t) * output_fps)
sample_idx = np.linspace(start_f_ind, end_f_ind, num_f, endpoint=False).astype(int)
video = video[sample_idx]
video = rearrange(video, "f h w c -> f c h w")
video = torch.Tensor(video).to(device).to(dtype)
if h > w:
w = int(w * resolution / h)
w = w - w % 8
h = resolution - resolution % 8
video = Resize((h, w))(video)
else:
h = int(h * resolution / w)
h = h - h % 8
w = resolution - resolution % 8
video = Resize((h, w))(video)
if normalize:
video = video / 127.5 - 1.0
return video, output_fps
def post_process_gif(list_of_results, image_resolution):
output_file = "/tmp/ddxk.gif"
imageio.mimsave(output_file, list_of_results, fps=4)
return output_file
class CrossFrameAttnProcessor:
def __init__(self, unet_chunk_size=2):
self.unet_chunk_size = unet_chunk_size
def __call__(
self,
attn,
hidden_states,
encoder_hidden_states=None,
attention_mask=None):
batch_size, sequence_length, _ = hidden_states.shape
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
query = attn.to_q(hidden_states)
is_cross_attention = encoder_hidden_states is not None
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.cross_attention_norm:
encoder_hidden_states = attn.norm_cross(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
# Sparse Attention
if not is_cross_attention:
video_length = key.size()[0] // self.unet_chunk_size
# former_frame_index = torch.arange(video_length) - 1
# former_frame_index[0] = 0
former_frame_index = [0] * video_length
key = rearrange(key, "(b f) d c -> b f d c", f=video_length)
key = key[:, former_frame_index]
key = rearrange(key, "b f d c -> (b f) d c")
value = rearrange(value, "(b f) d c -> b f d c", f=video_length)
value = value[:, former_frame_index]
value = rearrange(value, "b f d c -> (b f) d c")
query = attn.head_to_batch_dim(query)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
attention_probs = attn.get_attention_scores(query, key, attention_mask)
hidden_states = torch.bmm(attention_probs, value)
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
return hidden_states