Cletrason commited on
Commit
adbb9b1
1 Parent(s): 5bd153a

Upload 7 files

Browse files
Files changed (7) hide show
  1. app_canny_db.py +103 -0
  2. app_text_to_video.py +97 -0
  3. config.py +1 -0
  4. gradio_utils.py +98 -0
  5. hf_utils.py +39 -0
  6. style (1).css +3 -0
  7. utils (1).py +207 -0
app_canny_db.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from model import Model
3
+ import gradio_utils
4
+ import os
5
+ on_huggingspace = os.environ.get("SPACE_AUTHOR_NAME") == "PAIR"
6
+
7
+
8
+ examples = [
9
+ ['Anime DB', "woman1", "Portrait of detailed 1girl, feminine, soldier cinematic shot on canon 5d ultra realistic skin intricate clothes accurate hands Rory Lewis Artgerm WLOP Jeremy Lipking Jane Ansell studio lighting"],
10
+ ['Arcane DB', "woman1", "Oil painting of a beautiful girl arcane style, masterpiece, a high-quality, detailed, and professional photo"],
11
+ ['GTA-5 DB', "man1", "gtav style"],
12
+ ['GTA-5 DB', "woman3", "gtav style"],
13
+ ['Avatar DB', "woman2", "oil painting of a beautiful girl avatar style"],
14
+ ]
15
+
16
+
17
+ def load_db_model(evt: gr.SelectData):
18
+ db_name = gradio_utils.get_db_name_from_id(evt.index)
19
+ return db_name
20
+
21
+
22
+ def canny_select(evt: gr.SelectData):
23
+ canny_name = gradio_utils.get_canny_name_from_id(evt.index)
24
+ return canny_name
25
+
26
+
27
+ def create_demo(model: Model):
28
+
29
+ with gr.Blocks() as demo:
30
+ with gr.Row():
31
+ gr.Markdown(
32
+ '## Text, Canny-Edge and DreamBooth Conditional Video Generation')
33
+ with gr.Row():
34
+ gr.HTML(
35
+ """
36
+ <div style="text-align: left; auto;">
37
+ <h2 style="font-weight: 450; font-size: 1rem; margin: 0rem">
38
+ Description: Our current release supports only four predefined DreamBooth models and four "motion edges". So you must choose one DreamBooth model and one "motion edges" shown below, or use the examples. The keywords <b>1girl</b>, <b>arcane style</b>, <b>gtav</b>, and <b>avatar style</b> correspond to the models from left to right.
39
+ </h3>
40
+ </div>
41
+ """)
42
+ with gr.Row():
43
+ with gr.Column():
44
+ # input_video_path = gr.Video(source='upload', format="mp4", visible=False)
45
+ gr.Markdown("## Selection")
46
+ db_text_field = gr.Markdown('DB Model: **Anime DB** ')
47
+ canny_text_field = gr.Markdown('Motion: **woman1**')
48
+ prompt = gr.Textbox(label='Prompt')
49
+ run_button = gr.Button(label='Run')
50
+ with gr.Accordion('Advanced options', open=False):
51
+ watermark = gr.Radio(["Picsart AI Research", "Text2Video-Zero",
52
+ "None"], label="Watermark", value='Picsart AI Research')
53
+ chunk_size = gr.Slider(
54
+ label="Chunk size", minimum=2, maximum=16, value=12 if on_huggingspace else 8, step=1, visible=not on_huggingspace)
55
+ with gr.Column():
56
+ result = gr.Image(label="Generated Video").style(height=400)
57
+
58
+ with gr.Row():
59
+ gallery_db = gr.Gallery(label="Db models", value=[('__assets__/db_files/anime.jpg', "anime"), ('__assets__/db_files/arcane.jpg', "Arcane"), (
60
+ '__assets__/db_files/gta.jpg', "GTA-5 (Man)"), ('__assets__/db_files/avatar.jpg', "Avatar DB")]).style(grid=[4], height=50)
61
+ with gr.Row():
62
+ gallery_canny = gr.Gallery(label="Motions", value=[('__assets__/db_files/woman1.gif', "woman1"), ('__assets__/db_files/woman2.gif', "woman2"), (
63
+ '__assets__/db_files/man1.gif', "man1"), ('__assets__/db_files/woman3.gif', "woman3")]).style(grid=[4], height=50)
64
+
65
+ db_selection = gr.Textbox(label="DB Model", visible=False)
66
+ canny_selection = gr.Textbox(
67
+ label="One of the above defined motions", visible=False)
68
+
69
+ gallery_db.select(load_db_model, None, db_selection)
70
+ gallery_canny.select(canny_select, None, canny_selection)
71
+
72
+ db_selection.change(on_db_selection_update, None, db_text_field)
73
+ canny_selection.change(on_canny_selection_update,
74
+ None, canny_text_field)
75
+
76
+ inputs = [
77
+ db_selection,
78
+ canny_selection,
79
+ prompt,
80
+ chunk_size,
81
+ watermark,
82
+ ]
83
+
84
+ gr.Examples(examples=examples,
85
+ inputs=inputs,
86
+ outputs=result,
87
+ fn=model.process_controlnet_canny_db,
88
+ cache_examples=on_huggingspace,
89
+ )
90
+
91
+ run_button.click(fn=model.process_controlnet_canny_db,
92
+ inputs=inputs,
93
+ outputs=result,)
94
+ return demo
95
+
96
+
97
+ def on_db_selection_update(evt: gr.EventData):
98
+
99
+ return f"DB model: **{evt._data}**"
100
+
101
+
102
+ def on_canny_selection_update(evt: gr.EventData):
103
+ return f"Motion: **{evt._data}**"
app_text_to_video.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from model import Model
3
+ import os
4
+ from hf_utils import get_model_list
5
+
6
+ on_huggingspace = os.environ.get("SPACE_AUTHOR_NAME") == "PAIR"
7
+
8
+ examples = [
9
+ ["an astronaut waving the arm on the moon"],
10
+ ["a sloth surfing on a wakeboard"],
11
+ ["an astronaut walking on a street"],
12
+ ["a cute cat walking on grass"],
13
+ ["a horse is galloping on a street"],
14
+ ["an astronaut is skiing down the hill"],
15
+ ["a gorilla walking alone down the street"],
16
+ ["a gorilla dancing on times square"],
17
+ ["A panda dancing dancing like crazy on Times Square"],
18
+ ]
19
+
20
+
21
+ def create_demo(model: Model):
22
+
23
+ with gr.Blocks() as demo:
24
+ with gr.Row():
25
+ gr.Markdown('## Text2Video-Zero: Video Generation')
26
+ with gr.Row():
27
+ gr.HTML(
28
+ """
29
+ <div style="text-align: left; auto;">
30
+ <h2 style="font-weight: 450; font-size: 1rem; margin: 0rem">
31
+ Description: Simply input <b>any textual prompt</b> to generate videos right away and unleash your creativity and imagination! You can also select from the examples below. For performance purposes, our current preview release allows to generate up to 16 frames, which can be configured in the Advanced Options.
32
+ </h3>
33
+ </div>
34
+ """)
35
+
36
+ with gr.Row():
37
+ with gr.Column():
38
+ model_name = gr.Dropdown(
39
+ label="Model",
40
+ choices=get_model_list(),
41
+ value="dreamlike-art/dreamlike-photoreal-2.0",
42
+ )
43
+ prompt = gr.Textbox(label='Prompt')
44
+ run_button = gr.Button(label='Run')
45
+ with gr.Accordion('Advanced options', open=False):
46
+ watermark = gr.Radio(["Picsart AI Research", "Text2Video-Zero",
47
+ "None"], label="Watermark", value='Picsart AI Research')
48
+
49
+ if on_huggingspace:
50
+ video_length = gr.Slider(
51
+ label="Video length", minimum=8, maximum=16, step=1)
52
+ else:
53
+ video_length = gr.Number(
54
+ label="Video length", value=8, precision=0)
55
+ chunk_size = gr.Slider(
56
+ label="Chunk size", minimum=2, maximum=16, value=12 if on_huggingspace else 8, step=1, visible=not on_huggingspace)
57
+
58
+ motion_field_strength_x = gr.Slider(
59
+ label='Global Translation $\delta_{x}$', minimum=-20, maximum=20, value=12, step=1)
60
+ motion_field_strength_y = gr.Slider(
61
+ label='Global Translation $\delta_{y}$', minimum=-20, maximum=20, value=12, step=1)
62
+
63
+ t0 = gr.Slider(label="Timestep t0", minimum=0,
64
+ maximum=49, value=44, step=1)
65
+ t1 = gr.Slider(label="Timestep t1", minimum=0,
66
+ maximum=49, value=47, step=1)
67
+
68
+ n_prompt = gr.Textbox(
69
+ label="Optional Negative Prompt", value='')
70
+ with gr.Column():
71
+ result = gr.Video(label="Generated Video")
72
+
73
+ inputs = [
74
+ prompt,
75
+ model_name,
76
+ motion_field_strength_x,
77
+ motion_field_strength_y,
78
+ t0,
79
+ t1,
80
+ n_prompt,
81
+ chunk_size,
82
+ video_length,
83
+ watermark,
84
+ ]
85
+
86
+ gr.Examples(examples=examples,
87
+ inputs=inputs,
88
+ outputs=result,
89
+ fn=model.process_text2video,
90
+ run_on_click=False,
91
+ cache_examples=on_huggingspace,
92
+ )
93
+
94
+ run_button.click(fn=model.process_text2video,
95
+ inputs=inputs,
96
+ outputs=result,)
97
+ return demo
config.py ADDED
@@ -0,0 +1 @@
 
 
1
+ save_memory = False
gradio_utils.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ # App Canny utils
4
+
5
+
6
+ def edge_path_to_video_path(edge_path):
7
+ video_path = edge_path
8
+
9
+ vid_name = edge_path.split("/")[-1]
10
+ if vid_name == "butterfly.mp4":
11
+ video_path = "__assets__/canny_videos_mp4_2fps/butterfly.mp4"
12
+ elif vid_name == "deer.mp4":
13
+ video_path = "__assets__/canny_videos_mp4_2fps/deer.mp4"
14
+ elif vid_name == "fox.mp4":
15
+ video_path = "__assets__/canny_videos_mp4_2fps/fox.mp4"
16
+ elif vid_name == "girl_dancing.mp4":
17
+ video_path = "__assets__/canny_videos_mp4_2fps/girl_dancing.mp4"
18
+ elif vid_name == "girl_turning.mp4":
19
+ video_path = "__assets__/canny_videos_mp4_2fps/girl_turning.mp4"
20
+ elif vid_name == "halloween.mp4":
21
+ video_path = "__assets__/canny_videos_mp4_2fps/halloween.mp4"
22
+ elif vid_name == "santa.mp4":
23
+ video_path = "__assets__/canny_videos_mp4_2fps/santa.mp4"
24
+
25
+ assert os.path.isfile(video_path)
26
+ return video_path
27
+
28
+
29
+ # App Pose utils
30
+ def motion_to_video_path(motion):
31
+ videos = [
32
+ "__assets__/poses_skeleton_gifs/dance1_corr.mp4",
33
+ "__assets__/poses_skeleton_gifs/dance2_corr.mp4",
34
+ "__assets__/poses_skeleton_gifs/dance3_corr.mp4",
35
+ "__assets__/poses_skeleton_gifs/dance4_corr.mp4",
36
+ "__assets__/poses_skeleton_gifs/dance5_corr.mp4"
37
+ ]
38
+ if len(motion.split(" ")) > 1 and motion.split(" ")[1].isnumeric():
39
+ id = int(motion.split(" ")[1]) - 1
40
+ return videos[id]
41
+ else:
42
+ return motion
43
+
44
+
45
+ # App Canny Dreambooth utils
46
+ def get_video_from_canny_selection(canny_selection):
47
+ if canny_selection == "woman1":
48
+ input_video_path = "__assets__/db_files_2fps/woman1.mp4"
49
+
50
+ elif canny_selection == "woman2":
51
+ input_video_path = "__assets__/db_files_2fps/woman2.mp4"
52
+
53
+ elif canny_selection == "man1":
54
+ input_video_path = "__assets__/db_files_2fps/man1.mp4"
55
+
56
+ elif canny_selection == "woman3":
57
+ input_video_path = "__assets__/db_files_2fps/woman3.mp4"
58
+ else:
59
+ input_video_path = canny_selection
60
+
61
+ assert os.path.isfile(input_video_path)
62
+ return input_video_path
63
+
64
+
65
+ def get_model_from_db_selection(db_selection):
66
+ if db_selection == "Anime DB":
67
+ input_video_path = 'PAIR/text2video-zero-controlnet-canny-anime'
68
+ elif db_selection == "Avatar DB":
69
+ input_video_path = 'PAIR/text2video-zero-controlnet-canny-avatar'
70
+ elif db_selection == "GTA-5 DB":
71
+ input_video_path = 'PAIR/text2video-zero-controlnet-canny-gta5'
72
+ elif db_selection == "Arcane DB":
73
+ input_video_path = 'PAIR/text2video-zero-controlnet-canny-arcane'
74
+ else:
75
+ input_video_path = db_selection
76
+
77
+ return input_video_path
78
+
79
+
80
+ def get_db_name_from_id(id):
81
+ db_names = ["Anime DB", "Arcane DB", "GTA-5 DB", "Avatar DB"]
82
+ return db_names[id]
83
+
84
+
85
+ def get_canny_name_from_id(id):
86
+ canny_names = ["woman1", "woman2", "man1", "woman3"]
87
+ return canny_names[id]
88
+
89
+
90
+ def logo_name_to_path(name):
91
+ logo_paths = {
92
+ 'Picsart AI Research': '__assets__/pair_watermark.png',
93
+ 'Text2Video-Zero': '__assets__/t2v-z_watermark.png',
94
+ 'None': None
95
+ }
96
+ if name in logo_paths:
97
+ return logo_paths[name]
98
+ return name
hf_utils.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from bs4 import BeautifulSoup
2
+ import requests
3
+
4
+
5
+ def model_url_list():
6
+ url_list = []
7
+ for i in range(0, 5):
8
+ url_list.append(
9
+ f"https://huggingface.co/models?p={i}&sort=downloads&search=dreambooth")
10
+ return url_list
11
+
12
+
13
+ def data_scraping(url_list):
14
+ model_list = []
15
+ for url in url_list:
16
+ response = requests.get(url)
17
+ soup = BeautifulSoup(response.text, "html.parser")
18
+ div_class = 'grid grid-cols-1 gap-5 2xl:grid-cols-2'
19
+ div = soup.find('div', {'class': div_class})
20
+ for a in div.find_all('a', href=True):
21
+ model_list.append(a['href'])
22
+ return model_list
23
+
24
+
25
+ def get_model_list():
26
+ model_list = data_scraping(model_url_list())
27
+ for i in range(len(model_list)):
28
+ model_list[i] = model_list[i][1:]
29
+
30
+ best_model_list = [
31
+ "dreamlike-art/dreamlike-photoreal-2.0",
32
+ "dreamlike-art/dreamlike-diffusion-1.0",
33
+ "runwayml/stable-diffusion-v1-5",
34
+ "CompVis/stable-diffusion-v1-4",
35
+ "prompthero/openjourney",
36
+ ]
37
+
38
+ model_list = best_model_list + model_list
39
+ return model_list
style (1).css ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ h1 {
2
+ text-align: center;
3
+ }
utils (1).py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import PIL.Image
4
+ import numpy as np
5
+ import torch
6
+ import torchvision
7
+ from torchvision.transforms import Resize, InterpolationMode
8
+ import imageio
9
+ from einops import rearrange
10
+ import cv2
11
+ from PIL import Image
12
+ from annotator.util import resize_image, HWC3
13
+ from annotator.canny import CannyDetector
14
+ from annotator.openpose import OpenposeDetector
15
+ import decord
16
+ # decord.bridge.set_bridge('torch')
17
+
18
+ apply_canny = CannyDetector()
19
+ apply_openpose = OpenposeDetector()
20
+
21
+
22
+ def add_watermark(image, watermark_path, wm_rel_size=1/16, boundary=5):
23
+ '''
24
+ Creates a watermark on the saved inference image.
25
+ We request that you do not remove this to properly assign credit to
26
+ Shi-Lab's work.
27
+ '''
28
+ watermark = Image.open(watermark_path)
29
+ w_0, h_0 = watermark.size
30
+ H, W, _ = image.shape
31
+ wmsize = int(max(H, W) * wm_rel_size)
32
+ aspect = h_0 / w_0
33
+ if aspect > 1.0:
34
+ watermark = watermark.resize((wmsize, int(aspect * wmsize)), Image.LANCZOS)
35
+ else:
36
+ watermark = watermark.resize((int(wmsize / aspect), wmsize), Image.LANCZOS)
37
+ w, h = watermark.size
38
+ loc_h = H - h - boundary
39
+ loc_w = W - w - boundary
40
+ image = Image.fromarray(image)
41
+ mask = watermark if watermark.mode in ('RGBA', 'LA') else None
42
+ image.paste(watermark, (loc_w, loc_h), mask)
43
+ return image
44
+
45
+
46
+ def pre_process_canny(input_video, low_threshold=100, high_threshold=200):
47
+ detected_maps = []
48
+ for frame in input_video:
49
+ img = rearrange(frame, 'c h w -> h w c').cpu().numpy().astype(np.uint8)
50
+ detected_map = apply_canny(img, low_threshold, high_threshold)
51
+ detected_map = HWC3(detected_map)
52
+ detected_maps.append(detected_map[None])
53
+ detected_maps = np.concatenate(detected_maps)
54
+ control = torch.from_numpy(detected_maps.copy()).float() / 255.0
55
+ return rearrange(control, 'f h w c -> f c h w')
56
+
57
+
58
+ def pre_process_pose(input_video, apply_pose_detect: bool = True):
59
+ detected_maps = []
60
+ for frame in input_video:
61
+ img = rearrange(frame, 'c h w -> h w c').cpu().numpy().astype(np.uint8)
62
+ img = HWC3(img)
63
+ if apply_pose_detect:
64
+ detected_map, _ = apply_openpose(img)
65
+ else:
66
+ detected_map = img
67
+ detected_map = HWC3(detected_map)
68
+ H, W, C = img.shape
69
+ detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_NEAREST)
70
+ detected_maps.append(detected_map[None])
71
+ detected_maps = np.concatenate(detected_maps)
72
+ control = torch.from_numpy(detected_maps.copy()).float() / 255.0
73
+ return rearrange(control, 'f h w c -> f c h w')
74
+
75
+
76
+ def create_video(frames, fps, rescale=False, path=None, watermark=None):
77
+ if path is None:
78
+ dir = "temporal"
79
+ os.makedirs(dir, exist_ok=True)
80
+ path = os.path.join(dir, 'movie.mp4')
81
+
82
+ outputs = []
83
+ for i, x in enumerate(frames):
84
+ x = torchvision.utils.make_grid(torch.Tensor(x), nrow=4)
85
+ if rescale:
86
+ x = (x + 1.0) / 2.0 # -1,1 -> 0,1
87
+ x = (x * 255).numpy().astype(np.uint8)
88
+
89
+ if watermark is not None:
90
+ x = add_watermark(x, watermark)
91
+ outputs.append(x)
92
+ # imageio.imsave(os.path.join(dir, os.path.splitext(name)[0] + f'_{i}.jpg'), x)
93
+
94
+ imageio.mimsave(path, outputs, fps=fps)
95
+ return path
96
+
97
+ def create_gif(frames, fps, rescale=False, path=None, watermark=None):
98
+ if path is None:
99
+ dir = "temporal"
100
+ os.makedirs(dir, exist_ok=True)
101
+ path = os.path.join(dir, 'canny_db.gif')
102
+
103
+ outputs = []
104
+ for i, x in enumerate(frames):
105
+ x = torchvision.utils.make_grid(torch.Tensor(x), nrow=4)
106
+ if rescale:
107
+ x = (x + 1.0) / 2.0 # -1,1 -> 0,1
108
+ x = (x * 255).numpy().astype(np.uint8)
109
+ if watermark is not None:
110
+ x = add_watermark(x, watermark)
111
+ outputs.append(x)
112
+ # imageio.imsave(os.path.join(dir, os.path.splitext(name)[0] + f'_{i}.jpg'), x)
113
+
114
+ imageio.mimsave(path, outputs, fps=fps)
115
+ return path
116
+
117
+ def prepare_video(video_path:str, resolution:int, device, dtype, normalize=True, start_t:float=0, end_t:float=-1, output_fps:int=-1):
118
+ vr = decord.VideoReader(video_path)
119
+ initial_fps = vr.get_avg_fps()
120
+ if output_fps == -1:
121
+ output_fps = int(initial_fps)
122
+ if end_t == -1:
123
+ end_t = len(vr) / initial_fps
124
+ else:
125
+ end_t = min(len(vr) / initial_fps, end_t)
126
+ assert 0 <= start_t < end_t
127
+ assert output_fps > 0
128
+ start_f_ind = int(start_t * initial_fps)
129
+ end_f_ind = int(end_t * initial_fps)
130
+ num_f = int((end_t - start_t) * output_fps)
131
+ sample_idx = np.linspace(start_f_ind, end_f_ind, num_f, endpoint=False).astype(int)
132
+ video = vr.get_batch(sample_idx)
133
+ if torch.is_tensor(video):
134
+ video = video.detach().cpu().numpy()
135
+ else:
136
+ video = video.asnumpy()
137
+ _, h, w, _ = video.shape
138
+ video = rearrange(video, "f h w c -> f c h w")
139
+ video = torch.Tensor(video).to(device).to(dtype)
140
+ if h > w:
141
+ w = int(w * resolution / h)
142
+ w = w - w % 8
143
+ h = resolution - resolution % 8
144
+ else:
145
+ h = int(h * resolution / w)
146
+ h = h - h % 8
147
+ w = resolution - resolution % 8
148
+ video = Resize((h, w), interpolation=InterpolationMode.BILINEAR, antialias=True)(video)
149
+ if normalize:
150
+ video = video / 127.5 - 1.0
151
+ return video, output_fps
152
+
153
+
154
+ def post_process_gif(list_of_results, image_resolution):
155
+ output_file = "/tmp/ddxk.gif"
156
+ imageio.mimsave(output_file, list_of_results, fps=4)
157
+ return output_file
158
+
159
+
160
+ class CrossFrameAttnProcessor:
161
+ def __init__(self, unet_chunk_size=2):
162
+ self.unet_chunk_size = unet_chunk_size
163
+
164
+ def __call__(
165
+ self,
166
+ attn,
167
+ hidden_states,
168
+ encoder_hidden_states=None,
169
+ attention_mask=None):
170
+ batch_size, sequence_length, _ = hidden_states.shape
171
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
172
+ query = attn.to_q(hidden_states)
173
+
174
+ is_cross_attention = encoder_hidden_states is not None
175
+ if encoder_hidden_states is None:
176
+ encoder_hidden_states = hidden_states
177
+ elif attn.cross_attention_norm:
178
+ encoder_hidden_states = attn.norm_cross(encoder_hidden_states)
179
+ key = attn.to_k(encoder_hidden_states)
180
+ value = attn.to_v(encoder_hidden_states)
181
+ # Sparse Attention
182
+ if not is_cross_attention:
183
+ video_length = key.size()[0] // self.unet_chunk_size
184
+ # former_frame_index = torch.arange(video_length) - 1
185
+ # former_frame_index[0] = 0
186
+ former_frame_index = [0] * video_length
187
+ key = rearrange(key, "(b f) d c -> b f d c", f=video_length)
188
+ key = key[:, former_frame_index]
189
+ key = rearrange(key, "b f d c -> (b f) d c")
190
+ value = rearrange(value, "(b f) d c -> b f d c", f=video_length)
191
+ value = value[:, former_frame_index]
192
+ value = rearrange(value, "b f d c -> (b f) d c")
193
+
194
+ query = attn.head_to_batch_dim(query)
195
+ key = attn.head_to_batch_dim(key)
196
+ value = attn.head_to_batch_dim(value)
197
+
198
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
199
+ hidden_states = torch.bmm(attention_probs, value)
200
+ hidden_states = attn.batch_to_head_dim(hidden_states)
201
+
202
+ # linear proj
203
+ hidden_states = attn.to_out[0](hidden_states)
204
+ # dropout
205
+ hidden_states = attn.to_out[1](hidden_states)
206
+
207
+ return hidden_states