Sapir commited on
Commit
4bb89c5
1 Parent(s): 4535a03

Image to video script: make determinist by random seed.

Browse files
Files changed (1) hide show
  1. xora/examples/image_to_video.py +23 -10
xora/examples/image_to_video.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import torch
2
  from xora.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder
3
  from xora.models.transformers.transformer3d import Transformer3DModel
@@ -14,6 +15,8 @@ import os
14
  import numpy as np
15
  import cv2
16
  from PIL import Image
 
 
17
 
18
  def load_vae(vae_dir):
19
  vae_ckpt_path = vae_dir / "diffusion_pytorch_model.safetensors"
@@ -65,9 +68,8 @@ def load_video_to_tensor_with_resize(video_path, target_height=512, target_width
65
  frame_resized = center_crop_and_resize(frame_rgb, target_height, target_width)
66
  frames.append(frame_resized)
67
  cap.release()
68
- video_np = np.array(frames)
69
  video_tensor = torch.tensor(video_np).permute(3, 0, 1, 2).float()
70
- video_tensor = (video_tensor / 127.5) - 1.0
71
  return video_tensor
72
 
73
  def load_image_to_tensor_with_resize(image_path, target_height=512, target_width=768):
@@ -154,9 +156,13 @@ def main():
154
  'media_items': media_items,
155
  }
156
 
157
- generator = torch.Generator(device="cpu").manual_seed(args.seed)
 
 
 
 
 
158
 
159
- # Run the pipeline
160
  images = pipeline(
161
  num_inference_steps=args.num_inference_steps,
162
  num_images_per_prompt=args.num_images_per_prompt,
@@ -173,20 +179,27 @@ def main():
173
  vae_per_channel_normalize=True,
174
  conditioning_method=ConditioningMethod.FIRST_FRAME
175
  ).images
176
-
177
  # Save output video
 
 
 
 
 
 
 
 
178
  for i in range(images.shape[0]):
179
  video_np = images.squeeze(0).permute(1, 2, 3, 0).cpu().float().numpy()
180
  video_np = (video_np * 255).astype(np.uint8)
181
  fps = args.frame_rate
182
  height, width = video_np.shape[1:3]
183
- filename = lambda base, ext, dir='.': next(
184
- os.path.join(dir, f"{base}_{i}{ext}") for i in range(1000) if
185
- not os.path.exists(os.path.join(dir, f"{base}_{i}{ext}")))
186
- out = cv2.VideoWriter(filename(f"video_output_{i}", ".mp4", "."), cv2.VideoWriter_fourcc(*'mp4v'), fps,
187
- (width, height))
188
  for frame in video_np[..., ::-1]:
189
  out.write(frame)
 
190
  out.release()
191
 
192
 
 
1
+ import time
2
  import torch
3
  from xora.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder
4
  from xora.models.transformers.transformer3d import Transformer3DModel
 
15
  import numpy as np
16
  import cv2
17
  from PIL import Image
18
+ from tqdm import tqdm
19
+ import random
20
 
21
  def load_vae(vae_dir):
22
  vae_ckpt_path = vae_dir / "diffusion_pytorch_model.safetensors"
 
68
  frame_resized = center_crop_and_resize(frame_rgb, target_height, target_width)
69
  frames.append(frame_resized)
70
  cap.release()
71
+ video_np = (np.array(frames) / 127.5) - 1.0
72
  video_tensor = torch.tensor(video_np).permute(3, 0, 1, 2).float()
 
73
  return video_tensor
74
 
75
  def load_image_to_tensor_with_resize(image_path, target_height=512, target_width=768):
 
156
  'media_items': media_items,
157
  }
158
 
159
+ start_time = time.time()
160
+ random.seed(args.seed)
161
+ np.random.seed(args.seed)
162
+ torch.manual_seed(args.seed)
163
+ torch.cuda.manual_seed(args.seed)
164
+ generator = torch.Generator(device="cuda").manual_seed(args.seed)
165
 
 
166
  images = pipeline(
167
  num_inference_steps=args.num_inference_steps,
168
  num_images_per_prompt=args.num_images_per_prompt,
 
179
  vae_per_channel_normalize=True,
180
  conditioning_method=ConditioningMethod.FIRST_FRAME
181
  ).images
 
182
  # Save output video
183
+ def get_unique_filename(base, ext, dir='.', index_range=1000):
184
+ for i in range(index_range):
185
+ filename = os.path.join(dir, f"{base}_{i}{ext}")
186
+ if not os.path.exists(filename):
187
+ return filename
188
+ raise FileExistsError(f"Could not find a unique filename after {index_range} attempts.")
189
+
190
+
191
  for i in range(images.shape[0]):
192
  video_np = images.squeeze(0).permute(1, 2, 3, 0).cpu().float().numpy()
193
  video_np = (video_np * 255).astype(np.uint8)
194
  fps = args.frame_rate
195
  height, width = video_np.shape[1:3]
196
+ output_filename = get_unique_filename(f"video_output_{i}", ".mp4", ".")
197
+
198
+ out = cv2.VideoWriter(output_filename, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))
199
+
 
200
  for frame in video_np[..., ::-1]:
201
  out.write(frame)
202
+
203
  out.release()
204
 
205