g-ronimo commited on
Commit
d6748e0
1 Parent(s): dc093bf

Upload 3 files

Browse files
Files changed (3) hide show
  1. gifs_filter.py +68 -0
  2. invert_utils.py +89 -0
  3. text2vid_modded_full.py +612 -0
gifs_filter.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # filter images
2
+ from PIL import Image, ImageSequence
3
+ import requests
4
+ from tqdm import tqdm
5
+ import numpy as np
6
+ import torch
7
+ from transformers import CLIPProcessor, CLIPModel
8
+
9
+ def sample_frame_indices(clip_len, frame_sample_rate, seg_len):
10
+ converted_len = int(clip_len * frame_sample_rate)
11
+ end_idx = np.random.randint(converted_len, seg_len)
12
+ start_idx = end_idx - converted_len
13
+ indices = np.linspace(start_idx, end_idx, num=clip_len)
14
+ indices = np.clip(indices, start_idx, end_idx - 1).astype(np.int64)
15
+ return indices
16
+
17
+ def load_frames(image: Image, mode='RGBA'):
18
+ return np.array([
19
+ np.array(frame.convert(mode))
20
+ for frame in ImageSequence.Iterator(image)
21
+ ])
22
+
23
+ img_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
24
+ img_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
25
+
26
+
27
+
28
+ def filter(gifs, input_image):
29
+ max_cosine = 0.9
30
+ max_gif = []
31
+
32
+ for gif in tqdm(gifs, total=len(gifs)):
33
+ with Image.open(gif) as im:
34
+ frames = load_frames(im)
35
+
36
+ frames = np.array(frames)
37
+ frames = frames[:, :, :, :3]
38
+ frames = np.transpose(frames, (0, 3, 1, 2))[1:]
39
+
40
+
41
+
42
+ image = Image.open(input_image)
43
+
44
+
45
+ inputs = img_processor(images=frames, return_tensors="pt", padding=False)
46
+ inputs_base = img_processor(images=image, return_tensors="pt", padding=False)
47
+
48
+ with torch.no_grad():
49
+ feat_img_base = img_model.get_image_features(pixel_values=inputs_base["pixel_values"])
50
+ feat_img_vid = img_model.get_image_features(pixel_values=inputs["pixel_values"])
51
+ cos_avg = 0
52
+ avg_score_for_vid = 0
53
+ for i in range(len(feat_img_vid)):
54
+
55
+ cosine_similarity = torch.nn.functional.cosine_similarity(
56
+ feat_img_base,
57
+ feat_img_vid[0].unsqueeze(0),
58
+ dim=1)
59
+ # print(cosine_similarity)
60
+ cos_avg += cosine_similarity.item()
61
+
62
+ cos_avg /= len(feat_img_vid)
63
+ print("Current cosine similarity: ", cos_avg)
64
+ print("Max cosine similarity: ", max_cosine)
65
+ if cos_avg > max_cosine:
66
+ # max_cosine = cos_avg
67
+ max_gif.append(gif)
68
+ return max_gif
invert_utils.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import imageio
3
+ import numpy as np
4
+ from typing import Union
5
+
6
+ import torch
7
+ import torchvision
8
+
9
+ from tqdm import tqdm
10
+ from einops import rearrange
11
+
12
+
13
+ def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=4, fps=8):
14
+ videos = rearrange(videos, "b c t h w -> t b c h w")
15
+ outputs = []
16
+ for x in videos:
17
+ x = torchvision.utils.make_grid(x, nrow=n_rows)
18
+ x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
19
+ if rescale:
20
+ x = (x + 1.0) / 2.0 # -1,1 -> 0,1
21
+ x = (x * 255).numpy().astype(np.uint8)
22
+ outputs.append(x)
23
+
24
+ os.makedirs(os.path.dirname(path), exist_ok=True)
25
+ imageio.mimsave(path, outputs, fps=fps)
26
+
27
+
28
+ # DDIM Inversion
29
+ @torch.no_grad()
30
+ def init_prompt(prompt, pipeline):
31
+ uncond_input = pipeline.tokenizer(
32
+ [""], padding="max_length", max_length=pipeline.tokenizer.model_max_length,
33
+ return_tensors="pt"
34
+ )
35
+ uncond_embeddings = pipeline.text_encoder(uncond_input.input_ids.to(pipeline.device))[0]
36
+ text_input = pipeline.tokenizer(
37
+ [prompt],
38
+ padding="max_length",
39
+ max_length=pipeline.tokenizer.model_max_length,
40
+ truncation=True,
41
+ return_tensors="pt",
42
+ )
43
+ text_embeddings = pipeline.text_encoder(text_input.input_ids.to(pipeline.device))[0]
44
+ context = torch.cat([uncond_embeddings, text_embeddings])
45
+
46
+ return context
47
+
48
+
49
+ def next_step(model_output: Union[torch.FloatTensor, np.ndarray], timestep: int,
50
+ sample: Union[torch.FloatTensor, np.ndarray], ddim_scheduler):
51
+ timestep, next_timestep = min(
52
+ timestep - ddim_scheduler.config.num_train_timesteps // ddim_scheduler.num_inference_steps, 999), timestep
53
+ # try:
54
+ alpha_prod_t = ddim_scheduler.alphas_cumprod[timestep] if timestep >= 0 else ddim_scheduler.final_alpha_cumprod
55
+ # except:
56
+ # alpha_prod_t = ddim_scheduler.alphas_cumprod[timestep] #if timestep >= 0 else ddim_scheduler.final_alpha_cumprod
57
+ alpha_prod_t_next = ddim_scheduler.alphas_cumprod[next_timestep]
58
+ beta_prod_t = 1 - alpha_prod_t
59
+ next_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5
60
+ next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output
61
+ next_sample = alpha_prod_t_next ** 0.5 * next_original_sample + next_sample_direction
62
+ return next_sample
63
+
64
+
65
+ def get_noise_pred_single(latents, t, context, unet):
66
+ noise_pred = unet(latents, t, encoder_hidden_states=context)["sample"]
67
+ return noise_pred
68
+
69
+
70
+ @torch.no_grad()
71
+ def ddim_loop(pipeline, ddim_scheduler, latent, num_inv_steps, prompt):
72
+ context = init_prompt(prompt, pipeline)
73
+ uncond_embeddings, cond_embeddings = context.chunk(2)
74
+ all_latent = [latent]
75
+ latent = latent.clone().detach()
76
+ for i in tqdm(range(num_inv_steps)):
77
+ t = ddim_scheduler.timesteps[len(ddim_scheduler.timesteps) - i - 1]
78
+ noise_pred = get_noise_pred_single(latent, t, cond_embeddings, pipeline.unet)
79
+ noise_pred_unc = get_noise_pred_single(latent, t, uncond_embeddings, pipeline.unet)
80
+ noise_pred = noise_pred_unc + 9.0 * (noise_pred_unc - noise_pred)
81
+ latent = next_step(noise_pred, t, latent, ddim_scheduler)
82
+ all_latent.append(latent)
83
+ return all_latent
84
+
85
+
86
+ @torch.no_grad()
87
+ def ddim_inversion(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt=""):
88
+ ddim_latents = ddim_loop(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt)
89
+ return ddim_latents
text2vid_modded_full.py ADDED
@@ -0,0 +1,612 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Callable, Dict, List, Optional, Union
2
+ import numpy as np
3
+ import torch
4
+ from transformers import CLIPTextModel, CLIPTokenizer
5
+ from diffusers.image_processor import VaeImageProcessor
6
+ from diffusers.models import AutoencoderKL, UNet3DConditionModel
7
+ from diffusers.schedulers import KarrasDiffusionSchedulers
8
+ from diffusers.utils import (
9
+ logging,
10
+ replace_example_docstring)
11
+ from diffusers.pipelines.text_to_video_synthesis import TextToVideoSDPipelineOutput
12
+
13
+
14
+
15
+ TAU_2 = 15
16
+ TAU_1 = 10
17
+
18
+
19
+ def init_attention_params(unet, num_frames, lambda_=None, bs=None):
20
+
21
+
22
+ for name, module in unet.named_modules():
23
+ module_name = type(module).__name__
24
+ if module_name == "Attention":
25
+ module.LAMBDA = lambda_
26
+ module.bs = bs
27
+ module.num_frames = num_frames
28
+ module.last_attn_slice_weights = 1
29
+
30
+ def init_attention_func(unet):
31
+ # ORIGINAL SOURCE CODE: https://github.com/huggingface/diffusers/blob/91ddd2a25b848df0fa1262d4f1cd98c7ccb87750/src/diffusers/models/attention.py#L276
32
+ # Updated source code: https://github.com/huggingface/diffusers/blob/50296739878f3e17b2d25d45ef626318b44440b9/src/diffusers/models/attention_processor.py#L571
33
+ def get_attention_scores(
34
+ self, query, key, attention_mask = None):
35
+ r"""
36
+ Compute the attention scores.
37
+
38
+ Args:
39
+ query (`torch.Tensor`): The query tensor.
40
+ key (`torch.Tensor`): The key tensor.
41
+ attention_mask (`torch.Tensor`, *optional*): The attention mask to use. If `None`, no mask is applied.
42
+
43
+ Returns:
44
+ `torch.Tensor`: The attention probabilities/scores.
45
+ """
46
+
47
+ q_old = query.clone()
48
+ k_old = key.clone()
49
+
50
+ if self.use_last_attn_slice:
51
+ if self.last_attn_slice is not None:
52
+ query_list = self.last_attn_slice[0]
53
+ key_list = self.last_attn_slice[1]
54
+
55
+ if query.shape[1] == self.num_frames and query.shape == key.shape:
56
+
57
+ key1 = key.clone()
58
+ key1[:,:1,:key_list.shape[2]] = key_list[:,:1]
59
+
60
+ if q_old.shape == k_old.shape and q_old.shape[1]!=self.num_frames:
61
+
62
+ batch_dim = query_list.shape[0] // self.bs
63
+ all_dim = query.shape[0] // self.bs
64
+ for i in range(self.bs):
65
+ query[i*all_dim:(i*all_dim) + batch_dim,:query_list.shape[1],:query_list.shape[2]] = query_list[i*batch_dim:(i+1)*batch_dim]
66
+
67
+
68
+ dtype = query.dtype
69
+ if self.upcast_attention:
70
+ query = query.float()
71
+ key = key.float()
72
+
73
+
74
+ if attention_mask is None:
75
+ baddbmm_input = torch.empty(
76
+ query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device
77
+ )
78
+ beta = 0
79
+ else:
80
+ baddbmm_input = attention_mask
81
+ beta = 1
82
+
83
+
84
+ attention_scores = torch.baddbmm(
85
+ baddbmm_input,
86
+ query,
87
+ key.transpose(-1, -2),
88
+ beta=beta,
89
+ alpha=self.scale,
90
+ )
91
+
92
+ if query.shape[1] == self.num_frames and query.shape == key.shape and self.use_last_attn_slice:
93
+ attention_scores1 = torch.baddbmm(
94
+ baddbmm_input,
95
+ query,
96
+ key1.transpose(-1, -2),
97
+ beta=beta,
98
+ alpha=self.scale,
99
+ )
100
+ dynamic_lambda = torch.tensor([1 + self.LAMBDA * (i/50) for i in range(self.num_frames)]).to(dtype).cuda()
101
+ attention_scores[:,:self.num_frames,0] = attention_scores1[:,:self.num_frames,0] * dynamic_lambda
102
+
103
+
104
+ del baddbmm_input
105
+
106
+ if self.upcast_softmax:
107
+ attention_scores = attention_scores.float()
108
+
109
+ attention_probs = attention_scores.softmax(dim=-1)
110
+
111
+
112
+ if self.use_last_attn_slice:
113
+ self.use_last_attn_slice = False
114
+
115
+ if self.save_last_attn_slice:
116
+
117
+ self.last_attn_slice = [
118
+ query,
119
+ key,
120
+ ]
121
+
122
+ self.save_last_attn_slice = False
123
+
124
+
125
+
126
+ del attention_scores
127
+ attention_probs = attention_probs.to(dtype)
128
+
129
+
130
+ return attention_probs
131
+
132
+
133
+ for _, module in unet.named_modules():
134
+ module_name = type(module).__name__
135
+
136
+ if module_name == "Attention":
137
+ module.last_attn_slice = None
138
+ module.use_last_attn_slice = False
139
+ module.save_last_attn_slice = False
140
+ module.LAMBDA = 0
141
+ module.get_attention_scores = get_attention_scores.__get__(module, type(module))
142
+
143
+ module.bs = 0
144
+ module.num_frames = None
145
+
146
+ return unet
147
+
148
+
149
+ def use_last_self_attention(unet, use=True):
150
+ for name, module in unet.named_modules():
151
+ module_name = type(module).__name__
152
+ if module_name == "Attention" and "attn1" in name:
153
+ module.use_last_attn_slice = use
154
+
155
+ def save_last_self_attention(unet, save=True):
156
+ for name, module in unet.named_modules():
157
+ module_name = type(module).__name__
158
+ if module_name == "Attention" and "attn1" in name:
159
+ module.save_last_attn_slice = save
160
+
161
+
162
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
163
+
164
+ EXAMPLE_DOC_STRING = """
165
+ Examples:
166
+ ```py
167
+ >>> import torch
168
+ >>> from diffusers import TextToVideoSDPipeline
169
+ >>> from diffusers.utils import export_to_video
170
+
171
+ >>> pipe = TextToVideoSDPipeline.from_pretrained(
172
+ ... "damo-vilab/text-to-video-ms-1.7b", torch_dtype=torch.float16, variant="fp16"
173
+ ... )
174
+ >>> pipe.enable_model_cpu_offload()
175
+
176
+ >>> prompt = "Spiderman is surfing"
177
+ >>> video_frames = pipe(prompt).frames[0]
178
+ >>> video_path = export_to_video(video_frames)
179
+ >>> video_path
180
+ ```
181
+ """
182
+
183
+
184
+ # Copied from diffusers.pipelines.animatediff.pipeline_animatediff.tensor2vid
185
+ def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type: str = "np"):
186
+ batch_size, channels, num_frames, height, width = video.shape
187
+ outputs = []
188
+ for batch_idx in range(batch_size):
189
+ batch_vid = video[batch_idx].permute(1, 0, 2, 3)
190
+ batch_output = processor.postprocess(batch_vid, output_type)
191
+
192
+ outputs.append(batch_output)
193
+
194
+ if output_type == "np":
195
+ outputs = np.stack(outputs)
196
+
197
+ elif output_type == "pt":
198
+ outputs = torch.stack(outputs)
199
+
200
+ elif not output_type == "pil":
201
+ raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']")
202
+
203
+ return outputs
204
+
205
+ from diffusers import TextToVideoSDPipeline
206
+ class TextToVideoSDPipelineModded(TextToVideoSDPipeline):
207
+ def __init__(
208
+ self,
209
+ vae: AutoencoderKL,
210
+ text_encoder: CLIPTextModel,
211
+ tokenizer: CLIPTokenizer,
212
+ unet: UNet3DConditionModel,
213
+ scheduler: KarrasDiffusionSchedulers,
214
+ ):
215
+ super().__init__(vae, text_encoder, tokenizer, unet, scheduler)
216
+
217
+
218
+ def call_network(self,
219
+ negative_prompt_embeds,
220
+ prompt_embeds,
221
+ latents,
222
+ inv_latents,
223
+ t,
224
+ i,
225
+ null_embeds,
226
+ cross_attention_kwargs,
227
+ extra_step_kwargs,
228
+ do_classifier_free_guidance,
229
+ guidance_scale,
230
+ ):
231
+
232
+
233
+ inv_latent_model_input = inv_latents
234
+ inv_latent_model_input = self.scheduler.scale_model_input(inv_latent_model_input, t)
235
+
236
+ latent_model_input = latents
237
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
238
+
239
+
240
+ if do_classifier_free_guidance:
241
+ noise_pred_uncond = self.unet(
242
+ latent_model_input,
243
+ t,
244
+ encoder_hidden_states=negative_prompt_embeds,
245
+ cross_attention_kwargs=cross_attention_kwargs,
246
+ return_dict=False,
247
+ )[0]
248
+
249
+ noise_null_pred_uncond = self.unet(
250
+ inv_latent_model_input,
251
+ t,
252
+ encoder_hidden_states=negative_prompt_embeds,
253
+ cross_attention_kwargs=cross_attention_kwargs,
254
+ return_dict=False,
255
+ )[0]
256
+
257
+
258
+
259
+ if i<=TAU_2:
260
+ save_last_self_attention(self.unet)
261
+
262
+
263
+ noise_null_pred = self.unet(
264
+ inv_latent_model_input,
265
+ t,
266
+ encoder_hidden_states=null_embeds,
267
+ cross_attention_kwargs=cross_attention_kwargs,
268
+ return_dict=False,
269
+ )[0]
270
+
271
+ if do_classifier_free_guidance:
272
+ noise_null_pred = noise_null_pred_uncond + guidance_scale * (noise_null_pred - noise_null_pred_uncond)
273
+
274
+ bsz, channel, frames, width, height = inv_latents.shape
275
+
276
+ inv_latents = inv_latents.permute(0, 2, 1, 3, 4).reshape(bsz*frames, channel, height, width)
277
+ noise_null_pred = noise_null_pred.permute(0, 2, 1, 3, 4).reshape(bsz*frames, channel, height, width)
278
+ inv_latents = self.scheduler.step(noise_null_pred, t, inv_latents, **extra_step_kwargs).prev_sample
279
+ inv_latents = inv_latents[None, :].reshape((bsz, frames , -1) + inv_latents.shape[2:]).permute(0, 2, 1, 3, 4)
280
+
281
+ use_last_self_attention(self.unet)
282
+ else:
283
+ noise_null_pred = None
284
+
285
+
286
+
287
+
288
+ noise_pred = self.unet(
289
+ latent_model_input,
290
+ t,
291
+ encoder_hidden_states=prompt_embeds, # For unconditional guidance
292
+ cross_attention_kwargs=cross_attention_kwargs,
293
+ return_dict=False,
294
+ )[0]
295
+
296
+ use_last_self_attention(self.unet, False)
297
+
298
+
299
+ if do_classifier_free_guidance:
300
+ noise_pred_text = noise_pred
301
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
302
+
303
+ # reshape latents
304
+ bsz, channel, frames, width, height = latents.shape
305
+ latents = latents.permute(0, 2, 1, 3, 4).reshape(bsz * frames, channel, width, height)
306
+ noise_pred = noise_pred.permute(0, 2, 1, 3, 4).reshape(bsz * frames, channel, width, height)
307
+
308
+ # compute the previous noisy sample x_t -> x_t-1
309
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
310
+
311
+
312
+
313
+ # reshape latents back
314
+ latents = latents[None, :].reshape(bsz, frames, channel, width, height).permute(0, 2, 1, 3, 4)
315
+
316
+
317
+ return {
318
+ "latents": latents,
319
+ "inv_latents": inv_latents,
320
+ "noise_pred": noise_pred,
321
+ "noise_null_pred": noise_null_pred,
322
+ }
323
+
324
+ def optimize_latents(self, latents, inv_latents, t, i, null_embeds, cross_attention_kwargs, prompt_embeds):
325
+ inv_scaled = self.scheduler.scale_model_input(inv_latents, t)
326
+
327
+ noise_null_pred = self.unet(
328
+ inv_scaled[:,:,0:1,:,:],
329
+ t,
330
+ encoder_hidden_states=null_embeds,
331
+ cross_attention_kwargs=cross_attention_kwargs,
332
+ return_dict=False,
333
+ )[0]
334
+
335
+ with torch.enable_grad():
336
+
337
+ latent_train = latents[:,:,1:,:,:].clone().detach().requires_grad_(True)
338
+ optimizer = torch.optim.Adam([latent_train], lr=1e-3)
339
+
340
+ for j in range(10):
341
+ latent_in = torch.cat([inv_latents[:,:,0:1,:,:].detach(), latent_train], dim=2)
342
+ latent_input_unet = self.scheduler.scale_model_input(latent_in, t)
343
+
344
+ noise_pred = self.unet(
345
+ latent_input_unet,
346
+ t,
347
+ encoder_hidden_states=prompt_embeds, # For unconditional guidance
348
+ cross_attention_kwargs=cross_attention_kwargs,
349
+ return_dict=False,
350
+ )[0]
351
+
352
+ loss = torch.nn.functional.mse_loss(noise_pred[:,:,0,:,:], noise_null_pred[:,:,0,:,:])
353
+
354
+ loss.backward()
355
+
356
+ optimizer.step()
357
+ optimizer.zero_grad()
358
+
359
+ print("Iteration {} Subiteration {} Loss {} ".format(i, j, loss.item()))
360
+ latents = latent_in.detach()
361
+ return latents
362
+
363
+ @torch.no_grad()
364
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
365
+ def __call__(
366
+ self,
367
+ prompt: Union[str, List[str]] = None,
368
+ height: Optional[int] = None,
369
+ width: Optional[int] = None,
370
+ num_frames: int = 16,
371
+ num_inference_steps: int = 50,
372
+ guidance_scale: float = 9.0,
373
+ negative_prompt: Optional[Union[str, List[str]]] = None,
374
+ eta: float = 0.0,
375
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
376
+ latents: Optional[torch.FloatTensor] = None,
377
+ inv_latents: Optional[torch.FloatTensor] = None,
378
+ prompt_embeds: Optional[torch.FloatTensor] = None,
379
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
380
+ output_type: Optional[str] = "np",
381
+ return_dict: bool = True,
382
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
383
+ callback_steps: int = 1,
384
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
385
+ clip_skip: Optional[int] = None,
386
+ lambda_ = 0.5,
387
+ ):
388
+ r"""
389
+ The call function to the pipeline for generation.
390
+
391
+ Args:
392
+ prompt (`str` or `List[str]`, *optional*):
393
+ The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
394
+ height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
395
+ The height in pixels of the generated video.
396
+ width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
397
+ The width in pixels of the generated video.
398
+ num_frames (`int`, *optional*, defaults to 16):
399
+ The number of video frames that are generated. Defaults to 16 frames which at 8 frames per seconds
400
+ amounts to 2 seconds of video.
401
+ num_inference_steps (`int`, *optional*, defaults to 50):
402
+ The number of denoising steps. More denoising steps usually lead to a higher quality videos at the
403
+ expense of slower inference.
404
+ guidance_scale (`float`, *optional*, defaults to 7.5):
405
+ A higher guidance scale value encourages the model to generate images closely linked to the text
406
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
407
+ negative_prompt (`str` or `List[str]`, *optional*):
408
+ The prompt or prompts to guide what to not include in image generation. If not defined, you need to
409
+ pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
410
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
411
+ The number of images to generate per prompt.
412
+ eta (`float`, *optional*, defaults to 0.0):
413
+ Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
414
+ to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
415
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
416
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
417
+ generation deterministic.
418
+ latents (`torch.FloatTensor`, *optional*):
419
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for video
420
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
421
+ tensor is generated by sampling using the supplied random `generator`. Latents should be of shape
422
+ `(batch_size, num_channel, num_frames, height, width)`.
423
+ prompt_embeds (`torch.FloatTensor`, *optional*):
424
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
425
+ provided, text embeddings are generated from the `prompt` input argument.
426
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
427
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
428
+ not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
429
+ output_type (`str`, *optional*, defaults to `"np"`):
430
+ The output format of the generated video. Choose between `torch.FloatTensor` or `np.array`.
431
+ return_dict (`bool`, *optional*, defaults to `True`):
432
+ Whether or not to return a [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] instead
433
+ of a plain tuple.
434
+ callback (`Callable`, *optional*):
435
+ A function that calls every `callback_steps` steps during inference. The function is called with the
436
+ following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
437
+ callback_steps (`int`, *optional*, defaults to 1):
438
+ The frequency at which the `callback` function is called. If not specified, the callback is called at
439
+ every step.
440
+ cross_attention_kwargs (`dict`, *optional*):
441
+ A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
442
+ [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
443
+ clip_skip (`int`, *optional*):
444
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
445
+ the output of the pre-final layer will be used for computing the prompt embeddings.
446
+ Examples:
447
+
448
+ Returns:
449
+ [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] or `tuple`:
450
+ If `return_dict` is `True`, [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] is
451
+ returned, otherwise a `tuple` is returned where the first element is a list with the generated frames.
452
+ """
453
+ # 0. Default height and width to unet
454
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
455
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
456
+
457
+ num_images_per_prompt = 1
458
+
459
+ # 1. Check inputs. Raise error if not correct
460
+ self.check_inputs(
461
+ prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
462
+ )
463
+
464
+ # # 2. Define call parameters
465
+ # if prompt is not None and isinstance(prompt, str):
466
+ # batch_size = 1
467
+ # elif prompt is not None and isinstance(prompt, list):
468
+ # batch_size = len(prompt)
469
+ # else:
470
+ # batch_size = prompt_embeds.shape[0]
471
+
472
+ batch_size = inv_latents.shape[0]
473
+ device = self._execution_device
474
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
475
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
476
+ # corresponds to doing no classifier free guidance.
477
+ do_classifier_free_guidance = guidance_scale > 1.0
478
+
479
+ # 3. Encode input prompt
480
+ text_encoder_lora_scale = (
481
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
482
+ )
483
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
484
+ [prompt] * batch_size,
485
+ device,
486
+ num_images_per_prompt,
487
+ do_classifier_free_guidance,
488
+ [negative_prompt] * batch_size if negative_prompt is not None else None,
489
+ prompt_embeds=prompt_embeds,
490
+ negative_prompt_embeds=negative_prompt_embeds,
491
+ lora_scale=text_encoder_lora_scale,
492
+ clip_skip=clip_skip,
493
+ )
494
+ null_embeds, negative_prompt_embeds = self.encode_prompt(
495
+ [""] * batch_size,
496
+ device,
497
+ num_images_per_prompt,
498
+ do_classifier_free_guidance,
499
+ [negative_prompt] * batch_size if negative_prompt is not None else None,
500
+ prompt_embeds=None,
501
+ negative_prompt_embeds=negative_prompt_embeds,
502
+ lora_scale=text_encoder_lora_scale,
503
+ clip_skip=clip_skip,
504
+ )
505
+
506
+
507
+
508
+ # 4. Prepare timesteps
509
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
510
+ timesteps = self.scheduler.timesteps
511
+
512
+ # 5. Prepare latent variables
513
+ num_channels_latents = self.unet.config.in_channels
514
+ latents = self.prepare_latents(
515
+ batch_size * num_images_per_prompt,
516
+ num_channels_latents,
517
+ num_frames,
518
+ height,
519
+ width,
520
+ prompt_embeds.dtype,
521
+ device,
522
+ generator,
523
+ latents,
524
+ )
525
+ inv_latents = self.prepare_latents(
526
+ batch_size * num_images_per_prompt,
527
+ num_channels_latents,
528
+ num_frames,
529
+ height,
530
+ width,
531
+ prompt_embeds.dtype,
532
+ device,
533
+ generator,
534
+ inv_latents,
535
+ )
536
+
537
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
538
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
539
+
540
+ # 7. Denoising loop
541
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
542
+
543
+ init_attention_func(self.unet)
544
+ print("Setup for Current Run")
545
+ print("----------------------")
546
+ print("Prompt ", prompt)
547
+ print("Batch size ", batch_size)
548
+ print("Num frames ", latents.shape[2])
549
+ print("Lambda ", lambda_)
550
+
551
+ init_attention_params(self.unet, num_frames=latents.shape[2], lambda_=lambda_, bs = batch_size)
552
+
553
+ iters_to_alter = [i for i in range(0, TAU_1)]
554
+
555
+
556
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
557
+
558
+ mask_in = torch.zeros(latents.shape).to(dtype=latents.dtype, device=latents.device)
559
+ mask_in[:, :, 0, :, :] = 1
560
+ assert latents.shape[0] == inv_latents.shape[0], "Latents and Inverse Latents should have the same batch but got {} and {}".format(latents.shape[0], inv_latents.shape[0])
561
+ inv_latents = inv_latents.repeat(1,1,num_frames,1,1)
562
+
563
+ latents = inv_latents * mask_in + latents * (1-mask_in)
564
+
565
+
566
+
567
+ for i, t in enumerate(timesteps):
568
+
569
+ curr_copy = max(1,num_frames - i)
570
+ inv_latents = inv_latents[:,:,:curr_copy, :, : ]
571
+ if i in iters_to_alter:
572
+
573
+ latents = self.optimize_latents(latents, inv_latents, t, i, null_embeds, cross_attention_kwargs, prompt_embeds)
574
+
575
+
576
+ output_dict = self.call_network(
577
+ negative_prompt_embeds,
578
+ prompt_embeds,
579
+ latents,
580
+ inv_latents,
581
+ t,
582
+ i,
583
+ null_embeds,
584
+ cross_attention_kwargs,
585
+ extra_step_kwargs,
586
+ do_classifier_free_guidance,
587
+ guidance_scale,
588
+ )
589
+ latents = output_dict["latents"]
590
+ inv_latents = output_dict["inv_latents"]
591
+
592
+ # call the callback, if provided
593
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
594
+ progress_bar.update()
595
+ if callback is not None and i % callback_steps == 0:
596
+ step_idx = i // getattr(self.scheduler, "order", 1)
597
+ callback(step_idx, t, latents)
598
+
599
+ # 8. Post processing
600
+ if output_type == "latent":
601
+ video = latents
602
+ else:
603
+ video_tensor = self.decode_latents(latents)
604
+ video = tensor2vid(video_tensor, self.image_processor, output_type)
605
+
606
+ # 9. Offload all models
607
+ self.maybe_free_model_hooks()
608
+
609
+ if not return_dict:
610
+ return (video,)
611
+
612
+ return TextToVideoSDPipelineOutput(frames=video)