lev1 commited on
Commit
714bf26
β€’
1 Parent(s): ccfd1d5

T2V, Video Pix2Pix and Pose-Guided Gen

Browse files
Files changed (12) hide show
  1. README.md +5 -7
  2. app.py +73 -0
  3. app_pix2pix_video.py +70 -0
  4. app_pose.py +62 -0
  5. app_text_to_video.py +44 -0
  6. config.py +1 -0
  7. gradio_utils.py +77 -0
  8. model.py +296 -0
  9. requirements.txt +34 -0
  10. share.py +8 -0
  11. style.css +3 -0
  12. utils.py +187 -0
README.md CHANGED
@@ -1,12 +1,10 @@
1
  ---
2
- title: Text2Video Zero
3
- emoji: πŸ’»
4
- colorFrom: blue
5
- colorTo: pink
6
  sdk: gradio
7
  sdk_version: 3.23.0
8
  app_file: app.py
9
  pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: Text2Video-Zero
3
+ emoji: πŸš€
4
+ colorFrom: green
5
+ colorTo: blue
6
  sdk: gradio
7
  sdk_version: 3.23.0
8
  app_file: app.py
9
  pinned: false
10
+ ---
 
 
app.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+
4
+ from model import Model, ModelType
5
+
6
+ # from app_canny import create_demo as create_demo_canny
7
+ from app_pose import create_demo as create_demo_pose
8
+ from app_text_to_video import create_demo as create_demo_text_to_video
9
+ from app_pix2pix_video import create_demo as create_demo_pix2pix_video
10
+ # from app_canny_db import create_demo as create_demo_canny_db
11
+
12
+
13
+ model = Model(device='cuda', dtype=torch.float16)
14
+
15
+ with gr.Blocks(css='style.css') as demo:
16
+ gr.HTML(
17
+ """
18
+ <div style="text-align: center; max-width: 1200px; margin: 20px auto;">
19
+ <h1 style="font-weight: 900; font-size: 3rem; margin: 0rem">
20
+ Text2Video-Zero
21
+ </h1>
22
+ <h2 style="font-weight: 450; font-size: 1rem; margin-top: 0.5rem; margin-bottom: 0.5rem">
23
+ We propose <b>Text2Video-Zero, the first zero-shot text-to-video syntenes framework</b>, that also natively supports, Video Instruct Pix2Pix, Pose Conditional, Edge Conditional
24
+ and, Edge Conditional and DreamBooth Specialized applications.
25
+ </h2>
26
+ <h3 style="font-weight: 450; font-size: 1rem; margin: 0rem">
27
+ Levon Khachatryan, Andranik Movsisyan, Vahram Tadevosyan, Roberto Henschel, Atlas Wang, Shant Navasardyan
28
+ and <a href="https://www.humphreyshi.com/home">Humphrey Shi</a>
29
+ [<a href="" style="color:blue;">arXiv</a>]
30
+ [<a href="" style="color:blue;">GitHub</a>]
31
+ </h3>
32
+ </div>
33
+ """)
34
+
35
+ with gr.Tab('Zero-Shot Text2Video'):
36
+ # pass
37
+ create_demo_text_to_video(model)
38
+ with gr.Tab('Video Instruct Pix2Pix'):
39
+ # pass
40
+ create_demo_pix2pix_video(model)
41
+ with gr.Tab('Pose Conditional'):
42
+ # pass
43
+ create_demo_pose(model)
44
+ with gr.Tab('Edge Conditional'):
45
+ pass
46
+ # create_demo_canny(model)
47
+ with gr.Tab('Edge Conditional and Dreambooth Specialized'):
48
+ pass
49
+ # create_demo_canny_db(model)
50
+
51
+ gr.HTML(
52
+ """
53
+ <div style="text-align: justify; max-width: 1200px; margin: 20px auto;">
54
+ <h3 style="font-weight: 450; font-size: 0.8rem; margin: 0rem">
55
+ <b>Version: v1.0</b>
56
+ </h3>
57
+ <h3 style="font-weight: 450; font-size: 0.8rem; margin: 0rem">
58
+ <b>Caution</b>:
59
+ We would like the raise the awareness of users of this demo of its potential issues and concerns.
60
+ Like previous large foundation models, Text2Video-Zero could be problematic in some cases, partially we use pretrained Stable Diffusion, therefore Text2Video-Zero can Inherit Its Imperfections.
61
+ So far, we keep all features available for research testing both to show the great potential of the Text2Video-Zero framework and to collect important feedback to improve the model in the future.
62
+ We welcome researchers and users to report issues with the HuggingFace community discussion feature or email the authors.
63
+ </h3>
64
+ <h3 style="font-weight: 450; font-size: 0.8rem; margin: 0rem">
65
+ <b>Biases and content acknowledgement</b>:
66
+ Beware that Text2Video-Zero may output content that reinforces or exacerbates societal biases, as well as realistic faces, pornography, and violence.
67
+ Text2Video-Zero in this demo is meant only for research purposes.
68
+ </h3>
69
+ </div>
70
+ """)
71
+
72
+ demo.launch(debug=True)
73
+ # demo.queue(api_open=False).launch(file_directories=['temporal'], share=True)
app_pix2pix_video.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from model import Model
3
+
4
+
5
+ def create_demo(model: Model):
6
+ examples = [
7
+ ['__assets__/pix2pix video/camel.mp4', 'make it Van Gogh Starry Night style'],
8
+ ['__assets__/pix2pix video/mini-cooper.mp4', 'make it Picasso style'],
9
+ ['__assets__/pix2pix video/snowboard.mp4', 'replace man with robot'],
10
+ ['__assets__/pix2pix video/white-swan.mp4', 'replace swan with mallard'],
11
+ ]
12
+ with gr.Blocks() as demo:
13
+ with gr.Row():
14
+ gr.Markdown('## Video Instruct Pix2Pix')
15
+ with gr.Row():
16
+ with gr.Column():
17
+ input_image = gr.Video(label="Input Video",source='upload', type='numpy', format="mp4", visible=True).style(height="auto")
18
+ with gr.Column():
19
+ prompt = gr.Textbox(label='Prompt')
20
+ run_button = gr.Button(label='Run')
21
+ with gr.Accordion('Advanced options', open=False):
22
+ image_resolution = gr.Slider(label='Image Resolution',
23
+ minimum=256,
24
+ maximum=1024,
25
+ value=512,
26
+ step=64)
27
+ seed = gr.Slider(label='Seed',
28
+ minimum=0,
29
+ maximum=65536,
30
+ value=0,
31
+ step=1)
32
+ start_t = gr.Slider(label='Starting time in seconds',
33
+ minimum=0,
34
+ maximum=10,
35
+ value=0,
36
+ step=1)
37
+ end_t = gr.Slider(label='End time in seconds (-1 corresponds to uploaded video duration)',
38
+ minimum=0,
39
+ maximum=10,
40
+ value=-1,
41
+ step=1)
42
+ out_fps = gr.Slider(label='Output video fps (-1 corresponds to uploaded video fps)',
43
+ minimum=1,
44
+ maximum=30,
45
+ value=-1,
46
+ step=1)
47
+ with gr.Column():
48
+ result = gr.Video(label='Output',
49
+ show_label=True)
50
+ inputs = [
51
+ input_image,
52
+ prompt,
53
+ image_resolution,
54
+ seed,
55
+ start_t,
56
+ end_t,
57
+ out_fps
58
+ ]
59
+
60
+ gr.Examples(examples=examples,
61
+ inputs=inputs,
62
+ outputs=result,
63
+ # cache_examples=os.getenv('SYSTEM') == 'spaces',
64
+ run_on_click=False,
65
+ )
66
+
67
+ run_button.click(fn=model.process_pix2pix,
68
+ inputs=inputs,
69
+ outputs=result)
70
+ return demo
app_pose.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+
4
+ from model import Model
5
+
6
+ examples = [
7
+ ['Motion 1', "A Robot is dancing in Sahara desert"],
8
+ ['Motion 2', "A Robot is dancing in Sahara desert"],
9
+ ['Motion 3', "A Robot is dancing in Sahara desert"],
10
+ ['Motion 4', "A Robot is dancing in Sahara desert"],
11
+ ['Motion 5', "A Robot is dancing in Sahara desert"],
12
+ ]
13
+
14
+ def create_demo(model: Model):
15
+ with gr.Blocks() as demo:
16
+ with gr.Row():
17
+ gr.Markdown('## Text and Pose Conditional Video Generation')
18
+
19
+ with gr.Row():
20
+ gr.Markdown('### You must select one pose sequence shown below, or use the examples')
21
+ with gr.Column():
22
+ gallery_pose_sequence = gr.Gallery(label="Pose Sequence", value=[('__assets__/poses_skeleton_gifs/dance1.gif', "Motion 1"), ('__assets__/poses_skeleton_gifs/dance2.gif', "Motion 2"), ('__assets__/poses_skeleton_gifs/dance3.gif', "Motion 3"), ('__assets__/poses_skeleton_gifs/dance4.gif', "Motion 4"), ('__assets__/poses_skeleton_gifs/dance5.gif', "Motion 5")]).style(grid=[2], height="auto")
23
+ input_video_path = gr.Textbox(label="Pose Sequence",visible=False,value="Motion 1")
24
+ gr.Markdown("## Selection")
25
+ pose_sequence_selector = gr.Markdown('Pose Sequence: **Motion 1**')
26
+ with gr.Column():
27
+ prompt = gr.Textbox(label='Prompt')
28
+ run_button = gr.Button(label='Run')
29
+ with gr.Column():
30
+ result = gr.Image(label="Generated Video")
31
+
32
+ input_video_path.change(on_video_path_update, None, pose_sequence_selector)
33
+ gallery_pose_sequence.select(pose_gallery_callback, None, input_video_path)
34
+ inputs = [
35
+ input_video_path,
36
+ #pose_sequence,
37
+ prompt,
38
+ ]
39
+
40
+ gr.Examples(examples=examples,
41
+ inputs=inputs,
42
+ outputs=result,
43
+ # cache_examples=os.getenv('SYSTEM') == 'spaces',
44
+ fn=model.process_controlnet_pose,
45
+ run_on_click=False,
46
+ )
47
+ #fn=process,
48
+ #)
49
+
50
+
51
+ run_button.click(fn=model.process_controlnet_pose,
52
+ inputs=inputs,
53
+ outputs=result,)
54
+
55
+ return demo
56
+
57
+
58
+ def on_video_path_update(evt: gr.EventData):
59
+ return f'Pose Sequence: **{evt._data}**'
60
+
61
+ def pose_gallery_callback(evt: gr.SelectData):
62
+ return f"Motion {evt.index+1}"
app_text_to_video.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from model import Model
3
+
4
+ examples = [
5
+ "an astronaut waving the arm on the moon",
6
+ "a sloth surfing on a wakeboard",
7
+ "an astronaut walking on a street",
8
+ "a cute cat walking on grass",
9
+ "a horse is galloping on a street",
10
+ "an astronaut is skiing down the hill",
11
+ "a gorilla walking alone down the street"
12
+ "a gorilla dancing on times square",
13
+ "A panda dancing dancing like crazy on Times Square",
14
+ ]
15
+
16
+
17
+ def create_demo(model: Model):
18
+
19
+ with gr.Blocks() as demo:
20
+ with gr.Row():
21
+ gr.Markdown('## Text2Video-Zero: Video Generation')
22
+
23
+ with gr.Row():
24
+ with gr.Column():
25
+ prompt = gr.Textbox(label='Prompt')
26
+ run_button = gr.Button(label='Run')
27
+ with gr.Column():
28
+ result = gr.Video(label="Generated Video")
29
+ inputs = [
30
+ prompt,
31
+ ]
32
+
33
+ gr.Examples(examples=examples,
34
+ inputs=inputs,
35
+ outputs=result,
36
+ cache_examples=False,
37
+ #cache_examples=os.getenv('SYSTEM') == 'spaces')
38
+ run_on_click=False,
39
+ )
40
+
41
+ run_button.click(fn=model.process_text2video,
42
+ inputs=inputs,
43
+ outputs=result,)
44
+ return demo
config.py ADDED
@@ -0,0 +1 @@
 
 
1
+ save_memory = False
gradio_utils.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # App Canny utils
2
+ def edge_path_to_video_path(edge_path):
3
+ video_path = edge_path
4
+
5
+ vid_name = edge_path.split("/")[-1]
6
+ if vid_name == "butterfly.mp4":
7
+ video_path = "__assets__/canny_videos_mp4/butterfly.mp4"
8
+ elif vid_name == "deer.mp4":
9
+ video_path = "__assets__/canny_videos_mp4/deer.mp4"
10
+ elif vid_name == "fox.mp4":
11
+ video_path = "__assets__/canny_videos_mp4/fox.mp4"
12
+ elif vid_name == "girl_dancing.mp4":
13
+ video_path = "__assets__/canny_videos_mp4/girl_dancing.mp4"
14
+ elif vid_name == "girl_turning.mp4":
15
+ video_path = "__assets__/canny_videos_mp4/girl_turning.mp4"
16
+ elif vid_name == "halloween.mp4":
17
+ video_path = "__assets__/canny_videos_mp4/halloween.mp4"
18
+ elif vid_name == "santa.mp4":
19
+ video_path = "__assets__/canny_videos_mp4/santa.mp4"
20
+ return video_path
21
+
22
+
23
+ # App Pose utils
24
+ def motion_to_video_path(motion):
25
+ videos = [
26
+ "__assets__/poses_skeleton_gifs/dance1_corr.mp4",
27
+ "__assets__/poses_skeleton_gifs/dance2_corr.mp4",
28
+ "__assets__/poses_skeleton_gifs/dance3_corr.mp4",
29
+ "__assets__/poses_skeleton_gifs/dance4_corr.mp4",
30
+ "__assets__/poses_skeleton_gifs/dance5_corr.mp4"
31
+ ]
32
+ id = int(motion.split(" ")[1]) - 1
33
+ return videos[id]
34
+
35
+
36
+ # App Canny Dreambooth utils
37
+ def get_video_from_canny_selection(canny_selection):
38
+ if canny_selection == "woman1":
39
+ input_video_path = "__assets__/db_files/woman1.mp4"
40
+
41
+ elif canny_selection == "woman2":
42
+ input_video_path = "__assets__/db_files/woman2.mp4"
43
+
44
+ elif canny_selection == "man1":
45
+ input_video_path = "__assets__/db_files/man1.mp4"
46
+
47
+ elif canny_selection == "woman3":
48
+ input_video_path = "__assets__/db_files/woman3.mp4"
49
+ else:
50
+ raise Exception
51
+
52
+ return input_video_path
53
+
54
+
55
+ def get_model_from_db_selection(db_selection):
56
+ if db_selection == "Anime DB":
57
+ input_video_path = 'PAIR/controlnet-canny-anime'
58
+ elif db_selection == "Avatar DB":
59
+ input_video_path = 'PAIR/controlnet-canny-avatar'
60
+ elif db_selection == "GTA-5 DB":
61
+ input_video_path = 'PAIR/controlnet-canny-gta5'
62
+ elif db_selection == "Arcane DB":
63
+ input_video_path = 'PAIR/controlnet-canny-arcane'
64
+ else:
65
+ raise Exception
66
+ return input_video_path
67
+
68
+
69
+ def get_db_name_from_id(id):
70
+ db_names = ["Anime DB", "Arcane DB", "GTA-5 DB", "Avatar DB"]
71
+ return db_names[id]
72
+
73
+
74
+ def get_canny_name_from_id(id):
75
+ canny_names = ["woman1", "woman2", "man1", "woman3"]
76
+ return canny_names[id]
77
+
model.py ADDED
@@ -0,0 +1,296 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import Enum
2
+ import gc
3
+ import numpy as np
4
+
5
+ import torch
6
+ import decord
7
+ from diffusers import StableDiffusionInstructPix2PixPipeline, StableDiffusionControlNetPipeline, ControlNetModel, UNet2DConditionModel
8
+ from diffusers.schedulers import EulerAncestralDiscreteScheduler, DDIMScheduler
9
+ from text_to_video.text_to_video_pipeline import TextToVideoPipeline
10
+
11
+ import utils
12
+ import gradio_utils
13
+
14
+ decord.bridge.set_bridge('torch')
15
+
16
+
17
+ class ModelType(Enum):
18
+ Pix2Pix_Video = 1,
19
+ Text2Video = 2,
20
+ ControlNetCanny = 3,
21
+ ControlNetCannyDB = 4,
22
+ ControlNetPose = 5,
23
+
24
+
25
+ class Model:
26
+ def __init__(self, device, dtype, **kwargs):
27
+ self.device = device
28
+ self.dtype = dtype
29
+ self.generator = torch.Generator(device=device)
30
+ self.pipe_dict = {
31
+ ModelType.Pix2Pix_Video: StableDiffusionInstructPix2PixPipeline,
32
+ ModelType.Text2Video: TextToVideoPipeline,
33
+ ModelType.ControlNetCanny: StableDiffusionControlNetPipeline,
34
+ ModelType.ControlNetCannyDB: StableDiffusionControlNetPipeline,
35
+ ModelType.ControlNetPose: StableDiffusionControlNetPipeline,
36
+ }
37
+ self.controlnet_attn_proc = utils.CrossFrameAttnProcessor(unet_chunk_size=2)
38
+ self.pix2pix_attn_proc = utils.CrossFrameAttnProcessor(unet_chunk_size=3)
39
+ self.text2video_attn_proc = utils.CrossFrameAttnProcessor(unet_chunk_size=2)
40
+
41
+ self.pipe = None
42
+ self.model_type = None
43
+
44
+ self.states = {}
45
+
46
+ def set_model(self, model_type: ModelType, model_id: str, **kwargs):
47
+ if self.pipe is not None:
48
+ del self.pipe
49
+ torch.cuda.empty_cache()
50
+ gc.collect()
51
+ safety_checker = kwargs.pop('safety_checker', None)
52
+ self.pipe = self.pipe_dict[model_type].from_pretrained(model_id, safety_checker=safety_checker, **kwargs).to(self.device).to(self.dtype)
53
+ self.model_type = model_type
54
+
55
+ def inference_chunk(self, frame_ids, **kwargs):
56
+ if self.pipe is None:
57
+ return
58
+ image = kwargs.pop('image')
59
+ prompt = np.array(kwargs.pop('prompt'))
60
+ negative_prompt = np.array(kwargs.pop('negative_prompt', ''))
61
+ latents = None
62
+ if 'latents' in kwargs:
63
+ latents = kwargs.pop('latents')[frame_ids]
64
+ return self.pipe(image=image[frame_ids],
65
+ prompt=prompt[frame_ids].tolist(),
66
+ negative_prompt=negative_prompt[frame_ids].tolist(),
67
+ latents=latents,
68
+ generator=self.generator,
69
+ **kwargs)
70
+
71
+ def inference(self, split_to_chunks=False, chunk_size=8, **kwargs):
72
+ if self.pipe is None:
73
+ return
74
+ seed = kwargs.pop('seed', 0)
75
+ kwargs.pop('generator', '')
76
+ # self.generator.manual_seed(seed)
77
+ if split_to_chunks:
78
+ assert 'image' in kwargs
79
+ assert 'prompt' in kwargs
80
+ image = kwargs.pop('image')
81
+ prompt = kwargs.pop('prompt')
82
+ negative_prompt = kwargs.pop('negative_prompt', '')
83
+ f = image.shape[0]
84
+ chunk_ids = np.arange(0, f, chunk_size - 1)
85
+ result = []
86
+ for i in range(len(chunk_ids)):
87
+ ch_start = chunk_ids[i]
88
+ ch_end = f if i == len(chunk_ids) - 1 else chunk_ids[i + 1]
89
+ frame_ids = [0] + list(range(ch_start, ch_end))
90
+ self.generator.manual_seed(seed)
91
+ print(f'Processing chunk {i + 1} / {len(chunk_ids)}')
92
+ result.append(self.inference_chunk(frame_ids=frame_ids,
93
+ image=image,
94
+ prompt=[prompt] * f,
95
+ negative_prompt=[negative_prompt] * f,
96
+ **kwargs).images[1:])
97
+ result = np.concatenate(result)
98
+ return result
99
+ else:
100
+ return self.pipe(generator=self.generator, **kwargs).videos[0]
101
+
102
+ def process_controlnet_canny(self,
103
+ video_path,
104
+ prompt,
105
+ num_inference_steps=20,
106
+ controlnet_conditioning_scale=1.0,
107
+ guidance_scale=9.0,
108
+ seed=42,
109
+ eta=0.0,
110
+ low_threshold=100,
111
+ high_threshold=200,
112
+ resolution=512):
113
+ video_path = gradio_utils.edge_path_to_video_path(video_path)
114
+ if self.model_type != ModelType.ControlNetCanny:
115
+ controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny")
116
+ self.set_model(ModelType.ControlNetCanny, model_id="runwayml/stable-diffusion-v1-5", controlnet=controlnet)
117
+ self.pipe.scheduler = DDIMScheduler.from_config(self.pipe.scheduler.config)
118
+ self.pipe.unet.set_attn_processor(processor=self.controlnet_attn_proc)
119
+ self.pipe.controlnet.set_attn_processor(processor=self.controlnet_attn_proc)
120
+
121
+ # TODO: Check scheduler
122
+ added_prompt = 'best quality, extremely detailed'
123
+ negative_prompts = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality'
124
+
125
+ video, fps = utils.prepare_video(video_path, resolution, self.device, self.dtype, False)
126
+ control = utils.pre_process_canny(video, low_threshold, high_threshold).to(self.device).to(self.dtype)
127
+ f, _, h, w = video.shape
128
+ self.generator.manual_seed(seed)
129
+ latents = torch.randn((1, 4, h//8, w//8), dtype=self.dtype, device=self.device, generator=self.generator)
130
+ latents = latents.repeat(f, 1, 1, 1)
131
+ result = self.inference(image=control,
132
+ prompt=prompt + ', ' + added_prompt,
133
+ height=h,
134
+ width=w,
135
+ negative_prompt=negative_prompts,
136
+ num_inference_steps=num_inference_steps,
137
+ guidance_scale=guidance_scale,
138
+ controlnet_conditioning_scale=controlnet_conditioning_scale,
139
+ eta=eta,
140
+ latents=latents,
141
+ seed=seed,
142
+ output_type='numpy',
143
+ split_to_chunks=True,
144
+ chunk_size=8,
145
+ )
146
+ return utils.create_video(result, fps)
147
+
148
+ def process_controlnet_pose(self,
149
+ video_path,
150
+ prompt,
151
+ num_inference_steps=20,
152
+ controlnet_conditioning_scale=1.0,
153
+ guidance_scale=9.0,
154
+ seed=42,
155
+ eta=0.0,
156
+ resolution=512):
157
+ video_path = gradio_utils.motion_to_video_path(video_path)
158
+ if self.model_type != ModelType.ControlNetPose:
159
+ controlnet = ControlNetModel.from_pretrained("fusing/stable-diffusion-v1-5-controlnet-openpose")
160
+ self.set_model(ModelType.ControlNetPose, model_id="runwayml/stable-diffusion-v1-5", controlnet=controlnet)
161
+ self.pipe.scheduler = DDIMScheduler.from_config(self.pipe.scheduler.config)
162
+ self.pipe.unet.set_attn_processor(processor=self.controlnet_attn_proc)
163
+ self.pipe.controlnet.set_attn_processor(processor=self.controlnet_attn_proc)
164
+
165
+ added_prompt = 'best quality, extremely detailed, HD, ultra-realistic, 8K, HQ, masterpiece, trending on artstation, art, smooth'
166
+ negative_prompts = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer difits, cropped, worst quality, low quality, deformed body, bloated, ugly, unrealistic'
167
+
168
+ video, fps = utils.prepare_video(video_path, resolution, self.device, self.dtype, False, output_fps=4)
169
+ control = utils.pre_process_pose(video, apply_pose_detect=False).to(self.device).to(self.dtype)
170
+ f, _, h, w = video.shape
171
+ self.generator.manual_seed(seed)
172
+ latents = torch.randn((1, 4, h//8, w//8), dtype=self.dtype, device=self.device, generator=self.generator)
173
+ latents = latents.repeat(f, 1, 1, 1)
174
+ result = self.inference(image=control,
175
+ prompt=prompt + ', ' + added_prompt,
176
+ height=h,
177
+ width=w,
178
+ negative_prompt=negative_prompts,
179
+ num_inference_steps=num_inference_steps,
180
+ guidance_scale=guidance_scale,
181
+ controlnet_conditioning_scale=controlnet_conditioning_scale,
182
+ eta=eta,
183
+ latents=latents,
184
+ seed=seed,
185
+ output_type='numpy',
186
+ split_to_chunks=True,
187
+ chunk_size=8,
188
+ )
189
+ return utils.create_gif(result, fps)
190
+
191
+ def process_controlnet_canny_db(self,
192
+ db_path,
193
+ video_path,
194
+ prompt,
195
+ num_inference_steps=20,
196
+ controlnet_conditioning_scale=1.0,
197
+ guidance_scale=9.0,
198
+ seed=42,
199
+ eta=0.0,
200
+ low_threshold=100,
201
+ high_threshold=200,
202
+ resolution=512):
203
+ db_path = gradio_utils.get_model_from_db_selection(db_path)
204
+ video_path = gradio_utils.get_video_from_canny_selection(video_path)
205
+ # Load db and controlnet weights
206
+ if 'db_path' not in self.states or db_path != self.states['db_path']:
207
+ controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny")
208
+ self.set_model(ModelType.ControlNetCannyDB, model_id=db_path, controlnet=controlnet)
209
+ self.pipe.scheduler = DDIMScheduler.from_config(self.pipe.scheduler.config)
210
+ self.pipe.unet.set_attn_processor(processor=self.controlnet_attn_proc)
211
+ self.pipe.controlnet.set_attn_processor(processor=self.controlnet_attn_proc)
212
+ self.states['db_path'] = db_path
213
+
214
+ added_prompt = 'best quality, extremely detailed'
215
+ negative_prompts = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality'
216
+
217
+ video, fps = utils.prepare_video(video_path, resolution, self.device, self.dtype, False)
218
+ control = utils.pre_process_canny(video, low_threshold, high_threshold).to(self.device).to(self.dtype)
219
+ f, _, h, w = video.shape
220
+ self.generator.manual_seed(seed)
221
+ latents = torch.randn((1, 4, h//8, w//8), dtype=self.dtype, device=self.device, generator=self.generator)
222
+ latents = latents.repeat(f, 1, 1, 1)
223
+ result = self.inference(image=control,
224
+ prompt=prompt + ', ' + added_prompt,
225
+ height=h,
226
+ width=w,
227
+ negative_prompt=negative_prompts,
228
+ num_inference_steps=num_inference_steps,
229
+ guidance_scale=guidance_scale,
230
+ controlnet_conditioning_scale=controlnet_conditioning_scale,
231
+ eta=eta,
232
+ latents=latents,
233
+ seed=seed,
234
+ output_type='numpy',
235
+ split_to_chunks=True,
236
+ chunk_size=8,
237
+ )
238
+ return utils.create_gif(result, fps)
239
+
240
+ def process_pix2pix(self, video, prompt, resolution=512, seed=0, start_t=0, end_t=-1, out_fps=-1):
241
+ if self.model_type != ModelType.Pix2Pix_Video:
242
+ self.set_model(ModelType.Pix2Pix_Video, model_id="timbrooks/instruct-pix2pix")
243
+ self.pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(self.pipe.scheduler.config)
244
+ self.pipe.unet.set_attn_processor(processor=self.pix2pix_attn_proc)
245
+ video, fps = utils.prepare_video(video, resolution, self.device, self.dtype, True, start_t, end_t, out_fps)
246
+ self.generator.manual_seed(seed)
247
+ result = self.inference(image=video,
248
+ prompt=prompt,
249
+ seed=seed,
250
+ output_type='numpy',
251
+ num_inference_steps=50,
252
+ image_guidance_scale=1.5,
253
+ split_to_chunks=True,
254
+ chunk_size=8,
255
+ )
256
+ return utils.create_video(result, fps)
257
+
258
+ def process_text2video(self, prompt, resolution=512, seed=24, num_frames=8, fps=4, t0=881, t1=941,
259
+ use_cf_attn=True, use_motion_field=True, use_foreground_motion_field=False,
260
+ smooth_bg=False, smooth_bg_strength=0.4, motion_field_strength=12):
261
+
262
+ if self.model_type != ModelType.Text2Video:
263
+ unet = UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet")
264
+ self.set_model(ModelType.Text2Video, model_id="runwayml/stable-diffusion-v1-5", unet=unet)
265
+ self.pipe.scheduler = DDIMScheduler.from_config(self.pipe.scheduler.config)
266
+ self.pipe.unet.set_attn_processor(processor=self.text2video_attn_proc)
267
+ self.generator.manual_seed(seed)
268
+
269
+
270
+ added_prompt = "high quality, HD, 8K, trending on artstation, high focus, dramatic lighting"
271
+ self.generator.manual_seed(seed)
272
+
273
+ prompt = prompt.rstrip()
274
+ if len(prompt) > 0 and (prompt[-1] == "," or prompt[-1] == "."):
275
+ prompt = prompt.rstrip()[:-1]
276
+ prompt = prompt.rstrip()
277
+ prompt = prompt + ", "+added_prompt
278
+
279
+ result = self.inference(prompt=[prompt],
280
+ video_length=num_frames,
281
+ height=resolution,
282
+ width=resolution,
283
+ num_inference_steps=50,
284
+ guidance_scale=7.5,
285
+ guidance_stop_step=1.0,
286
+ t0=t0,
287
+ t1=t1,
288
+ use_foreground_motion_field=use_foreground_motion_field,
289
+ motion_field_strength=motion_field_strength,
290
+ use_motion_field=use_motion_field,
291
+ smooth_bg=smooth_bg,
292
+ smooth_bg_strength=smooth_bg_strength,
293
+ seed=seed,
294
+ output_type='numpy',
295
+ )
296
+ return utils.create_video(result, fps)
requirements.txt ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==0.16.0
2
+ addict==2.4.0
3
+ albumentations==1.3.0
4
+ basicsr==1.4.2
5
+ decord==0.6.0
6
+ diffusers==0.14.0
7
+ einops==0.6.0
8
+ gradio==3.23.0
9
+ kornia==0.6
10
+ imageio==2.9.0
11
+ imageio-ffmpeg==0.4.2
12
+ invisible-watermark>=0.1.5
13
+ moviepy==1.0.3
14
+ numpy==1.24.1
15
+ omegaconf==2.3.0
16
+ open_clip_torch==2.16.0
17
+ opencv_python==4.7.0.68
18
+ opencv-contrib-python==4.3.0.36
19
+ Pillow==9.4.0
20
+ pytorch_lightning==1.5.0
21
+ prettytable==3.6.0
22
+ scikit_image==0.19.3
23
+ scipy==1.10.1
24
+ tensorboardX==2.6
25
+ tqdm==4.64.1
26
+ timm==0.6.12
27
+ transformers==4.26.0
28
+ test-tube>=0.7.5
29
+ webdataset==0.2.5
30
+ yapf==0.32.0
31
+ safetensors==0.2.7
32
+ huggingface-hub==0.13.0
33
+ torch==1.13.1
34
+ torchvision==0.14.1
share.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ import config
2
+ from cldm.hack import disable_verbosity, enable_sliced_attention
3
+
4
+
5
+ disable_verbosity()
6
+
7
+ if config.save_memory:
8
+ enable_sliced_attention()
style.css ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ h1 {
2
+ text-align: center;
3
+ }
utils.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import torch
4
+ import torchvision
5
+ from torchvision.transforms import Resize
6
+ import imageio
7
+ from einops import rearrange
8
+ import cv2
9
+ from annotator.util import resize_image, HWC3
10
+ from annotator.canny import CannyDetector
11
+ from annotator.openpose import OpenposeDetector
12
+ import decord
13
+ decord.bridge.set_bridge('torch')
14
+
15
+ apply_canny = CannyDetector()
16
+ apply_openpose = OpenposeDetector()
17
+
18
+
19
+ def add_watermark(image, im_size, watermark_path="__assets__/pair_watermark.png",
20
+ wmsize=16, bbuf=5, opacity=0.9):
21
+ '''
22
+ Creates a watermark on the saved inference image.
23
+ We request that you do not remove this to properly assign credit to
24
+ Shi-Lab's work.
25
+ '''
26
+ watermark = Image.open(watermark_path).resize((wmsize, wmsize))
27
+ loc = im_size - wmsize - bbuf
28
+ image[:,:,loc:-bbuf, loc:-bbuf] = watermark
29
+ return image
30
+
31
+
32
+ def pre_process_canny(input_video, low_threshold=100, high_threshold=200):
33
+ detected_maps = []
34
+ for frame in input_video:
35
+ img = rearrange(frame, 'c h w -> h w c').cpu().numpy().astype(np.uint8)
36
+ detected_map = apply_canny(img, low_threshold, high_threshold)
37
+ detected_map = HWC3(detected_map)
38
+ detected_maps.append(detected_map[None])
39
+ detected_maps = np.concatenate(detected_maps)
40
+ control = torch.from_numpy(detected_maps.copy()).float() / 255.0
41
+ return rearrange(control, 'f h w c -> f c h w')
42
+
43
+
44
+ def pre_process_pose(input_video, apply_pose_detect: bool = True):
45
+ detected_maps = []
46
+ for frame in input_video:
47
+ img = rearrange(frame, 'c h w -> h w c').cpu().numpy().astype(np.uint8)
48
+ img = HWC3(img)
49
+ if apply_pose_detect:
50
+ detected_map, _ = apply_openpose(img)
51
+ else:
52
+ detected_map = img
53
+ detected_map = HWC3(detected_map)
54
+ H, W, C = img.shape
55
+ detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_NEAREST)
56
+ detected_maps.append(detected_map[None])
57
+ detected_maps = np.concatenate(detected_maps)
58
+ control = torch.from_numpy(detected_maps.copy()).float() / 255.0
59
+ return rearrange(control, 'f h w c -> f c h w')
60
+
61
+
62
+ def create_video(frames, fps, rescale=False, path=None):
63
+ if path is None:
64
+ dir = "temporal"
65
+ os.makedirs(dir, exist_ok=True)
66
+ path = os.path.join(dir, 'movie.mp4')
67
+
68
+ outputs = []
69
+ for i, x in enumerate(frames):
70
+ x = torchvision.utils.make_grid(torch.Tensor(x), nrow=4)
71
+ if rescale:
72
+ x = (x + 1.0) / 2.0 # -1,1 -> 0,1
73
+ x = (x * 255).numpy().astype(np.uint8)
74
+ x = add_watermark(x, im_size=512)
75
+ outputs.append(x)
76
+ # imageio.imsave(os.path.join(dir, os.path.splitext(name)[0] + f'_{i}.jpg'), x)
77
+
78
+ imageio.mimsave(path, outputs, fps=fps)
79
+ return path
80
+
81
+ def create_gif(frames, fps, rescale=False):
82
+ dir = "temporal"
83
+ os.makedirs(dir, exist_ok=True)
84
+ path = os.path.join(dir, 'canny_db.gif')
85
+
86
+ outputs = []
87
+ for i, x in enumerate(frames):
88
+ x = torchvision.utils.make_grid(torch.Tensor(x), nrow=4)
89
+ if rescale:
90
+ x = (x + 1.0) / 2.0 # -1,1 -> 0,1
91
+ x = (x * 255).numpy().astype(np.uint8)
92
+ x = add_watermark(x, im_size=512)
93
+ outputs.append(x)
94
+ # imageio.imsave(os.path.join(dir, os.path.splitext(name)[0] + f'_{i}.jpg'), x)
95
+
96
+ imageio.mimsave(path, outputs, fps=fps)
97
+ return path
98
+
99
+ def prepare_video(video_path:str, resolution:int, device, dtype, normalize=True, start_t:float=0, end_t:float=-1, output_fps:int=-1):
100
+ vr = decord.VideoReader(video_path)
101
+ video = vr.get_batch(range(0, len(vr))).asnumpy()
102
+ initial_fps = vr.get_avg_fps()
103
+ if output_fps == -1:
104
+ output_fps = int(initial_fps)
105
+ if end_t == -1:
106
+ end_t = len(vr) / initial_fps
107
+ else:
108
+ end_t = min(len(vr) / initial_fps, end_t)
109
+ assert 0 <= start_t < end_t
110
+ assert output_fps > 0
111
+ f, h, w, c = video.shape
112
+ start_f_ind = int(start_t * initial_fps)
113
+ end_f_ind = int(end_t * initial_fps)
114
+ num_f = int((end_t - start_t) * output_fps)
115
+ sample_idx = np.linspace(start_f_ind, end_f_ind, num_f, endpoint=False).astype(int)
116
+ video = video[sample_idx]
117
+ video = rearrange(video, "f h w c -> f c h w")
118
+ video = torch.Tensor(video).to(device).to(dtype)
119
+ if h > w:
120
+ w = int(w * resolution / h)
121
+ w = w - w % 8
122
+ h = resolution - resolution % 8
123
+ video = Resize((h, w))(video)
124
+ else:
125
+ h = int(h * resolution / w)
126
+ h = h - h % 8
127
+ w = resolution - resolution % 8
128
+ video = Resize((h, w))(video)
129
+ if normalize:
130
+ video = video / 127.5 - 1.0
131
+ return video, output_fps
132
+
133
+
134
+ def post_process_gif(list_of_results, image_resolution):
135
+ output_file = "/tmp/ddxk.gif"
136
+ imageio.mimsave(output_file, list_of_results, fps=4)
137
+ return output_file
138
+
139
+
140
+ class CrossFrameAttnProcessor:
141
+ def __init__(self, unet_chunk_size=2):
142
+ self.unet_chunk_size = unet_chunk_size
143
+
144
+ def __call__(
145
+ self,
146
+ attn,
147
+ hidden_states,
148
+ encoder_hidden_states=None,
149
+ attention_mask=None):
150
+ batch_size, sequence_length, _ = hidden_states.shape
151
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
152
+ query = attn.to_q(hidden_states)
153
+
154
+ is_cross_attention = encoder_hidden_states is not None
155
+ if encoder_hidden_states is None:
156
+ encoder_hidden_states = hidden_states
157
+ elif attn.cross_attention_norm:
158
+ encoder_hidden_states = attn.norm_cross(encoder_hidden_states)
159
+ key = attn.to_k(encoder_hidden_states)
160
+ value = attn.to_v(encoder_hidden_states)
161
+ # Sparse Attention
162
+ if not is_cross_attention:
163
+ video_length = key.size()[0] // self.unet_chunk_size
164
+ # former_frame_index = torch.arange(video_length) - 1
165
+ # former_frame_index[0] = 0
166
+ former_frame_index = [0] * video_length
167
+ key = rearrange(key, "(b f) d c -> b f d c", f=video_length)
168
+ key = key[:, former_frame_index]
169
+ key = rearrange(key, "b f d c -> (b f) d c")
170
+ value = rearrange(value, "(b f) d c -> b f d c", f=video_length)
171
+ value = value[:, former_frame_index]
172
+ value = rearrange(value, "b f d c -> (b f) d c")
173
+
174
+ query = attn.head_to_batch_dim(query)
175
+ key = attn.head_to_batch_dim(key)
176
+ value = attn.head_to_batch_dim(value)
177
+
178
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
179
+ hidden_states = torch.bmm(attention_probs, value)
180
+ hidden_states = attn.batch_to_head_dim(hidden_states)
181
+
182
+ # linear proj
183
+ hidden_states = attn.to_out[0](hidden_states)
184
+ # dropout
185
+ hidden_states = attn.to_out[1](hidden_states)
186
+
187
+ return hidden_states