Sapir commited on
Commit
325137b
·
1 Parent(s): 41b1cab

Lint: added ruff.

Browse files
.github/workflows/pylint.yml ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Ruff
2
+
3
+ on: [push]
4
+
5
+ jobs:
6
+ build:
7
+ runs-on: ubuntu-latest
8
+ strategy:
9
+ matrix:
10
+ python-version: ["3.10"]
11
+ steps:
12
+ - name: Checkout repository and submodules
13
+ uses: actions/checkout@v3
14
+ - name: Set up Python ${{ matrix.python-version }}
15
+ uses: actions/setup-python@v3
16
+ with:
17
+ python-version: ${{ matrix.python-version }}
18
+ - name: Install dependencies
19
+ run: |
20
+ python -m pip install --upgrade pip
21
+ pip install ruff==0.2.2 black==24.2.0
22
+ - name: Analyzing the code with ruff
23
+ run: |
24
+ ruff $(git ls-files '*.py')
25
+ - name: Verify that no Black changes are required
26
+ run: |
27
+ black --check $(git ls-files '*.py')
.gitignore CHANGED
@@ -159,4 +159,4 @@ cython_debug/
159
  # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
160
  # and can be added to the global gitignore or merged into this file. For a more nuclear
161
  # option (not recommended) you can uncomment the following to ignore the entire idea folder.
162
- #.idea/
 
159
  # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
160
  # and can be added to the global gitignore or merged into this file. For a more nuclear
161
  # option (not recommended) you can uncomment the following to ignore the entire idea folder.
162
+ .idea/
.pre-commit-config.yaml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ repos:
2
+ - repo: https://github.com/astral-sh/ruff-pre-commit
3
+ # Ruff version.
4
+ rev: v0.2.2
5
+ hooks:
6
+ # Run the linter.
7
+ - id: ruff
8
+ args: [--fix] # Automatically fix issues if possible.
9
+ types: [python] # Ensure it only runs on .py files.
10
+
11
+ - repo: https://github.com/psf/black
12
+ rev: 24.2.0 # Specify the version of Black you want
13
+ hooks:
14
+ - id: black
15
+ name: Black code formatter
16
+ language_version: python3 # Use the Python version you're targeting (e.g., 3.10)
scripts/to_safetensors.py CHANGED
@@ -1,6 +1,6 @@
1
  import argparse
2
  from pathlib import Path
3
- from typing import Any, Dict
4
  import safetensors.torch
5
  import torch
6
  import json
@@ -8,12 +8,14 @@ import shutil
8
 
9
 
10
  def load_text_encoder(index_path: Path) -> Dict:
11
- with open(index_path, 'r') as f:
12
  index: Dict = json.load(f)
13
 
14
  loaded_tensors = {}
15
  for part_file in set(index.get("weight_map", {}).values()):
16
- tensors = safetensors.torch.load_file(index_path.parent / part_file, device='cpu')
 
 
17
  for tensor_name in tensors:
18
  loaded_tensors[tensor_name] = tensors[tensor_name]
19
 
@@ -30,23 +32,30 @@ def convert_vae(vae_path: Path, add_prefix=True) -> Dict:
30
  state_dict = torch.load(vae_path / "autoencoder.pth", weights_only=True)
31
  stats_path = vae_path / "per_channel_statistics.json"
32
  if stats_path.exists():
33
- with open(stats_path, 'r') as f:
34
  data = json.load(f)
35
  transposed_data = list(zip(*data["data"]))
36
  data_dict = {
37
- f"{'vae.' if add_prefix else ''}per_channel_statistics.{col}": torch.tensor(vals)
 
 
38
  for col, vals in zip(data["columns"], transposed_data)
39
  }
40
  else:
41
  data_dict = {}
42
 
43
- result = {("vae." if add_prefix else "") + key: value for key, value in state_dict.items()}
 
 
44
  result.update(data_dict)
45
  return result
46
 
47
 
48
  def convert_encoder(encoder: Dict) -> Dict:
49
- return {"text_encoders.t5xxl.transformer." + key: value for key, value in encoder.items()}
 
 
 
50
 
51
 
52
  def save_config(config_src: str, config_dst: str):
@@ -60,50 +69,75 @@ def load_vae_config(vae_path: Path) -> str:
60
  return str(config_path)
61
 
62
 
63
- def main(unet_path: str, vae_path: str, out_path: str, mode: str,
64
- unet_config_path: str = None, scheduler_config_path: str = None) -> None:
65
- unet = convert_unet(torch.load(unet_path, weights_only=True), add_prefix=(mode == 'single'))
 
 
 
 
 
 
 
 
66
 
67
  # Load VAE from directory and config
68
- vae = convert_vae(Path(vae_path), add_prefix=(mode == 'single'))
69
  vae_config_path = load_vae_config(Path(vae_path))
70
 
71
- if mode == 'single':
72
  result = {**unet, **vae}
73
  safetensors.torch.save_file(result, out_path)
74
- elif mode == 'separate':
75
  # Create directories for unet, vae, and scheduler
76
- unet_dir = Path(out_path) / 'unet'
77
- vae_dir = Path(out_path) / 'vae'
78
- scheduler_dir = Path(out_path) / 'scheduler'
79
 
80
  unet_dir.mkdir(parents=True, exist_ok=True)
81
  vae_dir.mkdir(parents=True, exist_ok=True)
82
  scheduler_dir.mkdir(parents=True, exist_ok=True)
83
 
84
  # Save unet and vae safetensors with the name diffusion_pytorch_model.safetensors
85
- safetensors.torch.save_file(unet, unet_dir / 'diffusion_pytorch_model.safetensors')
86
- safetensors.torch.save_file(vae, vae_dir / 'diffusion_pytorch_model.safetensors')
 
 
 
 
87
 
88
  # Save config files for unet, vae, and scheduler
89
  if unet_config_path:
90
- save_config(unet_config_path, unet_dir / 'config.json')
91
  if vae_config_path:
92
- save_config(vae_config_path, vae_dir / 'config.json')
93
  if scheduler_config_path:
94
- save_config(scheduler_config_path, scheduler_dir / 'scheduler_config.json')
95
 
96
 
97
- if __name__ == '__main__':
98
  parser = argparse.ArgumentParser()
99
- parser.add_argument('--unet_path', '-u', type=str, default='unet/ema-002.pt')
100
- parser.add_argument('--vae_path', '-v', type=str, default='vae/')
101
- parser.add_argument('--out_path', '-o', type=str, default='xora.safetensors')
102
- parser.add_argument('--mode', '-m', type=str, choices=['single', 'separate'], default='single',
103
- help="Choose 'single' for the original behavior, 'separate' to save unet and vae separately.")
104
- parser.add_argument('--unet_config_path', type=str, help="Path to the UNet config file (for separate mode)")
105
- parser.add_argument('--scheduler_config_path', type=str,
106
- help="Path to the Scheduler config file (for separate mode)")
 
 
 
 
 
 
 
 
 
 
 
 
 
107
 
108
  args = parser.parse_args()
109
  main(**args.__dict__)
 
1
  import argparse
2
  from pathlib import Path
3
+ from typing import Dict
4
  import safetensors.torch
5
  import torch
6
  import json
 
8
 
9
 
10
  def load_text_encoder(index_path: Path) -> Dict:
11
+ with open(index_path, "r") as f:
12
  index: Dict = json.load(f)
13
 
14
  loaded_tensors = {}
15
  for part_file in set(index.get("weight_map", {}).values()):
16
+ tensors = safetensors.torch.load_file(
17
+ index_path.parent / part_file, device="cpu"
18
+ )
19
  for tensor_name in tensors:
20
  loaded_tensors[tensor_name] = tensors[tensor_name]
21
 
 
32
  state_dict = torch.load(vae_path / "autoencoder.pth", weights_only=True)
33
  stats_path = vae_path / "per_channel_statistics.json"
34
  if stats_path.exists():
35
+ with open(stats_path, "r") as f:
36
  data = json.load(f)
37
  transposed_data = list(zip(*data["data"]))
38
  data_dict = {
39
+ f"{'vae.' if add_prefix else ''}per_channel_statistics.{col}": torch.tensor(
40
+ vals
41
+ )
42
  for col, vals in zip(data["columns"], transposed_data)
43
  }
44
  else:
45
  data_dict = {}
46
 
47
+ result = {
48
+ ("vae." if add_prefix else "") + key: value for key, value in state_dict.items()
49
+ }
50
  result.update(data_dict)
51
  return result
52
 
53
 
54
  def convert_encoder(encoder: Dict) -> Dict:
55
+ return {
56
+ "text_encoders.t5xxl.transformer." + key: value
57
+ for key, value in encoder.items()
58
+ }
59
 
60
 
61
  def save_config(config_src: str, config_dst: str):
 
69
  return str(config_path)
70
 
71
 
72
+ def main(
73
+ unet_path: str,
74
+ vae_path: str,
75
+ out_path: str,
76
+ mode: str,
77
+ unet_config_path: str = None,
78
+ scheduler_config_path: str = None,
79
+ ) -> None:
80
+ unet = convert_unet(
81
+ torch.load(unet_path, weights_only=True), add_prefix=(mode == "single")
82
+ )
83
 
84
  # Load VAE from directory and config
85
+ vae = convert_vae(Path(vae_path), add_prefix=(mode == "single"))
86
  vae_config_path = load_vae_config(Path(vae_path))
87
 
88
+ if mode == "single":
89
  result = {**unet, **vae}
90
  safetensors.torch.save_file(result, out_path)
91
+ elif mode == "separate":
92
  # Create directories for unet, vae, and scheduler
93
+ unet_dir = Path(out_path) / "unet"
94
+ vae_dir = Path(out_path) / "vae"
95
+ scheduler_dir = Path(out_path) / "scheduler"
96
 
97
  unet_dir.mkdir(parents=True, exist_ok=True)
98
  vae_dir.mkdir(parents=True, exist_ok=True)
99
  scheduler_dir.mkdir(parents=True, exist_ok=True)
100
 
101
  # Save unet and vae safetensors with the name diffusion_pytorch_model.safetensors
102
+ safetensors.torch.save_file(
103
+ unet, unet_dir / "diffusion_pytorch_model.safetensors"
104
+ )
105
+ safetensors.torch.save_file(
106
+ vae, vae_dir / "diffusion_pytorch_model.safetensors"
107
+ )
108
 
109
  # Save config files for unet, vae, and scheduler
110
  if unet_config_path:
111
+ save_config(unet_config_path, unet_dir / "config.json")
112
  if vae_config_path:
113
+ save_config(vae_config_path, vae_dir / "config.json")
114
  if scheduler_config_path:
115
+ save_config(scheduler_config_path, scheduler_dir / "scheduler_config.json")
116
 
117
 
118
+ if __name__ == "__main__":
119
  parser = argparse.ArgumentParser()
120
+ parser.add_argument("--unet_path", "-u", type=str, default="unet/ema-002.pt")
121
+ parser.add_argument("--vae_path", "-v", type=str, default="vae/")
122
+ parser.add_argument("--out_path", "-o", type=str, default="xora.safetensors")
123
+ parser.add_argument(
124
+ "--mode",
125
+ "-m",
126
+ type=str,
127
+ choices=["single", "separate"],
128
+ default="single",
129
+ help="Choose 'single' for the original behavior, 'separate' to save unet and vae separately.",
130
+ )
131
+ parser.add_argument(
132
+ "--unet_config_path",
133
+ type=str,
134
+ help="Path to the UNet config file (for separate mode)",
135
+ )
136
+ parser.add_argument(
137
+ "--scheduler_config_path",
138
+ type=str,
139
+ help="Path to the Scheduler config file (for separate mode)",
140
+ )
141
 
142
  args = parser.parse_args()
143
  main(**args.__dict__)
setup.py CHANGED
@@ -1,7 +1,9 @@
1
  from setuptools import setup, find_packages
 
 
2
  def parse_requirements(filename):
3
  """Load requirements from a pip requirements file."""
4
- with open(filename, 'r') as file:
5
  return file.read().splitlines()
6
 
7
 
@@ -13,11 +15,13 @@ setup(
13
  author_email="sapir@lightricks.com", # Your email
14
  url="https://github.com/LightricksResearch/xora-core", # URL for the project (GitHub, etc.)
15
  packages=find_packages(), # Automatically find all packages inside `xora`
16
- install_requires=parse_requirements('requirements.txt'), # Install dependencies from requirements.txt
 
 
17
  classifiers=[
18
- 'Programming Language :: Python :: 3',
19
- 'License :: OSI Approved :: MIT License',
20
- 'Operating System :: OS Independent',
21
  ],
22
- python_requires='>=3.10', # Specify Python version compatibility
23
- )
 
1
  from setuptools import setup, find_packages
2
+
3
+
4
  def parse_requirements(filename):
5
  """Load requirements from a pip requirements file."""
6
+ with open(filename, "r") as file:
7
  return file.read().splitlines()
8
 
9
 
 
15
  author_email="sapir@lightricks.com", # Your email
16
  url="https://github.com/LightricksResearch/xora-core", # URL for the project (GitHub, etc.)
17
  packages=find_packages(), # Automatically find all packages inside `xora`
18
+ install_requires=parse_requirements(
19
+ "requirements.txt"
20
+ ), # Install dependencies from requirements.txt
21
  classifiers=[
22
+ "Programming Language :: Python :: 3",
23
+ "License :: OSI Approved :: MIT License",
24
+ "Operating System :: OS Independent",
25
  ],
26
+ python_requires=">=3.10", # Specify Python version compatibility
27
+ )
xora/__init__.py CHANGED
@@ -1 +0,0 @@
1
- from .pipelines import *
 
 
xora/examples/image_to_video.py CHANGED
@@ -1,4 +1,3 @@
1
- import time
2
  import torch
3
  from xora.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder
4
  from xora.models.transformers.transformer3d import Transformer3DModel
@@ -15,19 +14,20 @@ import os
15
  import numpy as np
16
  import cv2
17
  from PIL import Image
18
- from tqdm import tqdm
19
  import random
20
 
 
21
  def load_vae(vae_dir):
22
  vae_ckpt_path = vae_dir / "diffusion_pytorch_model.safetensors"
23
  vae_config_path = vae_dir / "config.json"
24
- with open(vae_config_path, 'r') as f:
25
  vae_config = json.load(f)
26
  vae = CausalVideoAutoencoder.from_config(vae_config)
27
  vae_state_dict = safetensors.torch.load_file(vae_ckpt_path)
28
  vae.load_state_dict(vae_state_dict)
29
  return vae.cuda().to(torch.bfloat16)
30
 
 
31
  def load_unet(unet_dir):
32
  unet_ckpt_path = unet_dir / "diffusion_pytorch_model.safetensors"
33
  unet_config_path = unet_dir / "config.json"
@@ -37,11 +37,13 @@ def load_unet(unet_dir):
37
  transformer.load_state_dict(unet_state_dict, strict=True)
38
  return transformer.cuda()
39
 
 
40
  def load_scheduler(scheduler_dir):
41
  scheduler_config_path = scheduler_dir / "scheduler_config.json"
42
  scheduler_config = RectifiedFlowScheduler.load_config(scheduler_config_path)
43
  return RectifiedFlowScheduler.from_config(scheduler_config)
44
 
 
45
  def center_crop_and_resize(frame, target_height, target_width):
46
  h, w, _ = frame.shape
47
  aspect_ratio_target = target_width / target_height
@@ -49,14 +51,15 @@ def center_crop_and_resize(frame, target_height, target_width):
49
  if aspect_ratio_frame > aspect_ratio_target:
50
  new_width = int(h * aspect_ratio_target)
51
  x_start = (w - new_width) // 2
52
- frame_cropped = frame[:, x_start:x_start + new_width]
53
  else:
54
  new_height = int(w / aspect_ratio_target)
55
  y_start = (h - new_height) // 2
56
- frame_cropped = frame[y_start:y_start + new_height, :]
57
  frame_resized = cv2.resize(frame_cropped, (target_width, target_height))
58
  return frame_resized
59
 
 
60
  def load_video_to_tensor_with_resize(video_path, target_height=512, target_width=768):
61
  cap = cv2.VideoCapture(video_path)
62
  frames = []
@@ -72,6 +75,7 @@ def load_video_to_tensor_with_resize(video_path, target_height=512, target_width
72
  video_tensor = torch.tensor(video_np).permute(3, 0, 1, 2).float()
73
  return video_tensor
74
 
 
75
  def load_image_to_tensor_with_resize(image_path, target_height=512, target_width=768):
76
  image = Image.open(image_path).convert("RGB")
77
  image_np = np.array(image)
@@ -81,51 +85,90 @@ def load_image_to_tensor_with_resize(image_path, target_height=512, target_width
81
  # Create 5D tensor: (batch_size=1, channels=3, num_frames=1, height, width)
82
  return frame_tensor.unsqueeze(0).unsqueeze(2)
83
 
 
84
  def main():
85
- parser = argparse.ArgumentParser(description='Load models from separate directories and run the pipeline.')
 
 
86
 
87
  # Directories
88
- parser.add_argument('--ckpt_dir', type=str, required=True,
89
- help='Path to the directory containing unet, vae, and scheduler subdirectories')
90
- parser.add_argument('--video_path', type=str,
91
- help='Path to the input video file (first frame used)')
92
- parser.add_argument('--image_path', type=str,
93
- help='Path to the input image file')
94
- parser.add_argument('--seed', type=int, default="171198")
 
 
 
 
95
 
96
  # Pipeline parameters
97
- parser.add_argument('--num_inference_steps', type=int, default=40, help='Number of inference steps')
98
- parser.add_argument('--num_images_per_prompt', type=int, default=1, help='Number of images per prompt')
99
- parser.add_argument('--guidance_scale', type=float, default=3, help='Guidance scale for the pipeline')
100
- parser.add_argument('--height', type=int, default=512, help='Height of the output video frames')
101
- parser.add_argument('--width', type=int, default=768, help='Width of the output video frames')
102
- parser.add_argument('--num_frames', type=int, default=121, help='Number of frames to generate in the output video')
103
- parser.add_argument('--frame_rate', type=int, default=25, help='Frame rate for the output video')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
 
105
  # Prompts
106
- parser.add_argument('--prompt', type=str,
107
- default='A man wearing a black leather jacket and blue jeans is riding a Harley Davidson motorcycle down a paved road. The man has short brown hair and is wearing a black helmet. The motorcycle is a dark red color with a large front fairing. The road is surrounded by green grass and trees. There is a gas station on the left side of the road with a red and white sign that says "Oil" and "Diner".',
108
- help='Text prompt to guide generation')
109
- parser.add_argument('--negative_prompt', type=str,
110
- default='worst quality, inconsistent motion, blurry, jittery, distorted',
111
- help='Negative prompt for undesired features')
 
 
 
 
 
 
112
 
113
  args = parser.parse_args()
114
 
115
  # Paths for the separate mode directories
116
  ckpt_dir = Path(args.ckpt_dir)
117
- unet_dir = ckpt_dir / 'unet'
118
- vae_dir = ckpt_dir / 'vae'
119
- scheduler_dir = ckpt_dir / 'scheduler'
120
 
121
  # Load models
122
  vae = load_vae(vae_dir)
123
  unet = load_unet(unet_dir)
124
  scheduler = load_scheduler(scheduler_dir)
125
  patchifier = SymmetricPatchifier(patch_size=1)
126
- text_encoder = T5EncoderModel.from_pretrained("PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="text_encoder").to(
127
- "cuda")
128
- tokenizer = T5Tokenizer.from_pretrained("PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="tokenizer")
 
 
 
129
 
130
  # Use submodels for the pipeline
131
  submodel_dict = {
@@ -141,22 +184,25 @@ def main():
141
 
142
  # Load media (video or image)
143
  if args.video_path:
144
- media_items = load_video_to_tensor_with_resize(args.video_path, args.height, args.width).unsqueeze(0)
 
 
145
  elif args.image_path:
146
- media_items = load_image_to_tensor_with_resize(args.image_path, args.height, args.width)
 
 
147
  else:
148
  raise ValueError("Either --video_path or --image_path must be provided.")
149
 
150
  # Prepare input for the pipeline
151
  sample = {
152
  "prompt": args.prompt,
153
- 'prompt_attention_mask': None,
154
- 'negative_prompt': args.negative_prompt,
155
- 'negative_prompt_attention_mask': None,
156
- 'media_items': media_items,
157
  }
158
 
159
- start_time = time.time()
160
  random.seed(args.seed)
161
  np.random.seed(args.seed)
162
  torch.manual_seed(args.seed)
@@ -177,16 +223,18 @@ def main():
177
  **sample,
178
  is_video=True,
179
  vae_per_channel_normalize=True,
180
- conditioning_method=ConditioningMethod.FIRST_FRAME
181
  ).images
 
182
  # Save output video
183
- def get_unique_filename(base, ext, dir='.', index_range=1000):
184
  for i in range(index_range):
185
  filename = os.path.join(dir, f"{base}_{i}{ext}")
186
  if not os.path.exists(filename):
187
  return filename
188
- raise FileExistsError(f"Could not find a unique filename after {index_range} attempts.")
189
-
 
190
 
191
  for i in range(images.shape[0]):
192
  video_np = images.squeeze(0).permute(1, 2, 3, 0).cpu().float().numpy()
@@ -195,7 +243,9 @@ def main():
195
  height, width = video_np.shape[1:3]
196
  output_filename = get_unique_filename(f"video_output_{i}", ".mp4", ".")
197
 
198
- out = cv2.VideoWriter(output_filename, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))
 
 
199
 
200
  for frame in video_np[..., ::-1]:
201
  out.write(frame)
 
 
1
  import torch
2
  from xora.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder
3
  from xora.models.transformers.transformer3d import Transformer3DModel
 
14
  import numpy as np
15
  import cv2
16
  from PIL import Image
 
17
  import random
18
 
19
+
20
  def load_vae(vae_dir):
21
  vae_ckpt_path = vae_dir / "diffusion_pytorch_model.safetensors"
22
  vae_config_path = vae_dir / "config.json"
23
+ with open(vae_config_path, "r") as f:
24
  vae_config = json.load(f)
25
  vae = CausalVideoAutoencoder.from_config(vae_config)
26
  vae_state_dict = safetensors.torch.load_file(vae_ckpt_path)
27
  vae.load_state_dict(vae_state_dict)
28
  return vae.cuda().to(torch.bfloat16)
29
 
30
+
31
  def load_unet(unet_dir):
32
  unet_ckpt_path = unet_dir / "diffusion_pytorch_model.safetensors"
33
  unet_config_path = unet_dir / "config.json"
 
37
  transformer.load_state_dict(unet_state_dict, strict=True)
38
  return transformer.cuda()
39
 
40
+
41
  def load_scheduler(scheduler_dir):
42
  scheduler_config_path = scheduler_dir / "scheduler_config.json"
43
  scheduler_config = RectifiedFlowScheduler.load_config(scheduler_config_path)
44
  return RectifiedFlowScheduler.from_config(scheduler_config)
45
 
46
+
47
  def center_crop_and_resize(frame, target_height, target_width):
48
  h, w, _ = frame.shape
49
  aspect_ratio_target = target_width / target_height
 
51
  if aspect_ratio_frame > aspect_ratio_target:
52
  new_width = int(h * aspect_ratio_target)
53
  x_start = (w - new_width) // 2
54
+ frame_cropped = frame[:, x_start : x_start + new_width]
55
  else:
56
  new_height = int(w / aspect_ratio_target)
57
  y_start = (h - new_height) // 2
58
+ frame_cropped = frame[y_start : y_start + new_height, :]
59
  frame_resized = cv2.resize(frame_cropped, (target_width, target_height))
60
  return frame_resized
61
 
62
+
63
  def load_video_to_tensor_with_resize(video_path, target_height=512, target_width=768):
64
  cap = cv2.VideoCapture(video_path)
65
  frames = []
 
75
  video_tensor = torch.tensor(video_np).permute(3, 0, 1, 2).float()
76
  return video_tensor
77
 
78
+
79
  def load_image_to_tensor_with_resize(image_path, target_height=512, target_width=768):
80
  image = Image.open(image_path).convert("RGB")
81
  image_np = np.array(image)
 
85
  # Create 5D tensor: (batch_size=1, channels=3, num_frames=1, height, width)
86
  return frame_tensor.unsqueeze(0).unsqueeze(2)
87
 
88
+
89
  def main():
90
+ parser = argparse.ArgumentParser(
91
+ description="Load models from separate directories and run the pipeline."
92
+ )
93
 
94
  # Directories
95
+ parser.add_argument(
96
+ "--ckpt_dir",
97
+ type=str,
98
+ required=True,
99
+ help="Path to the directory containing unet, vae, and scheduler subdirectories",
100
+ )
101
+ parser.add_argument(
102
+ "--video_path", type=str, help="Path to the input video file (first frame used)"
103
+ )
104
+ parser.add_argument("--image_path", type=str, help="Path to the input image file")
105
+ parser.add_argument("--seed", type=int, default="171198")
106
 
107
  # Pipeline parameters
108
+ parser.add_argument(
109
+ "--num_inference_steps", type=int, default=40, help="Number of inference steps"
110
+ )
111
+ parser.add_argument(
112
+ "--num_images_per_prompt",
113
+ type=int,
114
+ default=1,
115
+ help="Number of images per prompt",
116
+ )
117
+ parser.add_argument(
118
+ "--guidance_scale",
119
+ type=float,
120
+ default=3,
121
+ help="Guidance scale for the pipeline",
122
+ )
123
+ parser.add_argument(
124
+ "--height", type=int, default=512, help="Height of the output video frames"
125
+ )
126
+ parser.add_argument(
127
+ "--width", type=int, default=768, help="Width of the output video frames"
128
+ )
129
+ parser.add_argument(
130
+ "--num_frames",
131
+ type=int,
132
+ default=121,
133
+ help="Number of frames to generate in the output video",
134
+ )
135
+ parser.add_argument(
136
+ "--frame_rate", type=int, default=25, help="Frame rate for the output video"
137
+ )
138
 
139
  # Prompts
140
+ parser.add_argument(
141
+ "--prompt",
142
+ type=str,
143
+ default='A man wearing a black leather jacket and blue jeans is riding a Harley Davidson motorcycle down a paved road. The man has short brown hair and is wearing a black helmet. The motorcycle is a dark red color with a large front fairing. The road is surrounded by green grass and trees. There is a gas station on the left side of the road with a red and white sign that says "Oil" and "Diner".',
144
+ help="Text prompt to guide generation",
145
+ )
146
+ parser.add_argument(
147
+ "--negative_prompt",
148
+ type=str,
149
+ default="worst quality, inconsistent motion, blurry, jittery, distorted",
150
+ help="Negative prompt for undesired features",
151
+ )
152
 
153
  args = parser.parse_args()
154
 
155
  # Paths for the separate mode directories
156
  ckpt_dir = Path(args.ckpt_dir)
157
+ unet_dir = ckpt_dir / "unet"
158
+ vae_dir = ckpt_dir / "vae"
159
+ scheduler_dir = ckpt_dir / "scheduler"
160
 
161
  # Load models
162
  vae = load_vae(vae_dir)
163
  unet = load_unet(unet_dir)
164
  scheduler = load_scheduler(scheduler_dir)
165
  patchifier = SymmetricPatchifier(patch_size=1)
166
+ text_encoder = T5EncoderModel.from_pretrained(
167
+ "PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="text_encoder"
168
+ ).to("cuda")
169
+ tokenizer = T5Tokenizer.from_pretrained(
170
+ "PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="tokenizer"
171
+ )
172
 
173
  # Use submodels for the pipeline
174
  submodel_dict = {
 
184
 
185
  # Load media (video or image)
186
  if args.video_path:
187
+ media_items = load_video_to_tensor_with_resize(
188
+ args.video_path, args.height, args.width
189
+ ).unsqueeze(0)
190
  elif args.image_path:
191
+ media_items = load_image_to_tensor_with_resize(
192
+ args.image_path, args.height, args.width
193
+ )
194
  else:
195
  raise ValueError("Either --video_path or --image_path must be provided.")
196
 
197
  # Prepare input for the pipeline
198
  sample = {
199
  "prompt": args.prompt,
200
+ "prompt_attention_mask": None,
201
+ "negative_prompt": args.negative_prompt,
202
+ "negative_prompt_attention_mask": None,
203
+ "media_items": media_items,
204
  }
205
 
 
206
  random.seed(args.seed)
207
  np.random.seed(args.seed)
208
  torch.manual_seed(args.seed)
 
223
  **sample,
224
  is_video=True,
225
  vae_per_channel_normalize=True,
226
+ conditioning_method=ConditioningMethod.FIRST_FRAME,
227
  ).images
228
+
229
  # Save output video
230
+ def get_unique_filename(base, ext, dir=".", index_range=1000):
231
  for i in range(index_range):
232
  filename = os.path.join(dir, f"{base}_{i}{ext}")
233
  if not os.path.exists(filename):
234
  return filename
235
+ raise FileExistsError(
236
+ f"Could not find a unique filename after {index_range} attempts."
237
+ )
238
 
239
  for i in range(images.shape[0]):
240
  video_np = images.squeeze(0).permute(1, 2, 3, 0).cpu().float().numpy()
 
243
  height, width = video_np.shape[1:3]
244
  output_filename = get_unique_filename(f"video_output_{i}", ".mp4", ".")
245
 
246
+ out = cv2.VideoWriter(
247
+ output_filename, cv2.VideoWriter_fourcc(*"mp4v"), fps, (width, height)
248
+ )
249
 
250
  for frame in video_np[..., ::-1]:
251
  out.write(frame)
xora/examples/text_to_video.py CHANGED
@@ -10,16 +10,18 @@ import safetensors.torch
10
  import json
11
  import argparse
12
 
 
13
  def load_vae(vae_dir):
14
  vae_ckpt_path = vae_dir / "diffusion_pytorch_model.safetensors"
15
  vae_config_path = vae_dir / "config.json"
16
- with open(vae_config_path, 'r') as f:
17
  vae_config = json.load(f)
18
  vae = CausalVideoAutoencoder.from_config(vae_config)
19
  vae_state_dict = safetensors.torch.load_file(vae_ckpt_path)
20
  vae.load_state_dict(vae_state_dict)
21
  return vae.cuda().to(torch.bfloat16)
22
 
 
23
  def load_unet(unet_dir):
24
  unet_ckpt_path = unet_dir / "diffusion_pytorch_model.safetensors"
25
  unet_config_path = unet_dir / "config.json"
@@ -29,22 +31,31 @@ def load_unet(unet_dir):
29
  transformer.load_state_dict(unet_state_dict, strict=True)
30
  return transformer.cuda()
31
 
 
32
  def load_scheduler(scheduler_dir):
33
  scheduler_config_path = scheduler_dir / "scheduler_config.json"
34
  scheduler_config = RectifiedFlowScheduler.load_config(scheduler_config_path)
35
  return RectifiedFlowScheduler.from_config(scheduler_config)
36
 
 
37
  def main():
38
  # Parse command line arguments
39
- parser = argparse.ArgumentParser(description='Load models from separate directories')
40
- parser.add_argument('--separate_dir', type=str, required=True, help='Path to the directory containing unet, vae, and scheduler subdirectories')
 
 
 
 
 
 
 
41
  args = parser.parse_args()
42
 
43
  # Paths for the separate mode directories
44
  separate_dir = Path(args.separate_dir)
45
- unet_dir = separate_dir / 'unet'
46
- vae_dir = separate_dir / 'vae'
47
- scheduler_dir = separate_dir / 'scheduler'
48
 
49
  # Load models
50
  vae = load_vae(vae_dir)
@@ -54,8 +65,12 @@ def main():
54
  # Patchifier (remains the same)
55
  patchifier = SymmetricPatchifier(patch_size=1)
56
 
57
- text_encoder = T5EncoderModel.from_pretrained("PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="text_encoder").to("cuda")
58
- tokenizer = T5Tokenizer.from_pretrained("PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="tokenizer")
 
 
 
 
59
 
60
  # Use submodels for the pipeline
61
  submodel_dict = {
@@ -79,14 +94,14 @@ def main():
79
  frame_rate = 25
80
  sample = {
81
  "prompt": "A middle-aged man with glasses and a salt-and-pepper beard is driving a car and talking, gesturing with his right hand. "
82
- "The man is wearing a dark blue zip-up jacket and a light blue collared shirt. He is sitting in the driver's seat of a car with a black interior. The car is moving on a road with trees and bushes on either side. The man has a serious expression on his face and is looking straight ahead.",
83
- 'prompt_attention_mask': None, # Adjust attention masks as needed
84
- 'negative_prompt': "Ugly deformed",
85
- 'negative_prompt_attention_mask': None
86
  }
87
 
88
  # Generate images (video frames)
89
- images = pipeline(
90
  num_inference_steps=num_inference_steps,
91
  num_images_per_prompt=num_images_per_prompt,
92
  guidance_scale=guidance_scale,
@@ -104,5 +119,6 @@ def main():
104
 
105
  print("Generated images (video frames).")
106
 
 
107
  if __name__ == "__main__":
108
  main()
 
10
  import json
11
  import argparse
12
 
13
+
14
  def load_vae(vae_dir):
15
  vae_ckpt_path = vae_dir / "diffusion_pytorch_model.safetensors"
16
  vae_config_path = vae_dir / "config.json"
17
+ with open(vae_config_path, "r") as f:
18
  vae_config = json.load(f)
19
  vae = CausalVideoAutoencoder.from_config(vae_config)
20
  vae_state_dict = safetensors.torch.load_file(vae_ckpt_path)
21
  vae.load_state_dict(vae_state_dict)
22
  return vae.cuda().to(torch.bfloat16)
23
 
24
+
25
  def load_unet(unet_dir):
26
  unet_ckpt_path = unet_dir / "diffusion_pytorch_model.safetensors"
27
  unet_config_path = unet_dir / "config.json"
 
31
  transformer.load_state_dict(unet_state_dict, strict=True)
32
  return transformer.cuda()
33
 
34
+
35
  def load_scheduler(scheduler_dir):
36
  scheduler_config_path = scheduler_dir / "scheduler_config.json"
37
  scheduler_config = RectifiedFlowScheduler.load_config(scheduler_config_path)
38
  return RectifiedFlowScheduler.from_config(scheduler_config)
39
 
40
+
41
  def main():
42
  # Parse command line arguments
43
+ parser = argparse.ArgumentParser(
44
+ description="Load models from separate directories"
45
+ )
46
+ parser.add_argument(
47
+ "--separate_dir",
48
+ type=str,
49
+ required=True,
50
+ help="Path to the directory containing unet, vae, and scheduler subdirectories",
51
+ )
52
  args = parser.parse_args()
53
 
54
  # Paths for the separate mode directories
55
  separate_dir = Path(args.separate_dir)
56
+ unet_dir = separate_dir / "unet"
57
+ vae_dir = separate_dir / "vae"
58
+ scheduler_dir = separate_dir / "scheduler"
59
 
60
  # Load models
61
  vae = load_vae(vae_dir)
 
65
  # Patchifier (remains the same)
66
  patchifier = SymmetricPatchifier(patch_size=1)
67
 
68
+ text_encoder = T5EncoderModel.from_pretrained(
69
+ "PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="text_encoder"
70
+ ).to("cuda")
71
+ tokenizer = T5Tokenizer.from_pretrained(
72
+ "PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="tokenizer"
73
+ )
74
 
75
  # Use submodels for the pipeline
76
  submodel_dict = {
 
94
  frame_rate = 25
95
  sample = {
96
  "prompt": "A middle-aged man with glasses and a salt-and-pepper beard is driving a car and talking, gesturing with his right hand. "
97
+ "The man is wearing a dark blue zip-up jacket and a light blue collared shirt. He is sitting in the driver's seat of a car with a black interior. The car is moving on a road with trees and bushes on either side. The man has a serious expression on his face and is looking straight ahead.",
98
+ "prompt_attention_mask": None, # Adjust attention masks as needed
99
+ "negative_prompt": "Ugly deformed",
100
+ "negative_prompt_attention_mask": None,
101
  }
102
 
103
  # Generate images (video frames)
104
+ _ = pipeline(
105
  num_inference_steps=num_inference_steps,
106
  num_images_per_prompt=num_images_per_prompt,
107
  guidance_scale=guidance_scale,
 
119
 
120
  print("Generated images (video frames).")
121
 
122
+
123
  if __name__ == "__main__":
124
  main()
xora/models/autoencoders/causal_conv3d.py CHANGED
@@ -40,11 +40,17 @@ class CausalConv3d(nn.Module):
40
 
41
  def forward(self, x, causal: bool = True):
42
  if causal:
43
- first_frame_pad = x[:, :, :1, :, :].repeat((1, 1, self.time_kernel_size - 1, 1, 1))
 
 
44
  x = torch.concatenate((first_frame_pad, x), dim=2)
45
  else:
46
- first_frame_pad = x[:, :, :1, :, :].repeat((1, 1, (self.time_kernel_size - 1) // 2, 1, 1))
47
- last_frame_pad = x[:, :, -1:, :, :].repeat((1, 1, (self.time_kernel_size - 1) // 2, 1, 1))
 
 
 
 
48
  x = torch.concatenate((first_frame_pad, x, last_frame_pad), dim=2)
49
  x = self.conv(x)
50
  return x
 
40
 
41
  def forward(self, x, causal: bool = True):
42
  if causal:
43
+ first_frame_pad = x[:, :, :1, :, :].repeat(
44
+ (1, 1, self.time_kernel_size - 1, 1, 1)
45
+ )
46
  x = torch.concatenate((first_frame_pad, x), dim=2)
47
  else:
48
+ first_frame_pad = x[:, :, :1, :, :].repeat(
49
+ (1, 1, (self.time_kernel_size - 1) // 2, 1, 1)
50
+ )
51
+ last_frame_pad = x[:, :, -1:, :, :].repeat(
52
+ (1, 1, (self.time_kernel_size - 1) // 2, 1, 1)
53
+ )
54
  x = torch.concatenate((first_frame_pad, x, last_frame_pad), dim=2)
55
  x = self.conv(x)
56
  return x
xora/models/autoencoders/causal_video_autoencoder.py CHANGED
@@ -16,9 +16,15 @@ from xora.models.autoencoders.vae import AutoencoderKLWrapper
16
 
17
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
18
 
 
19
  class CausalVideoAutoencoder(AutoencoderKLWrapper):
20
  @classmethod
21
- def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *args, **kwargs):
 
 
 
 
 
22
  config_local_path = pretrained_model_name_or_path / "config.json"
23
  config = cls.load_config(config_local_path, **kwargs)
24
  video_vae = cls.from_config(config)
@@ -28,29 +34,41 @@ class CausalVideoAutoencoder(AutoencoderKLWrapper):
28
  ckpt_state_dict = torch.load(model_local_path, map_location=torch.device("cpu"))
29
  video_vae.load_state_dict(ckpt_state_dict)
30
 
31
- statistics_local_path = pretrained_model_name_or_path / "per_channel_statistics.json"
 
 
32
  if statistics_local_path.exists():
33
  with open(statistics_local_path, "r") as file:
34
  data = json.load(file)
35
  transposed_data = list(zip(*data["data"]))
36
- data_dict = {col: torch.tensor(vals) for col, vals in zip(data["columns"], transposed_data)}
 
 
 
37
  video_vae.register_buffer("std_of_means", data_dict["std-of-means"])
38
  video_vae.register_buffer(
39
- "mean_of_means", data_dict.get("mean-of-means", torch.zeros_like(data_dict["std-of-means"]))
 
 
 
40
  )
41
 
42
  return video_vae
43
 
44
  @staticmethod
45
  def from_config(config):
46
- assert config["_class_name"] == "CausalVideoAutoencoder", "config must have _class_name=CausalVideoAutoencoder"
 
 
47
  if isinstance(config["dims"], list):
48
  config["dims"] = tuple(config["dims"])
49
 
50
  assert config["dims"] in [2, 3, (2, 1)], "dims must be 2, 3 or (2, 1)"
51
 
52
  double_z = config.get("double_z", True)
53
- latent_log_var = config.get("latent_log_var", "per_channel" if double_z else "none")
 
 
54
  use_quant_conv = config.get("use_quant_conv", True)
55
 
56
  if use_quant_conv and latent_log_var == "uniform":
@@ -91,7 +109,8 @@ class CausalVideoAutoencoder(AutoencoderKLWrapper):
91
  _class_name="CausalVideoAutoencoder",
92
  dims=self.dims,
93
  in_channels=self.encoder.conv_in.in_channels // self.encoder.patch_size**2,
94
- out_channels=self.decoder.conv_out.out_channels // self.decoder.patch_size**2,
 
95
  latent_channels=self.decoder.conv_in.in_channels,
96
  blocks=self.encoder.blocks_desc,
97
  scaling_factor=1.0,
@@ -112,13 +131,26 @@ class CausalVideoAutoencoder(AutoencoderKLWrapper):
112
  @property
113
  def spatial_downscale_factor(self):
114
  return (
115
- 2 ** len([block for block in self.encoder.blocks_desc if block[0] in ["compress_space", "compress_all"]])
 
 
 
 
 
 
 
116
  * self.encoder.patch_size
117
  )
118
 
119
  @property
120
  def temporal_downscale_factor(self):
121
- return 2 ** len([block for block in self.encoder.blocks_desc if block[0] in ["compress_time", "compress_all"]])
 
 
 
 
 
 
122
 
123
  def to_json_string(self) -> str:
124
  import json
@@ -146,7 +178,9 @@ class CausalVideoAutoencoder(AutoencoderKLWrapper):
146
  key = key.replace(k, v)
147
 
148
  if "norm" in key and key not in model_keys:
149
- logger.info(f"Removing key {key} from state_dict as it is not present in the model")
 
 
150
  continue
151
 
152
  converted_state_dict[key] = value
@@ -293,7 +327,9 @@ class Encoder(nn.Module):
293
 
294
  # out
295
  if norm_layer == "group_norm":
296
- self.conv_norm_out = nn.GroupNorm(num_channels=output_channel, num_groups=norm_num_groups, eps=1e-6)
 
 
297
  elif norm_layer == "pixel_norm":
298
  self.conv_norm_out = PixelNorm()
299
  elif norm_layer == "layer_norm":
@@ -308,7 +344,9 @@ class Encoder(nn.Module):
308
  conv_out_channels += 1
309
  elif latent_log_var != "none":
310
  raise ValueError(f"Invalid latent_log_var: {latent_log_var}")
311
- self.conv_out = make_conv_nd(dims, output_channel, conv_out_channels, 3, padding=1, causal=True)
 
 
312
 
313
  self.gradient_checkpointing = False
314
 
@@ -337,11 +375,15 @@ class Encoder(nn.Module):
337
 
338
  if num_dims == 4:
339
  # For shape (B, C, H, W)
340
- repeated_last_channel = last_channel.repeat(1, sample.shape[1] - 2, 1, 1)
 
 
341
  sample = torch.cat([sample, repeated_last_channel], dim=1)
342
  elif num_dims == 5:
343
  # For shape (B, C, F, H, W)
344
- repeated_last_channel = last_channel.repeat(1, sample.shape[1] - 2, 1, 1, 1)
 
 
345
  sample = torch.cat([sample, repeated_last_channel], dim=1)
346
  else:
347
  raise ValueError(f"Invalid input shape: {sample.shape}")
@@ -430,25 +472,35 @@ class Decoder(nn.Module):
430
  norm_layer=norm_layer,
431
  )
432
  elif block_name == "compress_time":
433
- block = DepthToSpaceUpsample(dims=dims, in_channels=input_channel, stride=(2, 1, 1))
 
 
434
  elif block_name == "compress_space":
435
- block = DepthToSpaceUpsample(dims=dims, in_channels=input_channel, stride=(1, 2, 2))
 
 
436
  elif block_name == "compress_all":
437
- block = DepthToSpaceUpsample(dims=dims, in_channels=input_channel, stride=(2, 2, 2))
 
 
438
  else:
439
  raise ValueError(f"unknown layer: {block_name}")
440
 
441
  self.up_blocks.append(block)
442
 
443
  if norm_layer == "group_norm":
444
- self.conv_norm_out = nn.GroupNorm(num_channels=output_channel, num_groups=norm_num_groups, eps=1e-6)
 
 
445
  elif norm_layer == "pixel_norm":
446
  self.conv_norm_out = PixelNorm()
447
  elif norm_layer == "layer_norm":
448
  self.conv_norm_out = LayerNorm(output_channel, eps=1e-6)
449
 
450
  self.conv_act = nn.SiLU()
451
- self.conv_out = make_conv_nd(dims, output_channel, out_channels, 3, padding=1, causal=True)
 
 
452
 
453
  self.gradient_checkpointing = False
454
 
@@ -509,7 +561,9 @@ class UNetMidBlock3D(nn.Module):
509
  norm_layer: str = "group_norm",
510
  ):
511
  super().__init__()
512
- resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
 
 
513
 
514
  self.res_blocks = nn.ModuleList(
515
  [
@@ -526,7 +580,9 @@ class UNetMidBlock3D(nn.Module):
526
  ]
527
  )
528
 
529
- def forward(self, hidden_states: torch.FloatTensor, causal: bool = True) -> torch.FloatTensor:
 
 
530
  for resnet in self.res_blocks:
531
  hidden_states = resnet(hidden_states, causal=causal)
532
 
@@ -604,7 +660,9 @@ class ResnetBlock3D(nn.Module):
604
  self.use_conv_shortcut = conv_shortcut
605
 
606
  if norm_layer == "group_norm":
607
- self.norm1 = nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
 
 
608
  elif norm_layer == "pixel_norm":
609
  self.norm1 = PixelNorm()
610
  elif norm_layer == "layer_norm":
@@ -612,10 +670,20 @@ class ResnetBlock3D(nn.Module):
612
 
613
  self.non_linearity = nn.SiLU()
614
 
615
- self.conv1 = make_conv_nd(dims, in_channels, out_channels, kernel_size=3, stride=1, padding=1, causal=True)
 
 
 
 
 
 
 
 
616
 
617
  if norm_layer == "group_norm":
618
- self.norm2 = nn.GroupNorm(num_groups=groups, num_channels=out_channels, eps=eps, affine=True)
 
 
619
  elif norm_layer == "pixel_norm":
620
  self.norm2 = PixelNorm()
621
  elif norm_layer == "layer_norm":
@@ -623,16 +691,28 @@ class ResnetBlock3D(nn.Module):
623
 
624
  self.dropout = torch.nn.Dropout(dropout)
625
 
626
- self.conv2 = make_conv_nd(dims, out_channels, out_channels, kernel_size=3, stride=1, padding=1, causal=True)
 
 
 
 
 
 
 
 
627
 
628
  self.conv_shortcut = (
629
- make_linear_nd(dims=dims, in_channels=in_channels, out_channels=out_channels)
 
 
630
  if in_channels != out_channels
631
  else nn.Identity()
632
  )
633
 
634
  self.norm3 = (
635
- LayerNorm(in_channels, eps=eps, elementwise_affine=True) if in_channels != out_channels else nn.Identity()
 
 
636
  )
637
 
638
  def forward(
@@ -669,9 +749,17 @@ def patchify(x, patch_size_hw, patch_size_t=1):
669
  if patch_size_hw == 1 and patch_size_t == 1:
670
  return x
671
  if x.dim() == 4:
672
- x = rearrange(x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size_hw, r=patch_size_hw)
 
 
673
  elif x.dim() == 5:
674
- x = rearrange(x, "b c (f p) (h q) (w r) -> b (c p r q) f h w", p=patch_size_t, q=patch_size_hw, r=patch_size_hw)
 
 
 
 
 
 
675
  else:
676
  raise ValueError(f"Invalid input shape: {x.shape}")
677
 
@@ -683,9 +771,17 @@ def unpatchify(x, patch_size_hw, patch_size_t=1):
683
  return x
684
 
685
  if x.dim() == 4:
686
- x = rearrange(x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size_hw, r=patch_size_hw)
 
 
687
  elif x.dim() == 5:
688
- x = rearrange(x, "b (c p r q) f h w -> b c (f p) (h q) (w r)", p=patch_size_t, q=patch_size_hw, r=patch_size_hw)
 
 
 
 
 
 
689
 
690
  return x
691
 
@@ -755,14 +851,18 @@ def demo_video_autoencoder_forward_backward():
755
  print(f"input shape={input_videos.shape}")
756
  print(f"latent shape={latent.shape}")
757
 
758
- reconstructed_videos = video_autoencoder.decode(latent, target_shape=input_videos.shape).sample
 
 
759
 
760
  print(f"reconstructed shape={reconstructed_videos.shape}")
761
 
762
  # Validate that single image gets treated the same way as first frame
763
  input_image = input_videos[:, :, :1, :, :]
764
  image_latent = video_autoencoder.encode(input_image).latent_dist.mode()
765
- reconstructed_image = video_autoencoder.decode(image_latent, target_shape=image_latent.shape).sample
 
 
766
 
767
  first_frame_latent = latent[:, :, :1, :, :]
768
 
 
16
 
17
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
18
 
19
+
20
  class CausalVideoAutoencoder(AutoencoderKLWrapper):
21
  @classmethod
22
+ def from_pretrained(
23
+ cls,
24
+ pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
25
+ *args,
26
+ **kwargs,
27
+ ):
28
  config_local_path = pretrained_model_name_or_path / "config.json"
29
  config = cls.load_config(config_local_path, **kwargs)
30
  video_vae = cls.from_config(config)
 
34
  ckpt_state_dict = torch.load(model_local_path, map_location=torch.device("cpu"))
35
  video_vae.load_state_dict(ckpt_state_dict)
36
 
37
+ statistics_local_path = (
38
+ pretrained_model_name_or_path / "per_channel_statistics.json"
39
+ )
40
  if statistics_local_path.exists():
41
  with open(statistics_local_path, "r") as file:
42
  data = json.load(file)
43
  transposed_data = list(zip(*data["data"]))
44
+ data_dict = {
45
+ col: torch.tensor(vals)
46
+ for col, vals in zip(data["columns"], transposed_data)
47
+ }
48
  video_vae.register_buffer("std_of_means", data_dict["std-of-means"])
49
  video_vae.register_buffer(
50
+ "mean_of_means",
51
+ data_dict.get(
52
+ "mean-of-means", torch.zeros_like(data_dict["std-of-means"])
53
+ ),
54
  )
55
 
56
  return video_vae
57
 
58
  @staticmethod
59
  def from_config(config):
60
+ assert (
61
+ config["_class_name"] == "CausalVideoAutoencoder"
62
+ ), "config must have _class_name=CausalVideoAutoencoder"
63
  if isinstance(config["dims"], list):
64
  config["dims"] = tuple(config["dims"])
65
 
66
  assert config["dims"] in [2, 3, (2, 1)], "dims must be 2, 3 or (2, 1)"
67
 
68
  double_z = config.get("double_z", True)
69
+ latent_log_var = config.get(
70
+ "latent_log_var", "per_channel" if double_z else "none"
71
+ )
72
  use_quant_conv = config.get("use_quant_conv", True)
73
 
74
  if use_quant_conv and latent_log_var == "uniform":
 
109
  _class_name="CausalVideoAutoencoder",
110
  dims=self.dims,
111
  in_channels=self.encoder.conv_in.in_channels // self.encoder.patch_size**2,
112
+ out_channels=self.decoder.conv_out.out_channels
113
+ // self.decoder.patch_size**2,
114
  latent_channels=self.decoder.conv_in.in_channels,
115
  blocks=self.encoder.blocks_desc,
116
  scaling_factor=1.0,
 
131
  @property
132
  def spatial_downscale_factor(self):
133
  return (
134
+ 2
135
+ ** len(
136
+ [
137
+ block
138
+ for block in self.encoder.blocks_desc
139
+ if block[0] in ["compress_space", "compress_all"]
140
+ ]
141
+ )
142
  * self.encoder.patch_size
143
  )
144
 
145
  @property
146
  def temporal_downscale_factor(self):
147
+ return 2 ** len(
148
+ [
149
+ block
150
+ for block in self.encoder.blocks_desc
151
+ if block[0] in ["compress_time", "compress_all"]
152
+ ]
153
+ )
154
 
155
  def to_json_string(self) -> str:
156
  import json
 
178
  key = key.replace(k, v)
179
 
180
  if "norm" in key and key not in model_keys:
181
+ logger.info(
182
+ f"Removing key {key} from state_dict as it is not present in the model"
183
+ )
184
  continue
185
 
186
  converted_state_dict[key] = value
 
327
 
328
  # out
329
  if norm_layer == "group_norm":
330
+ self.conv_norm_out = nn.GroupNorm(
331
+ num_channels=output_channel, num_groups=norm_num_groups, eps=1e-6
332
+ )
333
  elif norm_layer == "pixel_norm":
334
  self.conv_norm_out = PixelNorm()
335
  elif norm_layer == "layer_norm":
 
344
  conv_out_channels += 1
345
  elif latent_log_var != "none":
346
  raise ValueError(f"Invalid latent_log_var: {latent_log_var}")
347
+ self.conv_out = make_conv_nd(
348
+ dims, output_channel, conv_out_channels, 3, padding=1, causal=True
349
+ )
350
 
351
  self.gradient_checkpointing = False
352
 
 
375
 
376
  if num_dims == 4:
377
  # For shape (B, C, H, W)
378
+ repeated_last_channel = last_channel.repeat(
379
+ 1, sample.shape[1] - 2, 1, 1
380
+ )
381
  sample = torch.cat([sample, repeated_last_channel], dim=1)
382
  elif num_dims == 5:
383
  # For shape (B, C, F, H, W)
384
+ repeated_last_channel = last_channel.repeat(
385
+ 1, sample.shape[1] - 2, 1, 1, 1
386
+ )
387
  sample = torch.cat([sample, repeated_last_channel], dim=1)
388
  else:
389
  raise ValueError(f"Invalid input shape: {sample.shape}")
 
472
  norm_layer=norm_layer,
473
  )
474
  elif block_name == "compress_time":
475
+ block = DepthToSpaceUpsample(
476
+ dims=dims, in_channels=input_channel, stride=(2, 1, 1)
477
+ )
478
  elif block_name == "compress_space":
479
+ block = DepthToSpaceUpsample(
480
+ dims=dims, in_channels=input_channel, stride=(1, 2, 2)
481
+ )
482
  elif block_name == "compress_all":
483
+ block = DepthToSpaceUpsample(
484
+ dims=dims, in_channels=input_channel, stride=(2, 2, 2)
485
+ )
486
  else:
487
  raise ValueError(f"unknown layer: {block_name}")
488
 
489
  self.up_blocks.append(block)
490
 
491
  if norm_layer == "group_norm":
492
+ self.conv_norm_out = nn.GroupNorm(
493
+ num_channels=output_channel, num_groups=norm_num_groups, eps=1e-6
494
+ )
495
  elif norm_layer == "pixel_norm":
496
  self.conv_norm_out = PixelNorm()
497
  elif norm_layer == "layer_norm":
498
  self.conv_norm_out = LayerNorm(output_channel, eps=1e-6)
499
 
500
  self.conv_act = nn.SiLU()
501
+ self.conv_out = make_conv_nd(
502
+ dims, output_channel, out_channels, 3, padding=1, causal=True
503
+ )
504
 
505
  self.gradient_checkpointing = False
506
 
 
561
  norm_layer: str = "group_norm",
562
  ):
563
  super().__init__()
564
+ resnet_groups = (
565
+ resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
566
+ )
567
 
568
  self.res_blocks = nn.ModuleList(
569
  [
 
580
  ]
581
  )
582
 
583
+ def forward(
584
+ self, hidden_states: torch.FloatTensor, causal: bool = True
585
+ ) -> torch.FloatTensor:
586
  for resnet in self.res_blocks:
587
  hidden_states = resnet(hidden_states, causal=causal)
588
 
 
660
  self.use_conv_shortcut = conv_shortcut
661
 
662
  if norm_layer == "group_norm":
663
+ self.norm1 = nn.GroupNorm(
664
+ num_groups=groups, num_channels=in_channels, eps=eps, affine=True
665
+ )
666
  elif norm_layer == "pixel_norm":
667
  self.norm1 = PixelNorm()
668
  elif norm_layer == "layer_norm":
 
670
 
671
  self.non_linearity = nn.SiLU()
672
 
673
+ self.conv1 = make_conv_nd(
674
+ dims,
675
+ in_channels,
676
+ out_channels,
677
+ kernel_size=3,
678
+ stride=1,
679
+ padding=1,
680
+ causal=True,
681
+ )
682
 
683
  if norm_layer == "group_norm":
684
+ self.norm2 = nn.GroupNorm(
685
+ num_groups=groups, num_channels=out_channels, eps=eps, affine=True
686
+ )
687
  elif norm_layer == "pixel_norm":
688
  self.norm2 = PixelNorm()
689
  elif norm_layer == "layer_norm":
 
691
 
692
  self.dropout = torch.nn.Dropout(dropout)
693
 
694
+ self.conv2 = make_conv_nd(
695
+ dims,
696
+ out_channels,
697
+ out_channels,
698
+ kernel_size=3,
699
+ stride=1,
700
+ padding=1,
701
+ causal=True,
702
+ )
703
 
704
  self.conv_shortcut = (
705
+ make_linear_nd(
706
+ dims=dims, in_channels=in_channels, out_channels=out_channels
707
+ )
708
  if in_channels != out_channels
709
  else nn.Identity()
710
  )
711
 
712
  self.norm3 = (
713
+ LayerNorm(in_channels, eps=eps, elementwise_affine=True)
714
+ if in_channels != out_channels
715
+ else nn.Identity()
716
  )
717
 
718
  def forward(
 
749
  if patch_size_hw == 1 and patch_size_t == 1:
750
  return x
751
  if x.dim() == 4:
752
+ x = rearrange(
753
+ x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size_hw, r=patch_size_hw
754
+ )
755
  elif x.dim() == 5:
756
+ x = rearrange(
757
+ x,
758
+ "b c (f p) (h q) (w r) -> b (c p r q) f h w",
759
+ p=patch_size_t,
760
+ q=patch_size_hw,
761
+ r=patch_size_hw,
762
+ )
763
  else:
764
  raise ValueError(f"Invalid input shape: {x.shape}")
765
 
 
771
  return x
772
 
773
  if x.dim() == 4:
774
+ x = rearrange(
775
+ x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size_hw, r=patch_size_hw
776
+ )
777
  elif x.dim() == 5:
778
+ x = rearrange(
779
+ x,
780
+ "b (c p r q) f h w -> b c (f p) (h q) (w r)",
781
+ p=patch_size_t,
782
+ q=patch_size_hw,
783
+ r=patch_size_hw,
784
+ )
785
 
786
  return x
787
 
 
851
  print(f"input shape={input_videos.shape}")
852
  print(f"latent shape={latent.shape}")
853
 
854
+ reconstructed_videos = video_autoencoder.decode(
855
+ latent, target_shape=input_videos.shape
856
+ ).sample
857
 
858
  print(f"reconstructed shape={reconstructed_videos.shape}")
859
 
860
  # Validate that single image gets treated the same way as first frame
861
  input_image = input_videos[:, :, :1, :, :]
862
  image_latent = video_autoencoder.encode(input_image).latent_dist.mode()
863
+ reconstructed_image = video_autoencoder.decode(
864
+ image_latent, target_shape=image_latent.shape
865
+ ).sample
866
 
867
  first_frame_latent = latent[:, :, :1, :, :]
868
 
xora/models/autoencoders/conv_nd_factory.py CHANGED
@@ -71,8 +71,12 @@ def make_linear_nd(
71
  bias=True,
72
  ):
73
  if dims == 2:
74
- return torch.nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias)
 
 
75
  elif dims == 3 or dims == (2, 1):
76
- return torch.nn.Conv3d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias)
 
 
77
  else:
78
  raise ValueError(f"unsupported dimensions: {dims}")
 
71
  bias=True,
72
  ):
73
  if dims == 2:
74
+ return torch.nn.Conv2d(
75
+ in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias
76
+ )
77
  elif dims == 3 or dims == (2, 1):
78
+ return torch.nn.Conv3d(
79
+ in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias
80
+ )
81
  else:
82
  raise ValueError(f"unsupported dimensions: {dims}")
xora/models/autoencoders/dual_conv3d.py CHANGED
@@ -27,7 +27,9 @@ class DualConv3d(nn.Module):
27
  if isinstance(kernel_size, int):
28
  kernel_size = (kernel_size, kernel_size, kernel_size)
29
  if kernel_size == (1, 1, 1):
30
- raise ValueError("kernel_size must be greater than 1. Use make_linear_nd instead.")
 
 
31
  if isinstance(stride, int):
32
  stride = (stride, stride, stride)
33
  if isinstance(padding, int):
@@ -40,11 +42,19 @@ class DualConv3d(nn.Module):
40
  self.bias = bias
41
 
42
  # Define the size of the channels after the first convolution
43
- intermediate_channels = out_channels if in_channels < out_channels else in_channels
 
 
44
 
45
  # Define parameters for the first convolution
46
  self.weight1 = nn.Parameter(
47
- torch.Tensor(intermediate_channels, in_channels // groups, 1, kernel_size[1], kernel_size[2])
 
 
 
 
 
 
48
  )
49
  self.stride1 = (1, stride[1], stride[2])
50
  self.padding1 = (0, padding[1], padding[2])
@@ -55,7 +65,11 @@ class DualConv3d(nn.Module):
55
  self.register_parameter("bias1", None)
56
 
57
  # Define parameters for the second convolution
58
- self.weight2 = nn.Parameter(torch.Tensor(out_channels, intermediate_channels // groups, kernel_size[0], 1, 1))
 
 
 
 
59
  self.stride2 = (stride[0], 1, 1)
60
  self.padding2 = (padding[0], 0, 0)
61
  self.dilation2 = (dilation[0], 1, 1)
@@ -86,13 +100,29 @@ class DualConv3d(nn.Module):
86
 
87
  def forward_with_3d(self, x, skip_time_conv):
88
  # First convolution
89
- x = F.conv3d(x, self.weight1, self.bias1, self.stride1, self.padding1, self.dilation1, self.groups)
 
 
 
 
 
 
 
 
90
 
91
  if skip_time_conv:
92
  return x
93
 
94
  # Second convolution
95
- x = F.conv3d(x, self.weight2, self.bias2, self.stride2, self.padding2, self.dilation2, self.groups)
 
 
 
 
 
 
 
 
96
 
97
  return x
98
 
 
27
  if isinstance(kernel_size, int):
28
  kernel_size = (kernel_size, kernel_size, kernel_size)
29
  if kernel_size == (1, 1, 1):
30
+ raise ValueError(
31
+ "kernel_size must be greater than 1. Use make_linear_nd instead."
32
+ )
33
  if isinstance(stride, int):
34
  stride = (stride, stride, stride)
35
  if isinstance(padding, int):
 
42
  self.bias = bias
43
 
44
  # Define the size of the channels after the first convolution
45
+ intermediate_channels = (
46
+ out_channels if in_channels < out_channels else in_channels
47
+ )
48
 
49
  # Define parameters for the first convolution
50
  self.weight1 = nn.Parameter(
51
+ torch.Tensor(
52
+ intermediate_channels,
53
+ in_channels // groups,
54
+ 1,
55
+ kernel_size[1],
56
+ kernel_size[2],
57
+ )
58
  )
59
  self.stride1 = (1, stride[1], stride[2])
60
  self.padding1 = (0, padding[1], padding[2])
 
65
  self.register_parameter("bias1", None)
66
 
67
  # Define parameters for the second convolution
68
+ self.weight2 = nn.Parameter(
69
+ torch.Tensor(
70
+ out_channels, intermediate_channels // groups, kernel_size[0], 1, 1
71
+ )
72
+ )
73
  self.stride2 = (stride[0], 1, 1)
74
  self.padding2 = (padding[0], 0, 0)
75
  self.dilation2 = (dilation[0], 1, 1)
 
100
 
101
  def forward_with_3d(self, x, skip_time_conv):
102
  # First convolution
103
+ x = F.conv3d(
104
+ x,
105
+ self.weight1,
106
+ self.bias1,
107
+ self.stride1,
108
+ self.padding1,
109
+ self.dilation1,
110
+ self.groups,
111
+ )
112
 
113
  if skip_time_conv:
114
  return x
115
 
116
  # Second convolution
117
+ x = F.conv3d(
118
+ x,
119
+ self.weight2,
120
+ self.bias2,
121
+ self.stride2,
122
+ self.padding2,
123
+ self.dilation2,
124
+ self.groups,
125
+ )
126
 
127
  return x
128
 
xora/models/autoencoders/vae.py CHANGED
@@ -4,7 +4,10 @@ import torch
4
  import math
5
  import torch.nn as nn
6
  from diffusers import ConfigMixin, ModelMixin
7
- from diffusers.models.autoencoders.vae import DecoderOutput, DiagonalGaussianDistribution
 
 
 
8
  from diffusers.models.modeling_outputs import AutoencoderKLOutput
9
  from xora.models.autoencoders.conv_nd_factory import make_conv_nd
10
 
@@ -43,8 +46,12 @@ class AutoencoderKLWrapper(ModelMixin, ConfigMixin):
43
  quant_dims = 2 if dims == 2 else 3
44
  self.decoder = decoder
45
  if use_quant_conv:
46
- self.quant_conv = make_conv_nd(quant_dims, 2 * latent_channels, 2 * latent_channels, 1)
47
- self.post_quant_conv = make_conv_nd(quant_dims, latent_channels, latent_channels, 1)
 
 
 
 
48
  else:
49
  self.quant_conv = nn.Identity()
50
  self.post_quant_conv = nn.Identity()
@@ -104,7 +111,13 @@ class AutoencoderKLWrapper(ModelMixin, ConfigMixin):
104
  for i in range(0, x.shape[3], overlap_size):
105
  row = []
106
  for j in range(0, x.shape[4], overlap_size):
107
- tile = x[:, :, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
 
 
 
 
 
 
108
  tile = self.encoder(tile)
109
  tile = self.quant_conv(tile)
110
  row.append(tile)
@@ -125,42 +138,58 @@ class AutoencoderKLWrapper(ModelMixin, ConfigMixin):
125
  moments = torch.cat(result_rows, dim=3)
126
  return moments
127
 
128
- def blend_z(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
 
 
129
  blend_extent = min(a.shape[2], b.shape[2], blend_extent)
130
  for z in range(blend_extent):
131
- b[:, :, z, :, :] = a[:, :, -blend_extent + z, :, :] * (1 - z / blend_extent) + b[:, :, z, :, :] * (
132
- z / blend_extent
133
- )
134
  return b
135
 
136
- def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
 
 
137
  blend_extent = min(a.shape[3], b.shape[3], blend_extent)
138
  for y in range(blend_extent):
139
- b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (
140
- y / blend_extent
141
- )
142
  return b
143
 
144
- def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
 
 
145
  blend_extent = min(a.shape[4], b.shape[4], blend_extent)
146
  for x in range(blend_extent):
147
- b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (
148
- x / blend_extent
149
- )
150
  return b
151
 
152
  def _hw_tiled_decode(self, z: torch.FloatTensor, target_shape):
153
  overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))
154
  blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)
155
  row_limit = self.tile_sample_min_size - blend_extent
156
- tile_target_shape = (*target_shape[:3], self.tile_sample_min_size, self.tile_sample_min_size)
 
 
 
 
157
  # Split z into overlapping 64x64 tiles and decode them separately.
158
  # The tiles have an overlap to avoid seams between tiles.
159
  rows = []
160
  for i in range(0, z.shape[3], overlap_size):
161
  row = []
162
  for j in range(0, z.shape[4], overlap_size):
163
- tile = z[:, :, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size]
 
 
 
 
 
 
164
  tile = self.post_quant_conv(tile)
165
  decoded = self.decoder(tile, target_shape=tile_target_shape)
166
  row.append(decoded)
@@ -181,20 +210,34 @@ class AutoencoderKLWrapper(ModelMixin, ConfigMixin):
181
  dec = torch.cat(result_rows, dim=3)
182
  return dec
183
 
184
- def encode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
 
 
185
  if self.use_z_tiling and z.shape[2] > self.z_sample_size > 1:
186
  num_splits = z.shape[2] // self.z_sample_size
187
  sizes = [self.z_sample_size] * num_splits
188
- sizes = sizes + [z.shape[2] - sum(sizes)] if z.shape[2] - sum(sizes) > 0 else sizes
 
 
 
 
189
  tiles = z.split(sizes, dim=2)
190
  moments_tiles = [
191
- self._hw_tiled_encode(z_tile, return_dict) if self.use_hw_tiling else self._encode(z_tile)
 
 
 
 
192
  for z_tile in tiles
193
  ]
194
  moments = torch.cat(moments_tiles, dim=2)
195
 
196
  else:
197
- moments = self._hw_tiled_encode(z, return_dict) if self.use_hw_tiling else self._encode(z)
 
 
 
 
198
 
199
  posterior = DiagonalGaussianDistribution(moments)
200
  if not return_dict:
@@ -207,7 +250,9 @@ class AutoencoderKLWrapper(ModelMixin, ConfigMixin):
207
  moments = self.quant_conv(h)
208
  return moments
209
 
210
- def _decode(self, z: torch.FloatTensor, target_shape=None) -> Union[DecoderOutput, torch.FloatTensor]:
 
 
211
  z = self.post_quant_conv(z)
212
  dec = self.decoder(z, target_shape=target_shape)
213
  return dec
@@ -219,7 +264,12 @@ class AutoencoderKLWrapper(ModelMixin, ConfigMixin):
219
  if self.use_z_tiling and z.shape[2] > self.z_sample_size > 1:
220
  reduction_factor = int(
221
  self.encoder.patch_size_t
222
- * 2 ** (len(self.encoder.down_blocks) - 1 - math.sqrt(self.encoder.patch_size))
 
 
 
 
 
223
  )
224
  split_size = self.z_sample_size // reduction_factor
225
  num_splits = z.shape[2] // split_size
 
4
  import math
5
  import torch.nn as nn
6
  from diffusers import ConfigMixin, ModelMixin
7
+ from diffusers.models.autoencoders.vae import (
8
+ DecoderOutput,
9
+ DiagonalGaussianDistribution,
10
+ )
11
  from diffusers.models.modeling_outputs import AutoencoderKLOutput
12
  from xora.models.autoencoders.conv_nd_factory import make_conv_nd
13
 
 
46
  quant_dims = 2 if dims == 2 else 3
47
  self.decoder = decoder
48
  if use_quant_conv:
49
+ self.quant_conv = make_conv_nd(
50
+ quant_dims, 2 * latent_channels, 2 * latent_channels, 1
51
+ )
52
+ self.post_quant_conv = make_conv_nd(
53
+ quant_dims, latent_channels, latent_channels, 1
54
+ )
55
  else:
56
  self.quant_conv = nn.Identity()
57
  self.post_quant_conv = nn.Identity()
 
111
  for i in range(0, x.shape[3], overlap_size):
112
  row = []
113
  for j in range(0, x.shape[4], overlap_size):
114
+ tile = x[
115
+ :,
116
+ :,
117
+ :,
118
+ i : i + self.tile_sample_min_size,
119
+ j : j + self.tile_sample_min_size,
120
+ ]
121
  tile = self.encoder(tile)
122
  tile = self.quant_conv(tile)
123
  row.append(tile)
 
138
  moments = torch.cat(result_rows, dim=3)
139
  return moments
140
 
141
+ def blend_z(
142
+ self, a: torch.Tensor, b: torch.Tensor, blend_extent: int
143
+ ) -> torch.Tensor:
144
  blend_extent = min(a.shape[2], b.shape[2], blend_extent)
145
  for z in range(blend_extent):
146
+ b[:, :, z, :, :] = a[:, :, -blend_extent + z, :, :] * (
147
+ 1 - z / blend_extent
148
+ ) + b[:, :, z, :, :] * (z / blend_extent)
149
  return b
150
 
151
+ def blend_v(
152
+ self, a: torch.Tensor, b: torch.Tensor, blend_extent: int
153
+ ) -> torch.Tensor:
154
  blend_extent = min(a.shape[3], b.shape[3], blend_extent)
155
  for y in range(blend_extent):
156
+ b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (
157
+ 1 - y / blend_extent
158
+ ) + b[:, :, :, y, :] * (y / blend_extent)
159
  return b
160
 
161
+ def blend_h(
162
+ self, a: torch.Tensor, b: torch.Tensor, blend_extent: int
163
+ ) -> torch.Tensor:
164
  blend_extent = min(a.shape[4], b.shape[4], blend_extent)
165
  for x in range(blend_extent):
166
+ b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (
167
+ 1 - x / blend_extent
168
+ ) + b[:, :, :, :, x] * (x / blend_extent)
169
  return b
170
 
171
  def _hw_tiled_decode(self, z: torch.FloatTensor, target_shape):
172
  overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))
173
  blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)
174
  row_limit = self.tile_sample_min_size - blend_extent
175
+ tile_target_shape = (
176
+ *target_shape[:3],
177
+ self.tile_sample_min_size,
178
+ self.tile_sample_min_size,
179
+ )
180
  # Split z into overlapping 64x64 tiles and decode them separately.
181
  # The tiles have an overlap to avoid seams between tiles.
182
  rows = []
183
  for i in range(0, z.shape[3], overlap_size):
184
  row = []
185
  for j in range(0, z.shape[4], overlap_size):
186
+ tile = z[
187
+ :,
188
+ :,
189
+ :,
190
+ i : i + self.tile_latent_min_size,
191
+ j : j + self.tile_latent_min_size,
192
+ ]
193
  tile = self.post_quant_conv(tile)
194
  decoded = self.decoder(tile, target_shape=tile_target_shape)
195
  row.append(decoded)
 
210
  dec = torch.cat(result_rows, dim=3)
211
  return dec
212
 
213
+ def encode(
214
+ self, z: torch.FloatTensor, return_dict: bool = True
215
+ ) -> Union[DecoderOutput, torch.FloatTensor]:
216
  if self.use_z_tiling and z.shape[2] > self.z_sample_size > 1:
217
  num_splits = z.shape[2] // self.z_sample_size
218
  sizes = [self.z_sample_size] * num_splits
219
+ sizes = (
220
+ sizes + [z.shape[2] - sum(sizes)]
221
+ if z.shape[2] - sum(sizes) > 0
222
+ else sizes
223
+ )
224
  tiles = z.split(sizes, dim=2)
225
  moments_tiles = [
226
+ (
227
+ self._hw_tiled_encode(z_tile, return_dict)
228
+ if self.use_hw_tiling
229
+ else self._encode(z_tile)
230
+ )
231
  for z_tile in tiles
232
  ]
233
  moments = torch.cat(moments_tiles, dim=2)
234
 
235
  else:
236
+ moments = (
237
+ self._hw_tiled_encode(z, return_dict)
238
+ if self.use_hw_tiling
239
+ else self._encode(z)
240
+ )
241
 
242
  posterior = DiagonalGaussianDistribution(moments)
243
  if not return_dict:
 
250
  moments = self.quant_conv(h)
251
  return moments
252
 
253
+ def _decode(
254
+ self, z: torch.FloatTensor, target_shape=None
255
+ ) -> Union[DecoderOutput, torch.FloatTensor]:
256
  z = self.post_quant_conv(z)
257
  dec = self.decoder(z, target_shape=target_shape)
258
  return dec
 
264
  if self.use_z_tiling and z.shape[2] > self.z_sample_size > 1:
265
  reduction_factor = int(
266
  self.encoder.patch_size_t
267
+ * 2
268
+ ** (
269
+ len(self.encoder.down_blocks)
270
+ - 1
271
+ - math.sqrt(self.encoder.patch_size)
272
+ )
273
  )
274
  split_size = self.z_sample_size // reduction_factor
275
  num_splits = z.shape[2] // split_size
xora/models/autoencoders/vae_encode.py CHANGED
@@ -6,12 +6,19 @@ from torch import Tensor
6
 
7
  from xora.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder
8
  from xora.models.autoencoders.video_autoencoder import Downsample3D, VideoAutoencoder
 
9
  try:
10
  import torch_xla.core.xla_model as xm
11
- except:
12
- pass
 
13
 
14
- def vae_encode(media_items: Tensor, vae: AutoencoderKL, split_size: int = 1, vae_per_channel_normalize=False) -> Tensor:
 
 
 
 
 
15
  """
16
  Encodes media items (images or videos) into latent representations using a specified VAE model.
17
  The function supports processing batches of images or video frames and can handle the processing
@@ -48,11 +55,15 @@ def vae_encode(media_items: Tensor, vae: AutoencoderKL, split_size: int = 1, vae
48
  if channels != 3:
49
  raise ValueError(f"Expects tensors with 3 channels, got {channels}.")
50
 
51
- if is_video_shaped and not isinstance(vae, (VideoAutoencoder, CausalVideoAutoencoder)):
 
 
52
  media_items = rearrange(media_items, "b c n h w -> (b n) c h w")
53
  if split_size > 1:
54
  if len(media_items) % split_size != 0:
55
- raise ValueError("Error: The batch size must be divisible by 'train.vae_bs_split")
 
 
56
  encode_bs = len(media_items) // split_size
57
  # latents = [vae.encode(image_batch).latent_dist.sample() for image_batch in media_items.split(encode_bs)]
58
  latents = []
@@ -67,22 +78,32 @@ def vae_encode(media_items: Tensor, vae: AutoencoderKL, split_size: int = 1, vae
67
  latents = vae.encode(media_items).latent_dist.sample()
68
 
69
  latents = normalize_latents(latents, vae, vae_per_channel_normalize)
70
- if is_video_shaped and not isinstance(vae, (VideoAutoencoder, CausalVideoAutoencoder)):
 
 
71
  latents = rearrange(latents, "(b n) c h w -> b c n h w", b=batch_size)
72
  return latents
73
 
74
 
75
  def vae_decode(
76
- latents: Tensor, vae: AutoencoderKL, is_video: bool = True, split_size: int = 1, vae_per_channel_normalize=False
 
 
 
 
77
  ) -> Tensor:
78
  is_video_shaped = latents.dim() == 5
79
  batch_size = latents.shape[0]
80
 
81
- if is_video_shaped and not isinstance(vae, (VideoAutoencoder, CausalVideoAutoencoder)):
 
 
82
  latents = rearrange(latents, "b c n h w -> (b n) c h w")
83
  if split_size > 1:
84
  if len(latents) % split_size != 0:
85
- raise ValueError("Error: The batch size must be divisible by 'train.vae_bs_split")
 
 
86
  encode_bs = len(latents) // split_size
87
  image_batch = [
88
  _run_decoder(latent_batch, vae, is_video, vae_per_channel_normalize)
@@ -92,12 +113,16 @@ def vae_decode(
92
  else:
93
  images = _run_decoder(latents, vae, is_video, vae_per_channel_normalize)
94
 
95
- if is_video_shaped and not isinstance(vae, (VideoAutoencoder, CausalVideoAutoencoder)):
 
 
96
  images = rearrange(images, "(b n) c h w -> b c n h w", b=batch_size)
97
  return images
98
 
99
 
100
- def _run_decoder(latents: Tensor, vae: AutoencoderKL, is_video: bool, vae_per_channel_normalize=False) -> Tensor:
 
 
101
  if isinstance(vae, (VideoAutoencoder, CausalVideoAutoencoder)):
102
  *_, fl, hl, wl = latents.shape
103
  temporal_scale, spatial_scale, _ = get_vae_size_scale_factor(vae)
@@ -105,7 +130,13 @@ def _run_decoder(latents: Tensor, vae: AutoencoderKL, is_video: bool, vae_per_ch
105
  image = vae.decode(
106
  un_normalize_latents(latents, vae, vae_per_channel_normalize),
107
  return_dict=False,
108
- target_shape=(1, 3, fl * temporal_scale if is_video else 1, hl * spatial_scale, wl * spatial_scale),
 
 
 
 
 
 
109
  )[0]
110
  else:
111
  image = vae.decode(
@@ -120,14 +151,26 @@ def get_vae_size_scale_factor(vae: AutoencoderKL) -> float:
120
  spatial = vae.spatial_downscale_factor
121
  temporal = vae.temporal_downscale_factor
122
  else:
123
- down_blocks = len([block for block in vae.encoder.down_blocks if isinstance(block.downsample, Downsample3D)])
 
 
 
 
 
 
124
  spatial = vae.config.patch_size * 2**down_blocks
125
- temporal = vae.config.patch_size_t * 2 ** down_blocks if isinstance(vae, VideoAutoencoder) else 1
 
 
 
 
126
 
127
  return (temporal, spatial, spatial)
128
 
129
 
130
- def normalize_latents(latents: Tensor, vae: AutoencoderKL, vae_per_channel_normalize: bool = False) -> Tensor:
 
 
131
  return (
132
  (latents - vae.mean_of_means.to(latents.dtype).view(1, -1, 1, 1, 1))
133
  / vae.std_of_means.to(latents.dtype).view(1, -1, 1, 1, 1)
@@ -136,10 +179,12 @@ def normalize_latents(latents: Tensor, vae: AutoencoderKL, vae_per_channel_norma
136
  )
137
 
138
 
139
- def un_normalize_latents(latents: Tensor, vae: AutoencoderKL, vae_per_channel_normalize: bool = False) -> Tensor:
 
 
140
  return (
141
  latents * vae.std_of_means.to(latents.dtype).view(1, -1, 1, 1, 1)
142
  + vae.mean_of_means.to(latents.dtype).view(1, -1, 1, 1, 1)
143
  if vae_per_channel_normalize
144
  else latents / vae.config.scaling_factor
145
- )
 
6
 
7
  from xora.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder
8
  from xora.models.autoencoders.video_autoencoder import Downsample3D, VideoAutoencoder
9
+
10
  try:
11
  import torch_xla.core.xla_model as xm
12
+ except ImportError:
13
+ xm = None
14
+
15
 
16
+ def vae_encode(
17
+ media_items: Tensor,
18
+ vae: AutoencoderKL,
19
+ split_size: int = 1,
20
+ vae_per_channel_normalize=False,
21
+ ) -> Tensor:
22
  """
23
  Encodes media items (images or videos) into latent representations using a specified VAE model.
24
  The function supports processing batches of images or video frames and can handle the processing
 
55
  if channels != 3:
56
  raise ValueError(f"Expects tensors with 3 channels, got {channels}.")
57
 
58
+ if is_video_shaped and not isinstance(
59
+ vae, (VideoAutoencoder, CausalVideoAutoencoder)
60
+ ):
61
  media_items = rearrange(media_items, "b c n h w -> (b n) c h w")
62
  if split_size > 1:
63
  if len(media_items) % split_size != 0:
64
+ raise ValueError(
65
+ "Error: The batch size must be divisible by 'train.vae_bs_split"
66
+ )
67
  encode_bs = len(media_items) // split_size
68
  # latents = [vae.encode(image_batch).latent_dist.sample() for image_batch in media_items.split(encode_bs)]
69
  latents = []
 
78
  latents = vae.encode(media_items).latent_dist.sample()
79
 
80
  latents = normalize_latents(latents, vae, vae_per_channel_normalize)
81
+ if is_video_shaped and not isinstance(
82
+ vae, (VideoAutoencoder, CausalVideoAutoencoder)
83
+ ):
84
  latents = rearrange(latents, "(b n) c h w -> b c n h w", b=batch_size)
85
  return latents
86
 
87
 
88
  def vae_decode(
89
+ latents: Tensor,
90
+ vae: AutoencoderKL,
91
+ is_video: bool = True,
92
+ split_size: int = 1,
93
+ vae_per_channel_normalize=False,
94
  ) -> Tensor:
95
  is_video_shaped = latents.dim() == 5
96
  batch_size = latents.shape[0]
97
 
98
+ if is_video_shaped and not isinstance(
99
+ vae, (VideoAutoencoder, CausalVideoAutoencoder)
100
+ ):
101
  latents = rearrange(latents, "b c n h w -> (b n) c h w")
102
  if split_size > 1:
103
  if len(latents) % split_size != 0:
104
+ raise ValueError(
105
+ "Error: The batch size must be divisible by 'train.vae_bs_split"
106
+ )
107
  encode_bs = len(latents) // split_size
108
  image_batch = [
109
  _run_decoder(latent_batch, vae, is_video, vae_per_channel_normalize)
 
113
  else:
114
  images = _run_decoder(latents, vae, is_video, vae_per_channel_normalize)
115
 
116
+ if is_video_shaped and not isinstance(
117
+ vae, (VideoAutoencoder, CausalVideoAutoencoder)
118
+ ):
119
  images = rearrange(images, "(b n) c h w -> b c n h w", b=batch_size)
120
  return images
121
 
122
 
123
+ def _run_decoder(
124
+ latents: Tensor, vae: AutoencoderKL, is_video: bool, vae_per_channel_normalize=False
125
+ ) -> Tensor:
126
  if isinstance(vae, (VideoAutoencoder, CausalVideoAutoencoder)):
127
  *_, fl, hl, wl = latents.shape
128
  temporal_scale, spatial_scale, _ = get_vae_size_scale_factor(vae)
 
130
  image = vae.decode(
131
  un_normalize_latents(latents, vae, vae_per_channel_normalize),
132
  return_dict=False,
133
+ target_shape=(
134
+ 1,
135
+ 3,
136
+ fl * temporal_scale if is_video else 1,
137
+ hl * spatial_scale,
138
+ wl * spatial_scale,
139
+ ),
140
  )[0]
141
  else:
142
  image = vae.decode(
 
151
  spatial = vae.spatial_downscale_factor
152
  temporal = vae.temporal_downscale_factor
153
  else:
154
+ down_blocks = len(
155
+ [
156
+ block
157
+ for block in vae.encoder.down_blocks
158
+ if isinstance(block.downsample, Downsample3D)
159
+ ]
160
+ )
161
  spatial = vae.config.patch_size * 2**down_blocks
162
+ temporal = (
163
+ vae.config.patch_size_t * 2**down_blocks
164
+ if isinstance(vae, VideoAutoencoder)
165
+ else 1
166
+ )
167
 
168
  return (temporal, spatial, spatial)
169
 
170
 
171
+ def normalize_latents(
172
+ latents: Tensor, vae: AutoencoderKL, vae_per_channel_normalize: bool = False
173
+ ) -> Tensor:
174
  return (
175
  (latents - vae.mean_of_means.to(latents.dtype).view(1, -1, 1, 1, 1))
176
  / vae.std_of_means.to(latents.dtype).view(1, -1, 1, 1, 1)
 
179
  )
180
 
181
 
182
+ def un_normalize_latents(
183
+ latents: Tensor, vae: AutoencoderKL, vae_per_channel_normalize: bool = False
184
+ ) -> Tensor:
185
  return (
186
  latents * vae.std_of_means.to(latents.dtype).view(1, -1, 1, 1, 1)
187
  + vae.mean_of_means.to(latents.dtype).view(1, -1, 1, 1, 1)
188
  if vae_per_channel_normalize
189
  else latents / vae.config.scaling_factor
190
+ )
xora/models/autoencoders/video_autoencoder.py CHANGED
@@ -21,7 +21,12 @@ logger = logging.get_logger(__name__)
21
 
22
  class VideoAutoencoder(AutoencoderKLWrapper):
23
  @classmethod
24
- def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *args, **kwargs):
 
 
 
 
 
25
  config_local_path = pretrained_model_name_or_path / "config.json"
26
  config = cls.load_config(config_local_path, **kwargs)
27
  video_vae = cls.from_config(config)
@@ -31,29 +36,41 @@ class VideoAutoencoder(AutoencoderKLWrapper):
31
  ckpt_state_dict = torch.load(model_local_path)
32
  video_vae.load_state_dict(ckpt_state_dict)
33
 
34
- statistics_local_path = pretrained_model_name_or_path / "per_channel_statistics.json"
 
 
35
  if statistics_local_path.exists():
36
  with open(statistics_local_path, "r") as file:
37
  data = json.load(file)
38
  transposed_data = list(zip(*data["data"]))
39
- data_dict = {col: torch.tensor(vals) for col, vals in zip(data["columns"], transposed_data)}
 
 
 
40
  video_vae.register_buffer("std_of_means", data_dict["std-of-means"])
41
  video_vae.register_buffer(
42
- "mean_of_means", data_dict.get("mean-of-means", torch.zeros_like(data_dict["std-of-means"]))
 
 
 
43
  )
44
 
45
  return video_vae
46
 
47
  @staticmethod
48
  def from_config(config):
49
- assert config["_class_name"] == "VideoAutoencoder", "config must have _class_name=VideoAutoencoder"
 
 
50
  if isinstance(config["dims"], list):
51
  config["dims"] = tuple(config["dims"])
52
 
53
  assert config["dims"] in [2, 3, (2, 1)], "dims must be 2, 3 or (2, 1)"
54
 
55
  double_z = config.get("double_z", True)
56
- latent_log_var = config.get("latent_log_var", "per_channel" if double_z else "none")
 
 
57
  use_quant_conv = config.get("use_quant_conv", True)
58
 
59
  if use_quant_conv and latent_log_var == "uniform":
@@ -96,8 +113,10 @@ class VideoAutoencoder(AutoencoderKLWrapper):
96
  return SimpleNamespace(
97
  _class_name="VideoAutoencoder",
98
  dims=self.dims,
99
- in_channels=self.encoder.conv_in.in_channels // (self.encoder.patch_size_t * self.encoder.patch_size**2),
100
- out_channels=self.decoder.conv_out.out_channels // (self.decoder.patch_size_t * self.decoder.patch_size**2),
 
 
101
  latent_channels=self.decoder.conv_in.in_channels,
102
  block_out_channels=[
103
  self.encoder.down_blocks[i].res_blocks[-1].conv1.out_channels
@@ -143,7 +162,9 @@ class VideoAutoencoder(AutoencoderKLWrapper):
143
  key = key.replace(k, v)
144
 
145
  if "norm" in key and key not in model_keys:
146
- logger.info(f"Removing key {key} from state_dict as it is not present in the model")
 
 
147
  continue
148
 
149
  converted_state_dict[key] = value
@@ -253,7 +274,11 @@ class Encoder(nn.Module):
253
 
254
  # out
255
  if norm_layer == "group_norm":
256
- self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6)
 
 
 
 
257
  elif norm_layer == "pixel_norm":
258
  self.conv_norm_out = PixelNorm()
259
  self.conv_act = nn.SiLU()
@@ -265,14 +290,23 @@ class Encoder(nn.Module):
265
  conv_out_channels += 1
266
  elif latent_log_var != "none":
267
  raise ValueError(f"Invalid latent_log_var: {latent_log_var}")
268
- self.conv_out = make_conv_nd(dims, block_out_channels[-1], conv_out_channels, 3, padding=1)
 
 
269
 
270
  self.gradient_checkpointing = False
271
 
272
  @property
273
  def downscale_factor(self):
274
  return (
275
- 2 ** len([block for block in self.down_blocks if isinstance(block.downsample, Downsample3D)])
 
 
 
 
 
 
 
276
  * self.patch_size
277
  )
278
 
@@ -299,7 +333,9 @@ class Encoder(nn.Module):
299
  )
300
 
301
  for down_block in self.down_blocks:
302
- sample = checkpoint_fn(down_block)(sample, downsample_in_time=downsample_in_time)
 
 
303
 
304
  sample = checkpoint_fn(self.mid_block)(sample)
305
 
@@ -314,11 +350,15 @@ class Encoder(nn.Module):
314
 
315
  if num_dims == 4:
316
  # For shape (B, C, H, W)
317
- repeated_last_channel = last_channel.repeat(1, sample.shape[1] - 2, 1, 1)
 
 
318
  sample = torch.cat([sample, repeated_last_channel], dim=1)
319
  elif num_dims == 5:
320
  # For shape (B, C, F, H, W)
321
- repeated_last_channel = last_channel.repeat(1, sample.shape[1] - 2, 1, 1, 1)
 
 
322
  sample = torch.cat([sample, repeated_last_channel], dim=1)
323
  else:
324
  raise ValueError(f"Invalid input shape: {sample.shape}")
@@ -405,7 +445,8 @@ class Decoder(nn.Module):
405
  num_layers=self.layers_per_block + 1,
406
  in_channels=prev_output_channel,
407
  out_channels=output_channel,
408
- add_upsample=not is_final_block and 2 ** (len(block_out_channels) - i - 1) > patch_size,
 
409
  resnet_eps=1e-6,
410
  resnet_groups=norm_num_groups,
411
  norm_layer=norm_layer,
@@ -413,12 +454,16 @@ class Decoder(nn.Module):
413
  self.up_blocks.append(up_block)
414
 
415
  if norm_layer == "group_norm":
416
- self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
 
 
417
  elif norm_layer == "pixel_norm":
418
  self.conv_norm_out = PixelNorm()
419
 
420
  self.conv_act = nn.SiLU()
421
- self.conv_out = make_conv_nd(dims, block_out_channels[0], out_channels, 3, padding=1)
 
 
422
 
423
  self.gradient_checkpointing = False
424
 
@@ -494,15 +539,24 @@ class DownEncoderBlock3D(nn.Module):
494
  self.res_blocks = nn.ModuleList(res_blocks)
495
 
496
  if add_downsample:
497
- self.downsample = Downsample3D(dims, out_channels, out_channels=out_channels, padding=downsample_padding)
 
 
 
 
 
498
  else:
499
  self.downsample = Identity()
500
 
501
- def forward(self, hidden_states: torch.FloatTensor, downsample_in_time) -> torch.FloatTensor:
 
 
502
  for resnet in self.res_blocks:
503
  hidden_states = resnet(hidden_states)
504
 
505
- hidden_states = self.downsample(hidden_states, downsample_in_time=downsample_in_time)
 
 
506
 
507
  return hidden_states
508
 
@@ -536,7 +590,9 @@ class UNetMidBlock3D(nn.Module):
536
  norm_layer: str = "group_norm",
537
  ):
538
  super().__init__()
539
- resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
 
 
540
 
541
  self.res_blocks = nn.ModuleList(
542
  [
@@ -595,13 +651,17 @@ class UpDecoderBlock3D(nn.Module):
595
  self.res_blocks = nn.ModuleList(res_blocks)
596
 
597
  if add_upsample:
598
- self.upsample = Upsample3D(dims=dims, channels=out_channels, out_channels=out_channels)
 
 
599
  else:
600
  self.upsample = Identity()
601
 
602
  self.resolution_idx = resolution_idx
603
 
604
- def forward(self, hidden_states: torch.FloatTensor, upsample_in_time=True) -> torch.FloatTensor:
 
 
605
  for resnet in self.res_blocks:
606
  hidden_states = resnet(hidden_states)
607
 
@@ -641,25 +701,35 @@ class ResnetBlock3D(nn.Module):
641
  self.use_conv_shortcut = conv_shortcut
642
 
643
  if norm_layer == "group_norm":
644
- self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
 
 
645
  elif norm_layer == "pixel_norm":
646
  self.norm1 = PixelNorm()
647
 
648
  self.non_linearity = nn.SiLU()
649
 
650
- self.conv1 = make_conv_nd(dims, in_channels, out_channels, kernel_size=3, stride=1, padding=1)
 
 
651
 
652
  if norm_layer == "group_norm":
653
- self.norm2 = torch.nn.GroupNorm(num_groups=groups, num_channels=out_channels, eps=eps, affine=True)
 
 
654
  elif norm_layer == "pixel_norm":
655
  self.norm2 = PixelNorm()
656
 
657
  self.dropout = torch.nn.Dropout(dropout)
658
 
659
- self.conv2 = make_conv_nd(dims, out_channels, out_channels, kernel_size=3, stride=1, padding=1)
 
 
660
 
661
  self.conv_shortcut = (
662
- make_linear_nd(dims=dims, in_channels=in_channels, out_channels=out_channels)
 
 
663
  if in_channels != out_channels
664
  else nn.Identity()
665
  )
@@ -692,7 +762,14 @@ class ResnetBlock3D(nn.Module):
692
 
693
 
694
  class Downsample3D(nn.Module):
695
- def __init__(self, dims, in_channels: int, out_channels: int, kernel_size: int = 3, padding: int = 1):
 
 
 
 
 
 
 
696
  super().__init__()
697
  stride: int = 2
698
  self.padding = padding
@@ -735,18 +812,24 @@ class Upsample3D(nn.Module):
735
  self.dims = dims
736
  self.channels = channels
737
  self.out_channels = out_channels or channels
738
- self.conv = make_conv_nd(dims, channels, out_channels, kernel_size=3, padding=1, bias=True)
 
 
739
 
740
  def forward(self, x, upsample_in_time):
741
  if self.dims == 2:
742
- x = functional.interpolate(x, (x.shape[2] * 2, x.shape[3] * 2), mode="nearest")
 
 
743
  else:
744
  time_scale_factor = 2 if upsample_in_time else 1
745
  # print("before:", x.shape)
746
  b, c, d, h, w = x.shape
747
  x = rearrange(x, "b c d h w -> (b d) c h w")
748
  # height and width interpolate
749
- x = functional.interpolate(x, (x.shape[2] * 2, x.shape[3] * 2), mode="nearest")
 
 
750
  _, _, h, w = x.shape
751
 
752
  if not upsample_in_time and self.dims == (2, 1):
@@ -760,7 +843,9 @@ class Upsample3D(nn.Module):
760
  new_d = x.shape[-1] * time_scale_factor
761
  x = functional.interpolate(x, (1, new_d), mode="nearest")
762
  # (b h w) c 1 new_d
763
- x = rearrange(x, "(b h w) c 1 new_d -> b c new_d h w", b=b, h=h, w=w, new_d=new_d)
 
 
764
  # b c d h w
765
 
766
  # x = functional.interpolate(
@@ -775,13 +860,25 @@ def patchify(x, patch_size_hw, patch_size_t=1, add_channel_padding=False):
775
  if patch_size_hw == 1 and patch_size_t == 1:
776
  return x
777
  if x.dim() == 4:
778
- x = rearrange(x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size_hw, r=patch_size_hw)
 
 
779
  elif x.dim() == 5:
780
- x = rearrange(x, "b c (f p) (h q) (w r) -> b (c p r q) f h w", p=patch_size_t, q=patch_size_hw, r=patch_size_hw)
 
 
 
 
 
 
781
  else:
782
  raise ValueError(f"Invalid input shape: {x.shape}")
783
 
784
- if (x.dim() == 5) and (patch_size_hw > patch_size_t) and (patch_size_t > 1 or add_channel_padding):
 
 
 
 
785
  channels_to_pad = x.shape[1] * (patch_size_hw // patch_size_t) - x.shape[1]
786
  padding_zeros = torch.zeros(
787
  x.shape[0],
@@ -801,14 +898,26 @@ def unpatchify(x, patch_size_hw, patch_size_t=1, add_channel_padding=False):
801
  if patch_size_hw == 1 and patch_size_t == 1:
802
  return x
803
 
804
- if (x.dim() == 5) and (patch_size_hw > patch_size_t) and (patch_size_t > 1 or add_channel_padding):
 
 
 
 
805
  channels_to_keep = int(x.shape[1] * (patch_size_t / patch_size_hw))
806
  x = x[:, :channels_to_keep, :, :, :]
807
 
808
  if x.dim() == 4:
809
- x = rearrange(x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size_hw, r=patch_size_hw)
 
 
810
  elif x.dim() == 5:
811
- x = rearrange(x, "b (c p r q) f h w -> b c (f p) (h q) (w r)", p=patch_size_t, q=patch_size_hw, r=patch_size_hw)
 
 
 
 
 
 
812
 
813
  return x
814
 
@@ -818,11 +927,19 @@ def create_video_autoencoder_config(
818
  ):
819
  config = {
820
  "_class_name": "VideoAutoencoder",
821
- "dims": (2, 1), # 2 for Conv2, 3 for Conv3d, (2, 1) for Conv2d followed by Conv1d
 
 
 
822
  "in_channels": 3, # Number of input color channels (e.g., RGB)
823
  "out_channels": 3, # Number of output color channels
824
  "latent_channels": latent_channels, # Number of channels in the latent space representation
825
- "block_out_channels": [128, 256, 512, 512], # Number of output channels of each encoder / decoder inner block
 
 
 
 
 
826
  "patch_size": 1,
827
  }
828
 
@@ -834,11 +951,15 @@ def create_video_autoencoder_pathify4x4x4_config(
834
  ):
835
  config = {
836
  "_class_name": "VideoAutoencoder",
837
- "dims": (2, 1), # 2 for Conv2, 3 for Conv3d, (2, 1) for Conv2d followed by Conv1d
 
 
 
838
  "in_channels": 3, # Number of input color channels (e.g., RGB)
839
  "out_channels": 3, # Number of output color channels
840
  "latent_channels": latent_channels, # Number of channels in the latent space representation
841
- "block_out_channels": [512] * 4, # Number of output channels of each encoder / decoder inner block
 
842
  "patch_size": 4,
843
  "latent_log_var": "uniform",
844
  }
@@ -855,7 +976,8 @@ def create_video_autoencoder_pathify4x4_config(
855
  "in_channels": 3, # Number of input color channels (e.g., RGB)
856
  "out_channels": 3, # Number of output color channels
857
  "latent_channels": latent_channels, # Number of channels in the latent space representation
858
- "block_out_channels": [512] * 4, # Number of output channels of each encoder / decoder inner block
 
859
  "patch_size": 4,
860
  "norm_layer": "pixel_norm",
861
  }
@@ -894,7 +1016,9 @@ def demo_video_autoencoder_forward_backward():
894
  latent = video_autoencoder.encode(input_videos).latent_dist.mode()
895
  print(f"input shape={input_videos.shape}")
896
  print(f"latent shape={latent.shape}")
897
- reconstructed_videos = video_autoencoder.decode(latent, target_shape=input_videos.shape).sample
 
 
898
 
899
  print(f"reconstructed shape={reconstructed_videos.shape}")
900
 
 
21
 
22
  class VideoAutoencoder(AutoencoderKLWrapper):
23
  @classmethod
24
+ def from_pretrained(
25
+ cls,
26
+ pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
27
+ *args,
28
+ **kwargs,
29
+ ):
30
  config_local_path = pretrained_model_name_or_path / "config.json"
31
  config = cls.load_config(config_local_path, **kwargs)
32
  video_vae = cls.from_config(config)
 
36
  ckpt_state_dict = torch.load(model_local_path)
37
  video_vae.load_state_dict(ckpt_state_dict)
38
 
39
+ statistics_local_path = (
40
+ pretrained_model_name_or_path / "per_channel_statistics.json"
41
+ )
42
  if statistics_local_path.exists():
43
  with open(statistics_local_path, "r") as file:
44
  data = json.load(file)
45
  transposed_data = list(zip(*data["data"]))
46
+ data_dict = {
47
+ col: torch.tensor(vals)
48
+ for col, vals in zip(data["columns"], transposed_data)
49
+ }
50
  video_vae.register_buffer("std_of_means", data_dict["std-of-means"])
51
  video_vae.register_buffer(
52
+ "mean_of_means",
53
+ data_dict.get(
54
+ "mean-of-means", torch.zeros_like(data_dict["std-of-means"])
55
+ ),
56
  )
57
 
58
  return video_vae
59
 
60
  @staticmethod
61
  def from_config(config):
62
+ assert (
63
+ config["_class_name"] == "VideoAutoencoder"
64
+ ), "config must have _class_name=VideoAutoencoder"
65
  if isinstance(config["dims"], list):
66
  config["dims"] = tuple(config["dims"])
67
 
68
  assert config["dims"] in [2, 3, (2, 1)], "dims must be 2, 3 or (2, 1)"
69
 
70
  double_z = config.get("double_z", True)
71
+ latent_log_var = config.get(
72
+ "latent_log_var", "per_channel" if double_z else "none"
73
+ )
74
  use_quant_conv = config.get("use_quant_conv", True)
75
 
76
  if use_quant_conv and latent_log_var == "uniform":
 
113
  return SimpleNamespace(
114
  _class_name="VideoAutoencoder",
115
  dims=self.dims,
116
+ in_channels=self.encoder.conv_in.in_channels
117
+ // (self.encoder.patch_size_t * self.encoder.patch_size**2),
118
+ out_channels=self.decoder.conv_out.out_channels
119
+ // (self.decoder.patch_size_t * self.decoder.patch_size**2),
120
  latent_channels=self.decoder.conv_in.in_channels,
121
  block_out_channels=[
122
  self.encoder.down_blocks[i].res_blocks[-1].conv1.out_channels
 
162
  key = key.replace(k, v)
163
 
164
  if "norm" in key and key not in model_keys:
165
+ logger.info(
166
+ f"Removing key {key} from state_dict as it is not present in the model"
167
+ )
168
  continue
169
 
170
  converted_state_dict[key] = value
 
274
 
275
  # out
276
  if norm_layer == "group_norm":
277
+ self.conv_norm_out = nn.GroupNorm(
278
+ num_channels=block_out_channels[-1],
279
+ num_groups=norm_num_groups,
280
+ eps=1e-6,
281
+ )
282
  elif norm_layer == "pixel_norm":
283
  self.conv_norm_out = PixelNorm()
284
  self.conv_act = nn.SiLU()
 
290
  conv_out_channels += 1
291
  elif latent_log_var != "none":
292
  raise ValueError(f"Invalid latent_log_var: {latent_log_var}")
293
+ self.conv_out = make_conv_nd(
294
+ dims, block_out_channels[-1], conv_out_channels, 3, padding=1
295
+ )
296
 
297
  self.gradient_checkpointing = False
298
 
299
  @property
300
  def downscale_factor(self):
301
  return (
302
+ 2
303
+ ** len(
304
+ [
305
+ block
306
+ for block in self.down_blocks
307
+ if isinstance(block.downsample, Downsample3D)
308
+ ]
309
+ )
310
  * self.patch_size
311
  )
312
 
 
333
  )
334
 
335
  for down_block in self.down_blocks:
336
+ sample = checkpoint_fn(down_block)(
337
+ sample, downsample_in_time=downsample_in_time
338
+ )
339
 
340
  sample = checkpoint_fn(self.mid_block)(sample)
341
 
 
350
 
351
  if num_dims == 4:
352
  # For shape (B, C, H, W)
353
+ repeated_last_channel = last_channel.repeat(
354
+ 1, sample.shape[1] - 2, 1, 1
355
+ )
356
  sample = torch.cat([sample, repeated_last_channel], dim=1)
357
  elif num_dims == 5:
358
  # For shape (B, C, F, H, W)
359
+ repeated_last_channel = last_channel.repeat(
360
+ 1, sample.shape[1] - 2, 1, 1, 1
361
+ )
362
  sample = torch.cat([sample, repeated_last_channel], dim=1)
363
  else:
364
  raise ValueError(f"Invalid input shape: {sample.shape}")
 
445
  num_layers=self.layers_per_block + 1,
446
  in_channels=prev_output_channel,
447
  out_channels=output_channel,
448
+ add_upsample=not is_final_block
449
+ and 2 ** (len(block_out_channels) - i - 1) > patch_size,
450
  resnet_eps=1e-6,
451
  resnet_groups=norm_num_groups,
452
  norm_layer=norm_layer,
 
454
  self.up_blocks.append(up_block)
455
 
456
  if norm_layer == "group_norm":
457
+ self.conv_norm_out = nn.GroupNorm(
458
+ num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6
459
+ )
460
  elif norm_layer == "pixel_norm":
461
  self.conv_norm_out = PixelNorm()
462
 
463
  self.conv_act = nn.SiLU()
464
+ self.conv_out = make_conv_nd(
465
+ dims, block_out_channels[0], out_channels, 3, padding=1
466
+ )
467
 
468
  self.gradient_checkpointing = False
469
 
 
539
  self.res_blocks = nn.ModuleList(res_blocks)
540
 
541
  if add_downsample:
542
+ self.downsample = Downsample3D(
543
+ dims,
544
+ out_channels,
545
+ out_channels=out_channels,
546
+ padding=downsample_padding,
547
+ )
548
  else:
549
  self.downsample = Identity()
550
 
551
+ def forward(
552
+ self, hidden_states: torch.FloatTensor, downsample_in_time
553
+ ) -> torch.FloatTensor:
554
  for resnet in self.res_blocks:
555
  hidden_states = resnet(hidden_states)
556
 
557
+ hidden_states = self.downsample(
558
+ hidden_states, downsample_in_time=downsample_in_time
559
+ )
560
 
561
  return hidden_states
562
 
 
590
  norm_layer: str = "group_norm",
591
  ):
592
  super().__init__()
593
+ resnet_groups = (
594
+ resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
595
+ )
596
 
597
  self.res_blocks = nn.ModuleList(
598
  [
 
651
  self.res_blocks = nn.ModuleList(res_blocks)
652
 
653
  if add_upsample:
654
+ self.upsample = Upsample3D(
655
+ dims=dims, channels=out_channels, out_channels=out_channels
656
+ )
657
  else:
658
  self.upsample = Identity()
659
 
660
  self.resolution_idx = resolution_idx
661
 
662
+ def forward(
663
+ self, hidden_states: torch.FloatTensor, upsample_in_time=True
664
+ ) -> torch.FloatTensor:
665
  for resnet in self.res_blocks:
666
  hidden_states = resnet(hidden_states)
667
 
 
701
  self.use_conv_shortcut = conv_shortcut
702
 
703
  if norm_layer == "group_norm":
704
+ self.norm1 = torch.nn.GroupNorm(
705
+ num_groups=groups, num_channels=in_channels, eps=eps, affine=True
706
+ )
707
  elif norm_layer == "pixel_norm":
708
  self.norm1 = PixelNorm()
709
 
710
  self.non_linearity = nn.SiLU()
711
 
712
+ self.conv1 = make_conv_nd(
713
+ dims, in_channels, out_channels, kernel_size=3, stride=1, padding=1
714
+ )
715
 
716
  if norm_layer == "group_norm":
717
+ self.norm2 = torch.nn.GroupNorm(
718
+ num_groups=groups, num_channels=out_channels, eps=eps, affine=True
719
+ )
720
  elif norm_layer == "pixel_norm":
721
  self.norm2 = PixelNorm()
722
 
723
  self.dropout = torch.nn.Dropout(dropout)
724
 
725
+ self.conv2 = make_conv_nd(
726
+ dims, out_channels, out_channels, kernel_size=3, stride=1, padding=1
727
+ )
728
 
729
  self.conv_shortcut = (
730
+ make_linear_nd(
731
+ dims=dims, in_channels=in_channels, out_channels=out_channels
732
+ )
733
  if in_channels != out_channels
734
  else nn.Identity()
735
  )
 
762
 
763
 
764
  class Downsample3D(nn.Module):
765
+ def __init__(
766
+ self,
767
+ dims,
768
+ in_channels: int,
769
+ out_channels: int,
770
+ kernel_size: int = 3,
771
+ padding: int = 1,
772
+ ):
773
  super().__init__()
774
  stride: int = 2
775
  self.padding = padding
 
812
  self.dims = dims
813
  self.channels = channels
814
  self.out_channels = out_channels or channels
815
+ self.conv = make_conv_nd(
816
+ dims, channels, out_channels, kernel_size=3, padding=1, bias=True
817
+ )
818
 
819
  def forward(self, x, upsample_in_time):
820
  if self.dims == 2:
821
+ x = functional.interpolate(
822
+ x, (x.shape[2] * 2, x.shape[3] * 2), mode="nearest"
823
+ )
824
  else:
825
  time_scale_factor = 2 if upsample_in_time else 1
826
  # print("before:", x.shape)
827
  b, c, d, h, w = x.shape
828
  x = rearrange(x, "b c d h w -> (b d) c h w")
829
  # height and width interpolate
830
+ x = functional.interpolate(
831
+ x, (x.shape[2] * 2, x.shape[3] * 2), mode="nearest"
832
+ )
833
  _, _, h, w = x.shape
834
 
835
  if not upsample_in_time and self.dims == (2, 1):
 
843
  new_d = x.shape[-1] * time_scale_factor
844
  x = functional.interpolate(x, (1, new_d), mode="nearest")
845
  # (b h w) c 1 new_d
846
+ x = rearrange(
847
+ x, "(b h w) c 1 new_d -> b c new_d h w", b=b, h=h, w=w, new_d=new_d
848
+ )
849
  # b c d h w
850
 
851
  # x = functional.interpolate(
 
860
  if patch_size_hw == 1 and patch_size_t == 1:
861
  return x
862
  if x.dim() == 4:
863
+ x = rearrange(
864
+ x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size_hw, r=patch_size_hw
865
+ )
866
  elif x.dim() == 5:
867
+ x = rearrange(
868
+ x,
869
+ "b c (f p) (h q) (w r) -> b (c p r q) f h w",
870
+ p=patch_size_t,
871
+ q=patch_size_hw,
872
+ r=patch_size_hw,
873
+ )
874
  else:
875
  raise ValueError(f"Invalid input shape: {x.shape}")
876
 
877
+ if (
878
+ (x.dim() == 5)
879
+ and (patch_size_hw > patch_size_t)
880
+ and (patch_size_t > 1 or add_channel_padding)
881
+ ):
882
  channels_to_pad = x.shape[1] * (patch_size_hw // patch_size_t) - x.shape[1]
883
  padding_zeros = torch.zeros(
884
  x.shape[0],
 
898
  if patch_size_hw == 1 and patch_size_t == 1:
899
  return x
900
 
901
+ if (
902
+ (x.dim() == 5)
903
+ and (patch_size_hw > patch_size_t)
904
+ and (patch_size_t > 1 or add_channel_padding)
905
+ ):
906
  channels_to_keep = int(x.shape[1] * (patch_size_t / patch_size_hw))
907
  x = x[:, :channels_to_keep, :, :, :]
908
 
909
  if x.dim() == 4:
910
+ x = rearrange(
911
+ x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size_hw, r=patch_size_hw
912
+ )
913
  elif x.dim() == 5:
914
+ x = rearrange(
915
+ x,
916
+ "b (c p r q) f h w -> b c (f p) (h q) (w r)",
917
+ p=patch_size_t,
918
+ q=patch_size_hw,
919
+ r=patch_size_hw,
920
+ )
921
 
922
  return x
923
 
 
927
  ):
928
  config = {
929
  "_class_name": "VideoAutoencoder",
930
+ "dims": (
931
+ 2,
932
+ 1,
933
+ ), # 2 for Conv2, 3 for Conv3d, (2, 1) for Conv2d followed by Conv1d
934
  "in_channels": 3, # Number of input color channels (e.g., RGB)
935
  "out_channels": 3, # Number of output color channels
936
  "latent_channels": latent_channels, # Number of channels in the latent space representation
937
+ "block_out_channels": [
938
+ 128,
939
+ 256,
940
+ 512,
941
+ 512,
942
+ ], # Number of output channels of each encoder / decoder inner block
943
  "patch_size": 1,
944
  }
945
 
 
951
  ):
952
  config = {
953
  "_class_name": "VideoAutoencoder",
954
+ "dims": (
955
+ 2,
956
+ 1,
957
+ ), # 2 for Conv2, 3 for Conv3d, (2, 1) for Conv2d followed by Conv1d
958
  "in_channels": 3, # Number of input color channels (e.g., RGB)
959
  "out_channels": 3, # Number of output color channels
960
  "latent_channels": latent_channels, # Number of channels in the latent space representation
961
+ "block_out_channels": [512]
962
+ * 4, # Number of output channels of each encoder / decoder inner block
963
  "patch_size": 4,
964
  "latent_log_var": "uniform",
965
  }
 
976
  "in_channels": 3, # Number of input color channels (e.g., RGB)
977
  "out_channels": 3, # Number of output color channels
978
  "latent_channels": latent_channels, # Number of channels in the latent space representation
979
+ "block_out_channels": [512]
980
+ * 4, # Number of output channels of each encoder / decoder inner block
981
  "patch_size": 4,
982
  "norm_layer": "pixel_norm",
983
  }
 
1016
  latent = video_autoencoder.encode(input_videos).latent_dist.mode()
1017
  print(f"input shape={input_videos.shape}")
1018
  print(f"latent shape={latent.shape}")
1019
+ reconstructed_videos = video_autoencoder.decode(
1020
+ latent, target_shape=input_videos.shape
1021
+ ).sample
1022
 
1023
  print(f"reconstructed shape={reconstructed_videos.shape}")
1024
 
xora/models/transformers/attention.py CHANGED
@@ -106,11 +106,15 @@ class BasicTransformerBlock(nn.Module):
106
  assert standardization_norm in ["layer_norm", "rms_norm"]
107
  assert adaptive_norm in ["single_scale_shift", "single_scale", "none"]
108
 
109
- make_norm_layer = nn.LayerNorm if standardization_norm == "layer_norm" else RMSNorm
 
 
110
 
111
  # Define 3 blocks. Each block has its own normalization layer.
112
  # 1. Self-Attn
113
- self.norm1 = make_norm_layer(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
 
 
114
 
115
  self.attn1 = Attention(
116
  query_dim=dim,
@@ -130,7 +134,9 @@ class BasicTransformerBlock(nn.Module):
130
  if cross_attention_dim is not None or double_self_attention:
131
  self.attn2 = Attention(
132
  query_dim=dim,
133
- cross_attention_dim=cross_attention_dim if not double_self_attention else None,
 
 
134
  heads=num_attention_heads,
135
  dim_head=attention_head_dim,
136
  dropout=dropout,
@@ -143,7 +149,9 @@ class BasicTransformerBlock(nn.Module):
143
  ) # is self-attn if encoder_hidden_states is none
144
 
145
  if adaptive_norm == "none":
146
- self.attn2_norm = make_norm_layer(dim, norm_eps, norm_elementwise_affine)
 
 
147
  else:
148
  self.attn2 = None
149
  self.attn2_norm = None
@@ -163,7 +171,9 @@ class BasicTransformerBlock(nn.Module):
163
  # 5. Scale-shift for PixArt-Alpha.
164
  if adaptive_norm != "none":
165
  num_ada_params = 4 if adaptive_norm == "single_scale" else 6
166
- self.scale_shift_table = nn.Parameter(torch.randn(num_ada_params, dim) / dim**0.5)
 
 
167
 
168
  # let chunk size default to None
169
  self._chunk_size = None
@@ -198,7 +208,9 @@ class BasicTransformerBlock(nn.Module):
198
  ) -> torch.FloatTensor:
199
  if cross_attention_kwargs is not None:
200
  if cross_attention_kwargs.get("scale", None) is not None:
201
- logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.")
 
 
202
 
203
  # Notice that normalization is always applied before the real computation in the following blocks.
204
  # 0. Self-Attention
@@ -214,7 +226,9 @@ class BasicTransformerBlock(nn.Module):
214
  batch_size, timestep.shape[1], num_ada_params, -1
215
  )
216
  if self.adaptive_norm == "single_scale_shift":
217
- shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ada_values.unbind(dim=2)
 
 
218
  norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
219
  else:
220
  scale_msa, gate_msa, scale_mlp, gate_mlp = ada_values.unbind(dim=2)
@@ -224,15 +238,21 @@ class BasicTransformerBlock(nn.Module):
224
  else:
225
  raise ValueError(f"Unknown adaptive norm type: {self.adaptive_norm}")
226
 
227
- norm_hidden_states = norm_hidden_states.squeeze(1) # TODO: Check if this is needed
 
 
228
 
229
  # 1. Prepare GLIGEN inputs
230
- cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
 
 
231
 
232
  attn_output = self.attn1(
233
  norm_hidden_states,
234
  freqs_cis=freqs_cis,
235
- encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
 
 
236
  attention_mask=attention_mask,
237
  **cross_attention_kwargs,
238
  )
@@ -271,7 +291,9 @@ class BasicTransformerBlock(nn.Module):
271
 
272
  if self._chunk_size is not None:
273
  # "feed_forward_chunk_size" can be used to save memory
274
- ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
 
 
275
  else:
276
  ff_output = self.ff(norm_hidden_states)
277
  if gate_mlp is not None:
@@ -371,7 +393,9 @@ class Attention(nn.Module):
371
  self.query_dim = query_dim
372
  self.use_bias = bias
373
  self.is_cross_attention = cross_attention_dim is not None
374
- self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
 
 
375
  self.upcast_attention = upcast_attention
376
  self.upcast_softmax = upcast_softmax
377
  self.rescale_output_factor = rescale_output_factor
@@ -416,12 +440,16 @@ class Attention(nn.Module):
416
  )
417
 
418
  if norm_num_groups is not None:
419
- self.group_norm = nn.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True)
 
 
420
  else:
421
  self.group_norm = None
422
 
423
  if spatial_norm_dim is not None:
424
- self.spatial_norm = SpatialNorm(f_channels=query_dim, zq_channels=spatial_norm_dim)
 
 
425
  else:
426
  self.spatial_norm = None
427
 
@@ -441,7 +469,10 @@ class Attention(nn.Module):
441
  norm_cross_num_channels = self.cross_attention_dim
442
 
443
  self.norm_cross = nn.GroupNorm(
444
- num_channels=norm_cross_num_channels, num_groups=cross_attention_norm_num_groups, eps=1e-5, affine=True
 
 
 
445
  )
446
  else:
447
  raise ValueError(
@@ -499,12 +530,16 @@ class Attention(nn.Module):
499
  and isinstance(self.processor, torch.nn.Module)
500
  and not isinstance(processor, torch.nn.Module)
501
  ):
502
- logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}")
 
 
503
  self._modules.pop("processor")
504
 
505
  self.processor = processor
506
 
507
- def get_processor(self, return_deprecated_lora: bool = False) -> "AttentionProcessor": # noqa: F821
 
 
508
  r"""
509
  Get the attention processor in use.
510
 
@@ -542,12 +577,18 @@ class Attention(nn.Module):
542
 
543
  # 3. And we need to merge the current LoRA layers into the corresponding LoRA attention processor
544
  non_lora_processor_cls_name = self.processor.__class__.__name__
545
- lora_processor_cls = getattr(import_module(__name__), "LoRA" + non_lora_processor_cls_name)
 
 
546
 
547
  hidden_size = self.inner_dim
548
 
549
  # now create a LoRA attention processor from the LoRA layers
550
- if lora_processor_cls in [LoRAAttnProcessor, LoRAAttnProcessor2_0, LoRAXFormersAttnProcessor]:
 
 
 
 
551
  kwargs = {
552
  "cross_attention_dim": self.cross_attention_dim,
553
  "rank": self.to_q.lora_layer.rank,
@@ -569,7 +610,9 @@ class Attention(nn.Module):
569
  lora_processor.to_q_lora.load_state_dict(self.to_q.lora_layer.state_dict())
570
  lora_processor.to_k_lora.load_state_dict(self.to_k.lora_layer.state_dict())
571
  lora_processor.to_v_lora.load_state_dict(self.to_v.lora_layer.state_dict())
572
- lora_processor.to_out_lora.load_state_dict(self.to_out[0].lora_layer.state_dict())
 
 
573
  elif lora_processor_cls == LoRAAttnAddedKVProcessor:
574
  lora_processor = lora_processor_cls(
575
  hidden_size,
@@ -580,12 +623,18 @@ class Attention(nn.Module):
580
  lora_processor.to_q_lora.load_state_dict(self.to_q.lora_layer.state_dict())
581
  lora_processor.to_k_lora.load_state_dict(self.to_k.lora_layer.state_dict())
582
  lora_processor.to_v_lora.load_state_dict(self.to_v.lora_layer.state_dict())
583
- lora_processor.to_out_lora.load_state_dict(self.to_out[0].lora_layer.state_dict())
 
 
584
 
585
  # only save if used
586
  if self.add_k_proj.lora_layer is not None:
587
- lora_processor.add_k_proj_lora.load_state_dict(self.add_k_proj.lora_layer.state_dict())
588
- lora_processor.add_v_proj_lora.load_state_dict(self.add_v_proj.lora_layer.state_dict())
 
 
 
 
589
  else:
590
  lora_processor.add_k_proj_lora = None
591
  lora_processor.add_v_proj_lora = None
@@ -622,14 +671,20 @@ class Attention(nn.Module):
622
  # here we simply pass along all tensors to the selected processor class
623
  # For standard processors that are defined here, `**cross_attention_kwargs` is empty
624
 
625
- attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
626
- unused_kwargs = [k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters]
 
 
 
 
627
  if len(unused_kwargs) > 0:
628
  logger.warning(
629
  f"cross_attention_kwargs {unused_kwargs} are not expected by"
630
  f" {self.processor.__class__.__name__} and will be ignored."
631
  )
632
- cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters}
 
 
633
 
634
  return self.processor(
635
  self,
@@ -654,7 +709,9 @@ class Attention(nn.Module):
654
  head_size = self.heads
655
  batch_size, seq_len, dim = tensor.shape
656
  tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
657
- tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
 
 
658
  return tensor
659
 
660
  def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) -> torch.Tensor:
@@ -677,16 +734,23 @@ class Attention(nn.Module):
677
  extra_dim = 1
678
  else:
679
  batch_size, extra_dim, seq_len, dim = tensor.shape
680
- tensor = tensor.reshape(batch_size, seq_len * extra_dim, head_size, dim // head_size)
 
 
681
  tensor = tensor.permute(0, 2, 1, 3)
682
 
683
  if out_dim == 3:
684
- tensor = tensor.reshape(batch_size * head_size, seq_len * extra_dim, dim // head_size)
 
 
685
 
686
  return tensor
687
 
688
  def get_attention_scores(
689
- self, query: torch.Tensor, key: torch.Tensor, attention_mask: torch.Tensor = None
 
 
 
690
  ) -> torch.Tensor:
691
  r"""
692
  Compute the attention scores.
@@ -706,7 +770,11 @@ class Attention(nn.Module):
706
 
707
  if attention_mask is None:
708
  baddbmm_input = torch.empty(
709
- query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device
 
 
 
 
710
  )
711
  beta = 0
712
  else:
@@ -733,7 +801,11 @@ class Attention(nn.Module):
733
  return attention_probs
734
 
735
  def prepare_attention_mask(
736
- self, attention_mask: torch.Tensor, target_length: int, batch_size: int, out_dim: int = 3
 
 
 
 
737
  ) -> torch.Tensor:
738
  r"""
739
  Prepare the attention mask for the attention computation.
@@ -760,8 +832,16 @@ class Attention(nn.Module):
760
  if attention_mask.device.type == "mps":
761
  # HACK: MPS: Does not support padding by greater than dimension of input tensor.
762
  # Instead, we can manually construct the padding tensor.
763
- padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length)
764
- padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device)
 
 
 
 
 
 
 
 
765
  attention_mask = torch.cat([attention_mask, padding], dim=2)
766
  else:
767
  # TODO: for pipelines such as stable-diffusion, padding cross-attn mask:
@@ -779,7 +859,9 @@ class Attention(nn.Module):
779
 
780
  return attention_mask
781
 
782
- def norm_encoder_hidden_states(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor:
 
 
783
  r"""
784
  Normalize the encoder hidden states. Requires `self.norm_cross` to be specified when constructing the
785
  `Attention` class.
@@ -790,7 +872,9 @@ class Attention(nn.Module):
790
  Returns:
791
  `torch.Tensor`: The normalized encoder hidden states.
792
  """
793
- assert self.norm_cross is not None, "self.norm_cross must be defined to call self.norm_encoder_hidden_states"
 
 
794
 
795
  if isinstance(self.norm_cross, nn.LayerNorm):
796
  encoder_hidden_states = self.norm_cross(encoder_hidden_states)
@@ -857,27 +941,39 @@ class AttnProcessor2_0:
857
 
858
  if input_ndim == 4:
859
  batch_size, channel, height, width = hidden_states.shape
860
- hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
 
 
861
 
862
  batch_size, sequence_length, _ = (
863
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
 
 
864
  )
865
 
866
  if (attention_mask is not None) and (not attn.use_tpu_flash_attention):
867
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
 
 
868
  # scaled_dot_product_attention expects attention_mask shape to be
869
  # (batch, heads, source_length, target_length)
870
- attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
 
 
871
 
872
  if attn.group_norm is not None:
873
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
 
 
874
 
875
  query = attn.to_q(hidden_states)
876
  query = attn.q_norm(query)
877
 
878
  if encoder_hidden_states is not None:
879
  if attn.norm_cross:
880
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
 
 
881
  key = attn.to_k(encoder_hidden_states)
882
  key = attn.k_norm(key)
883
  else: # if no context provided do self-attention
@@ -901,10 +997,14 @@ class AttnProcessor2_0:
901
 
902
  if attn.use_tpu_flash_attention: # use tpu attention offload 'flash attention'
903
  q_segment_indexes = None
904
- if attention_mask is not None: # if mask is required need to tune both segmenIds fields
 
 
905
  # attention_mask = torch.squeeze(attention_mask).to(torch.float32)
906
  attention_mask = attention_mask.to(torch.float32)
907
- q_segment_indexes = torch.ones(batch_size, query.shape[2], device=query.device, dtype=torch.float32)
 
 
908
  assert (
909
  attention_mask.shape[1] == key.shape[2]
910
  ), f"ERROR: KEY SHAPE must be same as attention mask [{key.shape[2]}, {attention_mask.shape[1]}]"
@@ -927,10 +1027,17 @@ class AttnProcessor2_0:
927
  )
928
  else:
929
  hidden_states = F.scaled_dot_product_attention(
930
- query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
 
 
 
 
 
931
  )
932
 
933
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
 
 
934
  hidden_states = hidden_states.to(query.dtype)
935
 
936
  # linear proj
@@ -939,7 +1046,9 @@ class AttnProcessor2_0:
939
  hidden_states = attn.to_out[1](hidden_states)
940
 
941
  if input_ndim == 4:
942
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
 
 
943
 
944
  if attn.residual_connection:
945
  hidden_states = hidden_states + residual
@@ -977,22 +1086,32 @@ class AttnProcessor:
977
 
978
  if input_ndim == 4:
979
  batch_size, channel, height, width = hidden_states.shape
980
- hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
 
 
981
 
982
  batch_size, sequence_length, _ = (
983
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
 
 
 
 
 
984
  )
985
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
986
 
987
  if attn.group_norm is not None:
988
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
 
 
989
 
990
  query = attn.to_q(hidden_states)
991
 
992
  if encoder_hidden_states is None:
993
  encoder_hidden_states = hidden_states
994
  elif attn.norm_cross:
995
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
 
 
996
 
997
  key = attn.to_k(encoder_hidden_states)
998
  value = attn.to_v(encoder_hidden_states)
@@ -1014,7 +1133,9 @@ class AttnProcessor:
1014
  hidden_states = attn.to_out[1](hidden_states)
1015
 
1016
  if input_ndim == 4:
1017
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
 
 
1018
 
1019
  if attn.residual_connection:
1020
  hidden_states = hidden_states + residual
 
106
  assert standardization_norm in ["layer_norm", "rms_norm"]
107
  assert adaptive_norm in ["single_scale_shift", "single_scale", "none"]
108
 
109
+ make_norm_layer = (
110
+ nn.LayerNorm if standardization_norm == "layer_norm" else RMSNorm
111
+ )
112
 
113
  # Define 3 blocks. Each block has its own normalization layer.
114
  # 1. Self-Attn
115
+ self.norm1 = make_norm_layer(
116
+ dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps
117
+ )
118
 
119
  self.attn1 = Attention(
120
  query_dim=dim,
 
134
  if cross_attention_dim is not None or double_self_attention:
135
  self.attn2 = Attention(
136
  query_dim=dim,
137
+ cross_attention_dim=(
138
+ cross_attention_dim if not double_self_attention else None
139
+ ),
140
  heads=num_attention_heads,
141
  dim_head=attention_head_dim,
142
  dropout=dropout,
 
149
  ) # is self-attn if encoder_hidden_states is none
150
 
151
  if adaptive_norm == "none":
152
+ self.attn2_norm = make_norm_layer(
153
+ dim, norm_eps, norm_elementwise_affine
154
+ )
155
  else:
156
  self.attn2 = None
157
  self.attn2_norm = None
 
171
  # 5. Scale-shift for PixArt-Alpha.
172
  if adaptive_norm != "none":
173
  num_ada_params = 4 if adaptive_norm == "single_scale" else 6
174
+ self.scale_shift_table = nn.Parameter(
175
+ torch.randn(num_ada_params, dim) / dim**0.5
176
+ )
177
 
178
  # let chunk size default to None
179
  self._chunk_size = None
 
208
  ) -> torch.FloatTensor:
209
  if cross_attention_kwargs is not None:
210
  if cross_attention_kwargs.get("scale", None) is not None:
211
+ logger.warning(
212
+ "Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored."
213
+ )
214
 
215
  # Notice that normalization is always applied before the real computation in the following blocks.
216
  # 0. Self-Attention
 
226
  batch_size, timestep.shape[1], num_ada_params, -1
227
  )
228
  if self.adaptive_norm == "single_scale_shift":
229
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
230
+ ada_values.unbind(dim=2)
231
+ )
232
  norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
233
  else:
234
  scale_msa, gate_msa, scale_mlp, gate_mlp = ada_values.unbind(dim=2)
 
238
  else:
239
  raise ValueError(f"Unknown adaptive norm type: {self.adaptive_norm}")
240
 
241
+ norm_hidden_states = norm_hidden_states.squeeze(
242
+ 1
243
+ ) # TODO: Check if this is needed
244
 
245
  # 1. Prepare GLIGEN inputs
246
+ cross_attention_kwargs = (
247
+ cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
248
+ )
249
 
250
  attn_output = self.attn1(
251
  norm_hidden_states,
252
  freqs_cis=freqs_cis,
253
+ encoder_hidden_states=(
254
+ encoder_hidden_states if self.only_cross_attention else None
255
+ ),
256
  attention_mask=attention_mask,
257
  **cross_attention_kwargs,
258
  )
 
291
 
292
  if self._chunk_size is not None:
293
  # "feed_forward_chunk_size" can be used to save memory
294
+ ff_output = _chunked_feed_forward(
295
+ self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size
296
+ )
297
  else:
298
  ff_output = self.ff(norm_hidden_states)
299
  if gate_mlp is not None:
 
393
  self.query_dim = query_dim
394
  self.use_bias = bias
395
  self.is_cross_attention = cross_attention_dim is not None
396
+ self.cross_attention_dim = (
397
+ cross_attention_dim if cross_attention_dim is not None else query_dim
398
+ )
399
  self.upcast_attention = upcast_attention
400
  self.upcast_softmax = upcast_softmax
401
  self.rescale_output_factor = rescale_output_factor
 
440
  )
441
 
442
  if norm_num_groups is not None:
443
+ self.group_norm = nn.GroupNorm(
444
+ num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True
445
+ )
446
  else:
447
  self.group_norm = None
448
 
449
  if spatial_norm_dim is not None:
450
+ self.spatial_norm = SpatialNorm(
451
+ f_channels=query_dim, zq_channels=spatial_norm_dim
452
+ )
453
  else:
454
  self.spatial_norm = None
455
 
 
469
  norm_cross_num_channels = self.cross_attention_dim
470
 
471
  self.norm_cross = nn.GroupNorm(
472
+ num_channels=norm_cross_num_channels,
473
+ num_groups=cross_attention_norm_num_groups,
474
+ eps=1e-5,
475
+ affine=True,
476
  )
477
  else:
478
  raise ValueError(
 
530
  and isinstance(self.processor, torch.nn.Module)
531
  and not isinstance(processor, torch.nn.Module)
532
  ):
533
+ logger.info(
534
+ f"You are removing possibly trained weights of {self.processor} with {processor}"
535
+ )
536
  self._modules.pop("processor")
537
 
538
  self.processor = processor
539
 
540
+ def get_processor(
541
+ self, return_deprecated_lora: bool = False
542
+ ) -> "AttentionProcessor": # noqa: F821
543
  r"""
544
  Get the attention processor in use.
545
 
 
577
 
578
  # 3. And we need to merge the current LoRA layers into the corresponding LoRA attention processor
579
  non_lora_processor_cls_name = self.processor.__class__.__name__
580
+ lora_processor_cls = getattr(
581
+ import_module(__name__), "LoRA" + non_lora_processor_cls_name
582
+ )
583
 
584
  hidden_size = self.inner_dim
585
 
586
  # now create a LoRA attention processor from the LoRA layers
587
+ if lora_processor_cls in [
588
+ LoRAAttnProcessor,
589
+ LoRAAttnProcessor2_0,
590
+ LoRAXFormersAttnProcessor,
591
+ ]:
592
  kwargs = {
593
  "cross_attention_dim": self.cross_attention_dim,
594
  "rank": self.to_q.lora_layer.rank,
 
610
  lora_processor.to_q_lora.load_state_dict(self.to_q.lora_layer.state_dict())
611
  lora_processor.to_k_lora.load_state_dict(self.to_k.lora_layer.state_dict())
612
  lora_processor.to_v_lora.load_state_dict(self.to_v.lora_layer.state_dict())
613
+ lora_processor.to_out_lora.load_state_dict(
614
+ self.to_out[0].lora_layer.state_dict()
615
+ )
616
  elif lora_processor_cls == LoRAAttnAddedKVProcessor:
617
  lora_processor = lora_processor_cls(
618
  hidden_size,
 
623
  lora_processor.to_q_lora.load_state_dict(self.to_q.lora_layer.state_dict())
624
  lora_processor.to_k_lora.load_state_dict(self.to_k.lora_layer.state_dict())
625
  lora_processor.to_v_lora.load_state_dict(self.to_v.lora_layer.state_dict())
626
+ lora_processor.to_out_lora.load_state_dict(
627
+ self.to_out[0].lora_layer.state_dict()
628
+ )
629
 
630
  # only save if used
631
  if self.add_k_proj.lora_layer is not None:
632
+ lora_processor.add_k_proj_lora.load_state_dict(
633
+ self.add_k_proj.lora_layer.state_dict()
634
+ )
635
+ lora_processor.add_v_proj_lora.load_state_dict(
636
+ self.add_v_proj.lora_layer.state_dict()
637
+ )
638
  else:
639
  lora_processor.add_k_proj_lora = None
640
  lora_processor.add_v_proj_lora = None
 
671
  # here we simply pass along all tensors to the selected processor class
672
  # For standard processors that are defined here, `**cross_attention_kwargs` is empty
673
 
674
+ attn_parameters = set(
675
+ inspect.signature(self.processor.__call__).parameters.keys()
676
+ )
677
+ unused_kwargs = [
678
+ k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters
679
+ ]
680
  if len(unused_kwargs) > 0:
681
  logger.warning(
682
  f"cross_attention_kwargs {unused_kwargs} are not expected by"
683
  f" {self.processor.__class__.__name__} and will be ignored."
684
  )
685
+ cross_attention_kwargs = {
686
+ k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters
687
+ }
688
 
689
  return self.processor(
690
  self,
 
709
  head_size = self.heads
710
  batch_size, seq_len, dim = tensor.shape
711
  tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
712
+ tensor = tensor.permute(0, 2, 1, 3).reshape(
713
+ batch_size // head_size, seq_len, dim * head_size
714
+ )
715
  return tensor
716
 
717
  def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) -> torch.Tensor:
 
734
  extra_dim = 1
735
  else:
736
  batch_size, extra_dim, seq_len, dim = tensor.shape
737
+ tensor = tensor.reshape(
738
+ batch_size, seq_len * extra_dim, head_size, dim // head_size
739
+ )
740
  tensor = tensor.permute(0, 2, 1, 3)
741
 
742
  if out_dim == 3:
743
+ tensor = tensor.reshape(
744
+ batch_size * head_size, seq_len * extra_dim, dim // head_size
745
+ )
746
 
747
  return tensor
748
 
749
  def get_attention_scores(
750
+ self,
751
+ query: torch.Tensor,
752
+ key: torch.Tensor,
753
+ attention_mask: torch.Tensor = None,
754
  ) -> torch.Tensor:
755
  r"""
756
  Compute the attention scores.
 
770
 
771
  if attention_mask is None:
772
  baddbmm_input = torch.empty(
773
+ query.shape[0],
774
+ query.shape[1],
775
+ key.shape[1],
776
+ dtype=query.dtype,
777
+ device=query.device,
778
  )
779
  beta = 0
780
  else:
 
801
  return attention_probs
802
 
803
  def prepare_attention_mask(
804
+ self,
805
+ attention_mask: torch.Tensor,
806
+ target_length: int,
807
+ batch_size: int,
808
+ out_dim: int = 3,
809
  ) -> torch.Tensor:
810
  r"""
811
  Prepare the attention mask for the attention computation.
 
832
  if attention_mask.device.type == "mps":
833
  # HACK: MPS: Does not support padding by greater than dimension of input tensor.
834
  # Instead, we can manually construct the padding tensor.
835
+ padding_shape = (
836
+ attention_mask.shape[0],
837
+ attention_mask.shape[1],
838
+ target_length,
839
+ )
840
+ padding = torch.zeros(
841
+ padding_shape,
842
+ dtype=attention_mask.dtype,
843
+ device=attention_mask.device,
844
+ )
845
  attention_mask = torch.cat([attention_mask, padding], dim=2)
846
  else:
847
  # TODO: for pipelines such as stable-diffusion, padding cross-attn mask:
 
859
 
860
  return attention_mask
861
 
862
+ def norm_encoder_hidden_states(
863
+ self, encoder_hidden_states: torch.Tensor
864
+ ) -> torch.Tensor:
865
  r"""
866
  Normalize the encoder hidden states. Requires `self.norm_cross` to be specified when constructing the
867
  `Attention` class.
 
872
  Returns:
873
  `torch.Tensor`: The normalized encoder hidden states.
874
  """
875
+ assert (
876
+ self.norm_cross is not None
877
+ ), "self.norm_cross must be defined to call self.norm_encoder_hidden_states"
878
 
879
  if isinstance(self.norm_cross, nn.LayerNorm):
880
  encoder_hidden_states = self.norm_cross(encoder_hidden_states)
 
941
 
942
  if input_ndim == 4:
943
  batch_size, channel, height, width = hidden_states.shape
944
+ hidden_states = hidden_states.view(
945
+ batch_size, channel, height * width
946
+ ).transpose(1, 2)
947
 
948
  batch_size, sequence_length, _ = (
949
+ hidden_states.shape
950
+ if encoder_hidden_states is None
951
+ else encoder_hidden_states.shape
952
  )
953
 
954
  if (attention_mask is not None) and (not attn.use_tpu_flash_attention):
955
+ attention_mask = attn.prepare_attention_mask(
956
+ attention_mask, sequence_length, batch_size
957
+ )
958
  # scaled_dot_product_attention expects attention_mask shape to be
959
  # (batch, heads, source_length, target_length)
960
+ attention_mask = attention_mask.view(
961
+ batch_size, attn.heads, -1, attention_mask.shape[-1]
962
+ )
963
 
964
  if attn.group_norm is not None:
965
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
966
+ 1, 2
967
+ )
968
 
969
  query = attn.to_q(hidden_states)
970
  query = attn.q_norm(query)
971
 
972
  if encoder_hidden_states is not None:
973
  if attn.norm_cross:
974
+ encoder_hidden_states = attn.norm_encoder_hidden_states(
975
+ encoder_hidden_states
976
+ )
977
  key = attn.to_k(encoder_hidden_states)
978
  key = attn.k_norm(key)
979
  else: # if no context provided do self-attention
 
997
 
998
  if attn.use_tpu_flash_attention: # use tpu attention offload 'flash attention'
999
  q_segment_indexes = None
1000
+ if (
1001
+ attention_mask is not None
1002
+ ): # if mask is required need to tune both segmenIds fields
1003
  # attention_mask = torch.squeeze(attention_mask).to(torch.float32)
1004
  attention_mask = attention_mask.to(torch.float32)
1005
+ q_segment_indexes = torch.ones(
1006
+ batch_size, query.shape[2], device=query.device, dtype=torch.float32
1007
+ )
1008
  assert (
1009
  attention_mask.shape[1] == key.shape[2]
1010
  ), f"ERROR: KEY SHAPE must be same as attention mask [{key.shape[2]}, {attention_mask.shape[1]}]"
 
1027
  )
1028
  else:
1029
  hidden_states = F.scaled_dot_product_attention(
1030
+ query,
1031
+ key,
1032
+ value,
1033
+ attn_mask=attention_mask,
1034
+ dropout_p=0.0,
1035
+ is_causal=False,
1036
  )
1037
 
1038
+ hidden_states = hidden_states.transpose(1, 2).reshape(
1039
+ batch_size, -1, attn.heads * head_dim
1040
+ )
1041
  hidden_states = hidden_states.to(query.dtype)
1042
 
1043
  # linear proj
 
1046
  hidden_states = attn.to_out[1](hidden_states)
1047
 
1048
  if input_ndim == 4:
1049
+ hidden_states = hidden_states.transpose(-1, -2).reshape(
1050
+ batch_size, channel, height, width
1051
+ )
1052
 
1053
  if attn.residual_connection:
1054
  hidden_states = hidden_states + residual
 
1086
 
1087
  if input_ndim == 4:
1088
  batch_size, channel, height, width = hidden_states.shape
1089
+ hidden_states = hidden_states.view(
1090
+ batch_size, channel, height * width
1091
+ ).transpose(1, 2)
1092
 
1093
  batch_size, sequence_length, _ = (
1094
+ hidden_states.shape
1095
+ if encoder_hidden_states is None
1096
+ else encoder_hidden_states.shape
1097
+ )
1098
+ attention_mask = attn.prepare_attention_mask(
1099
+ attention_mask, sequence_length, batch_size
1100
  )
 
1101
 
1102
  if attn.group_norm is not None:
1103
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
1104
+ 1, 2
1105
+ )
1106
 
1107
  query = attn.to_q(hidden_states)
1108
 
1109
  if encoder_hidden_states is None:
1110
  encoder_hidden_states = hidden_states
1111
  elif attn.norm_cross:
1112
+ encoder_hidden_states = attn.norm_encoder_hidden_states(
1113
+ encoder_hidden_states
1114
+ )
1115
 
1116
  key = attn.to_k(encoder_hidden_states)
1117
  value = attn.to_v(encoder_hidden_states)
 
1133
  hidden_states = attn.to_out[1](hidden_states)
1134
 
1135
  if input_ndim == 4:
1136
+ hidden_states = hidden_states.transpose(-1, -2).reshape(
1137
+ batch_size, channel, height, width
1138
+ )
1139
 
1140
  if attn.residual_connection:
1141
  hidden_states = hidden_states + residual
xora/models/transformers/embeddings.py CHANGED
@@ -26,7 +26,9 @@ def get_timestep_embedding(
26
  assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
27
 
28
  half_dim = embedding_dim // 2
29
- exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32, device=timesteps.device)
 
 
30
  exponent = exponent / (half_dim - downscale_freq_shift)
31
 
32
  emb = torch.exp(exponent)
@@ -113,7 +115,9 @@ class SinusoidalPositionalEmbedding(nn.Module):
113
  def __init__(self, embed_dim: int, max_seq_length: int = 32):
114
  super().__init__()
115
  position = torch.arange(max_seq_length).unsqueeze(1)
116
- div_term = torch.exp(torch.arange(0, embed_dim, 2) * (-math.log(10000.0) / embed_dim))
 
 
117
  pe = torch.zeros(1, max_seq_length, embed_dim)
118
  pe[0, :, 0::2] = torch.sin(position * div_term)
119
  pe[0, :, 1::2] = torch.cos(position * div_term)
 
26
  assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
27
 
28
  half_dim = embedding_dim // 2
29
+ exponent = -math.log(max_period) * torch.arange(
30
+ start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
31
+ )
32
  exponent = exponent / (half_dim - downscale_freq_shift)
33
 
34
  emb = torch.exp(exponent)
 
115
  def __init__(self, embed_dim: int, max_seq_length: int = 32):
116
  super().__init__()
117
  position = torch.arange(max_seq_length).unsqueeze(1)
118
+ div_term = torch.exp(
119
+ torch.arange(0, embed_dim, 2) * (-math.log(10000.0) / embed_dim)
120
+ )
121
  pe = torch.zeros(1, max_seq_length, embed_dim)
122
  pe[0, :, 0::2] = torch.sin(position * div_term)
123
  pe[0, :, 1::2] = torch.cos(position * div_term)
xora/models/transformers/symmetric_patchifier.py CHANGED
@@ -15,12 +15,19 @@ class Patchifier(ConfigMixin, ABC):
15
  self._patch_size = (1, patch_size, patch_size)
16
 
17
  @abstractmethod
18
- def patchify(self, latents: Tensor, frame_rates: Tensor, scale_grid: bool) -> Tuple[Tensor, Tensor]:
 
 
19
  pass
20
 
21
  @abstractmethod
22
  def unpatchify(
23
- self, latents: Tensor, output_height: int, output_width: int, output_num_frames: int, out_channels: int
 
 
 
 
 
24
  ) -> Tuple[Tensor, Tensor]:
25
  pass
26
 
@@ -28,7 +35,9 @@ class Patchifier(ConfigMixin, ABC):
28
  def patch_size(self):
29
  return self._patch_size
30
 
31
- def get_grid(self, orig_num_frames, orig_height, orig_width, batch_size, scale_grid, device):
 
 
32
  f = orig_num_frames // self._patch_size[0]
33
  h = orig_height // self._patch_size[1]
34
  w = orig_width // self._patch_size[2]
@@ -64,6 +73,7 @@ def pixart_alpha_patchify(
64
  )
65
  return latents
66
 
 
67
  class SymmetricPatchifier(Patchifier):
68
  def patchify(
69
  self,
@@ -72,7 +82,12 @@ class SymmetricPatchifier(Patchifier):
72
  return pixart_alpha_patchify(latents, self._patch_size)
73
 
74
  def unpatchify(
75
- self, latents: Tensor, output_height: int, output_width: int, output_num_frames: int, out_channels: int
 
 
 
 
 
76
  ) -> Tuple[Tensor, Tensor]:
77
  output_height = output_height // self._patch_size[1]
78
  output_width = output_width // self._patch_size[2]
 
15
  self._patch_size = (1, patch_size, patch_size)
16
 
17
  @abstractmethod
18
+ def patchify(
19
+ self, latents: Tensor, frame_rates: Tensor, scale_grid: bool
20
+ ) -> Tuple[Tensor, Tensor]:
21
  pass
22
 
23
  @abstractmethod
24
  def unpatchify(
25
+ self,
26
+ latents: Tensor,
27
+ output_height: int,
28
+ output_width: int,
29
+ output_num_frames: int,
30
+ out_channels: int,
31
  ) -> Tuple[Tensor, Tensor]:
32
  pass
33
 
 
35
  def patch_size(self):
36
  return self._patch_size
37
 
38
+ def get_grid(
39
+ self, orig_num_frames, orig_height, orig_width, batch_size, scale_grid, device
40
+ ):
41
  f = orig_num_frames // self._patch_size[0]
42
  h = orig_height // self._patch_size[1]
43
  w = orig_width // self._patch_size[2]
 
73
  )
74
  return latents
75
 
76
+
77
  class SymmetricPatchifier(Patchifier):
78
  def patchify(
79
  self,
 
82
  return pixart_alpha_patchify(latents, self._patch_size)
83
 
84
  def unpatchify(
85
+ self,
86
+ latents: Tensor,
87
+ output_height: int,
88
+ output_width: int,
89
+ output_num_frames: int,
90
+ out_channels: int,
91
  ) -> Tuple[Tensor, Tensor]:
92
  output_height = output_height // self._patch_size[1]
93
  output_width = output_width // self._patch_size[2]
xora/models/transformers/transformer3d.py CHANGED
@@ -17,6 +17,7 @@ from xora.models.transformers.embeddings import get_3d_sincos_pos_embed
17
 
18
  logger = logging.get_logger(__name__)
19
 
 
20
  @dataclass
21
  class Transformer3DModelOutput(BaseOutput):
22
  """
@@ -68,7 +69,9 @@ class Transformer3DModel(ModelMixin, ConfigMixin):
68
  timestep_scale_multiplier: Optional[float] = None,
69
  ):
70
  super().__init__()
71
- self.use_tpu_flash_attention = use_tpu_flash_attention # FIXME: push config down to the attention modules
 
 
72
  self.use_linear_projection = use_linear_projection
73
  self.num_attention_heads = num_attention_heads
74
  self.attention_head_dim = attention_head_dim
@@ -86,7 +89,9 @@ class Transformer3DModel(ModelMixin, ConfigMixin):
86
  self.timestep_scale_multiplier = timestep_scale_multiplier
87
 
88
  if self.positional_embedding_type == "absolute":
89
- embed_dim_3d = math.ceil((inner_dim / 2) * 3) if project_to_2d_pos else inner_dim
 
 
90
  if self.project_to_2d_pos:
91
  self.to_2d_proj = torch.nn.Linear(embed_dim_3d, inner_dim, bias=False)
92
  self._init_to_2d_proj_weights(self.to_2d_proj)
@@ -131,18 +136,24 @@ class Transformer3DModel(ModelMixin, ConfigMixin):
131
  # 4. Define output layers
132
  self.out_channels = in_channels if out_channels is None else out_channels
133
  self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
134
- self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5)
 
 
135
  self.proj_out = nn.Linear(inner_dim, self.out_channels)
136
 
137
  # 5. PixArt-Alpha blocks.
138
- self.adaln_single = AdaLayerNormSingle(inner_dim, use_additional_conditions=False)
 
 
139
  if adaptive_norm == "single_scale":
140
  # Use 4 channels instead of the 6 for the PixArt-Alpha scale + shift ada norm.
141
  self.adaln_single.linear = nn.Linear(inner_dim, 4 * inner_dim, bias=True)
142
 
143
  self.caption_projection = None
144
  if caption_channels is not None:
145
- self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim)
 
 
146
 
147
  self.gradient_checkpointing = False
148
 
@@ -169,16 +180,32 @@ class Transformer3DModel(ModelMixin, ConfigMixin):
169
  self.apply(_basic_init)
170
 
171
  # Initialize timestep embedding MLP:
172
- nn.init.normal_(self.adaln_single.emb.timestep_embedder.linear_1.weight, std=embedding_std)
173
- nn.init.normal_(self.adaln_single.emb.timestep_embedder.linear_2.weight, std=embedding_std)
 
 
 
 
174
  nn.init.normal_(self.adaln_single.linear.weight, std=embedding_std)
175
 
176
  if hasattr(self.adaln_single.emb, "resolution_embedder"):
177
- nn.init.normal_(self.adaln_single.emb.resolution_embedder.linear_1.weight, std=embedding_std)
178
- nn.init.normal_(self.adaln_single.emb.resolution_embedder.linear_2.weight, std=embedding_std)
 
 
 
 
 
 
179
  if hasattr(self.adaln_single.emb, "aspect_ratio_embedder"):
180
- nn.init.normal_(self.adaln_single.emb.aspect_ratio_embedder.linear_1.weight, std=embedding_std)
181
- nn.init.normal_(self.adaln_single.emb.aspect_ratio_embedder.linear_2.weight, std=embedding_std)
 
 
 
 
 
 
182
 
183
  # Initialize caption embedding MLP:
184
  nn.init.normal_(self.caption_projection.linear_1.weight, std=embedding_std)
@@ -220,7 +247,11 @@ class Transformer3DModel(ModelMixin, ConfigMixin):
220
 
221
  def get_fractional_positions(self, indices_grid):
222
  fractional_positions = torch.stack(
223
- [indices_grid[:, i] / self.positional_embedding_max_pos[i] for i in range(3)], dim=-1
 
 
 
 
224
  )
225
  return fractional_positions
226
 
@@ -236,7 +267,13 @@ class Transformer3DModel(ModelMixin, ConfigMixin):
236
  device = fractional_positions.device
237
  if spacing == "exp":
238
  indices = theta ** (
239
- torch.linspace(math.log(start, theta), math.log(end, theta), dim // 6, device=device, dtype=dtype)
 
 
 
 
 
 
240
  )
241
  indices = indices.to(dtype=dtype)
242
  elif spacing == "exp_2":
@@ -245,14 +282,24 @@ class Transformer3DModel(ModelMixin, ConfigMixin):
245
  elif spacing == "linear":
246
  indices = torch.linspace(start, end, dim // 6, device=device, dtype=dtype)
247
  elif spacing == "sqrt":
248
- indices = torch.linspace(start**2, end**2, dim // 6, device=device, dtype=dtype).sqrt()
 
 
249
 
250
  indices = indices * math.pi / 2
251
 
252
  if spacing == "exp_2":
253
- freqs = (indices * fractional_positions.unsqueeze(-1)).transpose(-1, -2).flatten(2)
 
 
 
 
254
  else:
255
- freqs = (indices * (fractional_positions.unsqueeze(-1) * 2 - 1)).transpose(-1, -2).flatten(2)
 
 
 
 
256
 
257
  cos_freq = freqs.cos().repeat_interleave(2, dim=-1)
258
  sin_freq = freqs.sin().repeat_interleave(2, dim=-1)
@@ -336,7 +383,9 @@ class Transformer3DModel(ModelMixin, ConfigMixin):
336
 
337
  # convert encoder_attention_mask to a bias the same way we do for attention_mask
338
  if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
339
- encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
 
 
340
  encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
341
 
342
  # 1. Input
@@ -346,7 +395,9 @@ class Transformer3DModel(ModelMixin, ConfigMixin):
346
  timestep = self.timestep_scale_multiplier * timestep
347
 
348
  if self.positional_embedding_type == "absolute":
349
- pos_embed_3d = self.get_absolute_pos_embed(indices_grid).to(hidden_states.device)
 
 
350
  if self.project_to_2d_pos:
351
  pos_embed = self.to_2d_proj(pos_embed_3d)
352
  hidden_states = (hidden_states + pos_embed).to(hidden_states.dtype)
@@ -363,13 +414,17 @@ class Transformer3DModel(ModelMixin, ConfigMixin):
363
  )
364
  # Second dimension is 1 or number of tokens (if timestep_per_token)
365
  timestep = timestep.view(batch_size, -1, timestep.shape[-1])
366
- embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.shape[-1])
 
 
367
 
368
  # 2. Blocks
369
  if self.caption_projection is not None:
370
  batch_size = hidden_states.shape[0]
371
  encoder_hidden_states = self.caption_projection(encoder_hidden_states)
372
- encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
 
 
373
 
374
  for block in self.transformer_blocks:
375
  if self.training and self.gradient_checkpointing:
@@ -383,7 +438,9 @@ class Transformer3DModel(ModelMixin, ConfigMixin):
383
 
384
  return custom_forward
385
 
386
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
 
 
387
  hidden_states = torch.utils.checkpoint.checkpoint(
388
  create_custom_forward(block),
389
  hidden_states,
@@ -409,7 +466,9 @@ class Transformer3DModel(ModelMixin, ConfigMixin):
409
  )
410
 
411
  # 3. Output
412
- scale_shift_values = self.scale_shift_table[None, None] + embedded_timestep[:, :, None]
 
 
413
  shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1]
414
  hidden_states = self.norm_out(hidden_states)
415
  # Modulation
@@ -422,7 +481,11 @@ class Transformer3DModel(ModelMixin, ConfigMixin):
422
 
423
  def get_absolute_pos_embed(self, grid):
424
  grid_np = grid[0].cpu().numpy()
425
- embed_dim_3d = math.ceil((self.inner_dim / 2) * 3) if self.project_to_2d_pos else self.inner_dim
 
 
 
 
426
  pos_embed = get_3d_sincos_pos_embed( # (f h w)
427
  embed_dim_3d,
428
  grid_np,
 
17
 
18
  logger = logging.get_logger(__name__)
19
 
20
+
21
  @dataclass
22
  class Transformer3DModelOutput(BaseOutput):
23
  """
 
69
  timestep_scale_multiplier: Optional[float] = None,
70
  ):
71
  super().__init__()
72
+ self.use_tpu_flash_attention = (
73
+ use_tpu_flash_attention # FIXME: push config down to the attention modules
74
+ )
75
  self.use_linear_projection = use_linear_projection
76
  self.num_attention_heads = num_attention_heads
77
  self.attention_head_dim = attention_head_dim
 
89
  self.timestep_scale_multiplier = timestep_scale_multiplier
90
 
91
  if self.positional_embedding_type == "absolute":
92
+ embed_dim_3d = (
93
+ math.ceil((inner_dim / 2) * 3) if project_to_2d_pos else inner_dim
94
+ )
95
  if self.project_to_2d_pos:
96
  self.to_2d_proj = torch.nn.Linear(embed_dim_3d, inner_dim, bias=False)
97
  self._init_to_2d_proj_weights(self.to_2d_proj)
 
136
  # 4. Define output layers
137
  self.out_channels = in_channels if out_channels is None else out_channels
138
  self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
139
+ self.scale_shift_table = nn.Parameter(
140
+ torch.randn(2, inner_dim) / inner_dim**0.5
141
+ )
142
  self.proj_out = nn.Linear(inner_dim, self.out_channels)
143
 
144
  # 5. PixArt-Alpha blocks.
145
+ self.adaln_single = AdaLayerNormSingle(
146
+ inner_dim, use_additional_conditions=False
147
+ )
148
  if adaptive_norm == "single_scale":
149
  # Use 4 channels instead of the 6 for the PixArt-Alpha scale + shift ada norm.
150
  self.adaln_single.linear = nn.Linear(inner_dim, 4 * inner_dim, bias=True)
151
 
152
  self.caption_projection = None
153
  if caption_channels is not None:
154
+ self.caption_projection = PixArtAlphaTextProjection(
155
+ in_features=caption_channels, hidden_size=inner_dim
156
+ )
157
 
158
  self.gradient_checkpointing = False
159
 
 
180
  self.apply(_basic_init)
181
 
182
  # Initialize timestep embedding MLP:
183
+ nn.init.normal_(
184
+ self.adaln_single.emb.timestep_embedder.linear_1.weight, std=embedding_std
185
+ )
186
+ nn.init.normal_(
187
+ self.adaln_single.emb.timestep_embedder.linear_2.weight, std=embedding_std
188
+ )
189
  nn.init.normal_(self.adaln_single.linear.weight, std=embedding_std)
190
 
191
  if hasattr(self.adaln_single.emb, "resolution_embedder"):
192
+ nn.init.normal_(
193
+ self.adaln_single.emb.resolution_embedder.linear_1.weight,
194
+ std=embedding_std,
195
+ )
196
+ nn.init.normal_(
197
+ self.adaln_single.emb.resolution_embedder.linear_2.weight,
198
+ std=embedding_std,
199
+ )
200
  if hasattr(self.adaln_single.emb, "aspect_ratio_embedder"):
201
+ nn.init.normal_(
202
+ self.adaln_single.emb.aspect_ratio_embedder.linear_1.weight,
203
+ std=embedding_std,
204
+ )
205
+ nn.init.normal_(
206
+ self.adaln_single.emb.aspect_ratio_embedder.linear_2.weight,
207
+ std=embedding_std,
208
+ )
209
 
210
  # Initialize caption embedding MLP:
211
  nn.init.normal_(self.caption_projection.linear_1.weight, std=embedding_std)
 
247
 
248
  def get_fractional_positions(self, indices_grid):
249
  fractional_positions = torch.stack(
250
+ [
251
+ indices_grid[:, i] / self.positional_embedding_max_pos[i]
252
+ for i in range(3)
253
+ ],
254
+ dim=-1,
255
  )
256
  return fractional_positions
257
 
 
267
  device = fractional_positions.device
268
  if spacing == "exp":
269
  indices = theta ** (
270
+ torch.linspace(
271
+ math.log(start, theta),
272
+ math.log(end, theta),
273
+ dim // 6,
274
+ device=device,
275
+ dtype=dtype,
276
+ )
277
  )
278
  indices = indices.to(dtype=dtype)
279
  elif spacing == "exp_2":
 
282
  elif spacing == "linear":
283
  indices = torch.linspace(start, end, dim // 6, device=device, dtype=dtype)
284
  elif spacing == "sqrt":
285
+ indices = torch.linspace(
286
+ start**2, end**2, dim // 6, device=device, dtype=dtype
287
+ ).sqrt()
288
 
289
  indices = indices * math.pi / 2
290
 
291
  if spacing == "exp_2":
292
+ freqs = (
293
+ (indices * fractional_positions.unsqueeze(-1))
294
+ .transpose(-1, -2)
295
+ .flatten(2)
296
+ )
297
  else:
298
+ freqs = (
299
+ (indices * (fractional_positions.unsqueeze(-1) * 2 - 1))
300
+ .transpose(-1, -2)
301
+ .flatten(2)
302
+ )
303
 
304
  cos_freq = freqs.cos().repeat_interleave(2, dim=-1)
305
  sin_freq = freqs.sin().repeat_interleave(2, dim=-1)
 
383
 
384
  # convert encoder_attention_mask to a bias the same way we do for attention_mask
385
  if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
386
+ encoder_attention_mask = (
387
+ 1 - encoder_attention_mask.to(hidden_states.dtype)
388
+ ) * -10000.0
389
  encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
390
 
391
  # 1. Input
 
395
  timestep = self.timestep_scale_multiplier * timestep
396
 
397
  if self.positional_embedding_type == "absolute":
398
+ pos_embed_3d = self.get_absolute_pos_embed(indices_grid).to(
399
+ hidden_states.device
400
+ )
401
  if self.project_to_2d_pos:
402
  pos_embed = self.to_2d_proj(pos_embed_3d)
403
  hidden_states = (hidden_states + pos_embed).to(hidden_states.dtype)
 
414
  )
415
  # Second dimension is 1 or number of tokens (if timestep_per_token)
416
  timestep = timestep.view(batch_size, -1, timestep.shape[-1])
417
+ embedded_timestep = embedded_timestep.view(
418
+ batch_size, -1, embedded_timestep.shape[-1]
419
+ )
420
 
421
  # 2. Blocks
422
  if self.caption_projection is not None:
423
  batch_size = hidden_states.shape[0]
424
  encoder_hidden_states = self.caption_projection(encoder_hidden_states)
425
+ encoder_hidden_states = encoder_hidden_states.view(
426
+ batch_size, -1, hidden_states.shape[-1]
427
+ )
428
 
429
  for block in self.transformer_blocks:
430
  if self.training and self.gradient_checkpointing:
 
438
 
439
  return custom_forward
440
 
441
+ ckpt_kwargs: Dict[str, Any] = (
442
+ {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
443
+ )
444
  hidden_states = torch.utils.checkpoint.checkpoint(
445
  create_custom_forward(block),
446
  hidden_states,
 
466
  )
467
 
468
  # 3. Output
469
+ scale_shift_values = (
470
+ self.scale_shift_table[None, None] + embedded_timestep[:, :, None]
471
+ )
472
  shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1]
473
  hidden_states = self.norm_out(hidden_states)
474
  # Modulation
 
481
 
482
  def get_absolute_pos_embed(self, grid):
483
  grid_np = grid[0].cpu().numpy()
484
+ embed_dim_3d = (
485
+ math.ceil((self.inner_dim / 2) * 3)
486
+ if self.project_to_2d_pos
487
+ else self.inner_dim
488
+ )
489
  pos_embed = get_3d_sincos_pos_embed( # (f h w)
490
  embed_dim_3d,
491
  grid_np,
xora/pipelines/pipeline_video_pixart_alpha.py CHANGED
@@ -5,12 +5,10 @@ import math
5
  import re
6
  import urllib.parse as ul
7
  from typing import Callable, Dict, List, Optional, Tuple, Union
8
- from abc import ABC, abstractmethod
9
 
10
 
11
  import torch
12
  import torch.nn.functional as F
13
- from torch import Tensor
14
  from diffusers.image_processor import VaeImageProcessor
15
  from diffusers.models import AutoencoderKL
16
  from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
@@ -29,7 +27,11 @@ from transformers import T5EncoderModel, T5Tokenizer
29
 
30
  from xora.models.transformers.transformer3d import Transformer3DModel
31
  from xora.models.transformers.symmetric_patchifier import Patchifier
32
- from xora.models.autoencoders.vae_encode import get_vae_size_scale_factor, vae_decode, vae_encode
 
 
 
 
33
  from xora.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder
34
  from xora.schedulers.rf import TimestepShifter
35
  from xora.utils.conditioning_method import ConditioningMethod
@@ -161,7 +163,9 @@ def retrieve_timesteps(
161
  second element is the number of inference steps.
162
  """
163
  if timesteps is not None:
164
- accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
 
 
165
  if not accepts_timesteps:
166
  raise ValueError(
167
  f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
@@ -238,7 +242,9 @@ class VideoPixArtAlphaPipeline(DiffusionPipeline):
238
  patchifier=patchifier,
239
  )
240
 
241
- self.video_scale_factor, self.vae_scale_factor, _ = get_vae_size_scale_factor(self.vae)
 
 
242
  self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
243
 
244
  # Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/utils.py
@@ -320,12 +326,16 @@ class VideoPixArtAlphaPipeline(DiffusionPipeline):
320
  return_tensors="pt",
321
  )
322
  text_input_ids = text_inputs.input_ids
323
- untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
324
-
325
- if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
326
- text_input_ids, untruncated_ids
327
- ):
328
- removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1])
 
 
 
 
329
  logger.warning(
330
  "The following part of your input was truncated because CLIP can only handle sequences up to"
331
  f" {max_length} tokens: {removed_text}"
@@ -334,7 +344,9 @@ class VideoPixArtAlphaPipeline(DiffusionPipeline):
334
  prompt_attention_mask = text_inputs.attention_mask
335
  prompt_attention_mask = prompt_attention_mask.to(device)
336
 
337
- prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask)
 
 
338
  prompt_embeds = prompt_embeds[0]
339
 
340
  if self.text_encoder is not None:
@@ -349,14 +361,20 @@ class VideoPixArtAlphaPipeline(DiffusionPipeline):
349
  bs_embed, seq_len, _ = prompt_embeds.shape
350
  # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
351
  prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
352
- prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
 
 
353
  prompt_attention_mask = prompt_attention_mask.repeat(1, num_images_per_prompt)
354
- prompt_attention_mask = prompt_attention_mask.view(bs_embed * num_images_per_prompt, -1)
 
 
355
 
356
  # get unconditional embeddings for classifier free guidance
357
  if do_classifier_free_guidance and negative_prompt_embeds is None:
358
  uncond_tokens = [negative_prompt] * batch_size
359
- uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption)
 
 
360
  max_length = prompt_embeds.shape[1]
361
  uncond_input = self.tokenizer(
362
  uncond_tokens,
@@ -371,7 +389,8 @@ class VideoPixArtAlphaPipeline(DiffusionPipeline):
371
  negative_prompt_attention_mask = negative_prompt_attention_mask.to(device)
372
 
373
  negative_prompt_embeds = self.text_encoder(
374
- uncond_input.input_ids.to(device), attention_mask=negative_prompt_attention_mask
 
375
  )
376
  negative_prompt_embeds = negative_prompt_embeds[0]
377
 
@@ -379,18 +398,33 @@ class VideoPixArtAlphaPipeline(DiffusionPipeline):
379
  # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
380
  seq_len = negative_prompt_embeds.shape[1]
381
 
382
- negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
 
 
383
 
384
- negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
385
- negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
 
 
 
 
386
 
387
- negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(1, num_images_per_prompt)
388
- negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed * num_images_per_prompt, -1)
 
 
 
 
389
  else:
390
  negative_prompt_embeds = None
391
  negative_prompt_attention_mask = None
392
 
393
- return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
 
 
 
 
 
394
 
395
  # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
396
  def prepare_extra_step_kwargs(self, generator, eta):
@@ -399,13 +433,17 @@ class VideoPixArtAlphaPipeline(DiffusionPipeline):
399
  # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
400
  # and should be between [0, 1]
401
 
402
- accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
 
 
403
  extra_step_kwargs = {}
404
  if accepts_eta:
405
  extra_step_kwargs["eta"] = eta
406
 
407
  # check if the scheduler accepts generator
408
- accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
 
 
409
  if accepts_generator:
410
  extra_step_kwargs["generator"] = generator
411
  return extra_step_kwargs
@@ -422,7 +460,9 @@ class VideoPixArtAlphaPipeline(DiffusionPipeline):
422
  negative_prompt_attention_mask=None,
423
  ):
424
  if height % 8 != 0 or width % 8 != 0:
425
- raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
 
 
426
 
427
  if prompt is not None and prompt_embeds is not None:
428
  raise ValueError(
@@ -433,8 +473,12 @@ class VideoPixArtAlphaPipeline(DiffusionPipeline):
433
  raise ValueError(
434
  "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
435
  )
436
- elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
437
- raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
 
 
 
 
438
 
439
  if prompt is not None and negative_prompt_embeds is not None:
440
  raise ValueError(
@@ -449,10 +493,17 @@ class VideoPixArtAlphaPipeline(DiffusionPipeline):
449
  )
450
 
451
  if prompt_embeds is not None and prompt_attention_mask is None:
452
- raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.")
 
 
453
 
454
- if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
455
- raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
 
 
 
 
 
456
 
457
  if prompt_embeds is not None and negative_prompt_embeds is not None:
458
  if prompt_embeds.shape != negative_prompt_embeds.shape:
@@ -471,12 +522,16 @@ class VideoPixArtAlphaPipeline(DiffusionPipeline):
471
  # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing
472
  def _text_preprocessing(self, text, clean_caption=False):
473
  if clean_caption and not is_bs4_available():
474
- logger.warn(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`"))
 
 
475
  logger.warn("Setting `clean_caption` to False...")
476
  clean_caption = False
477
 
478
  if clean_caption and not is_ftfy_available():
479
- logger.warn(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`"))
 
 
480
  logger.warn("Setting `clean_caption` to False...")
481
  clean_caption = False
482
 
@@ -564,13 +619,17 @@ class VideoPixArtAlphaPipeline(DiffusionPipeline):
564
  # "123456.."
565
  caption = re.sub(r"\b\d{6,}\b", "", caption)
566
  # filenames:
567
- caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption)
 
 
568
 
569
  #
570
  caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT"""
571
  caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT"""
572
 
573
- caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT
 
 
574
  caption = re.sub(r"\s+\.\s+", r" ", caption) # " . "
575
 
576
  # this-is-my-cute-cat / this_is_my_cute_cat
@@ -588,10 +647,14 @@ class VideoPixArtAlphaPipeline(DiffusionPipeline):
588
  caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption)
589
  caption = re.sub(r"(free\s)?download(\sfree)?", "", caption)
590
  caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption)
591
- caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption)
 
 
592
  caption = re.sub(r"\bpage\s+\d+\b", "", caption)
593
 
594
- caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a...
 
 
595
 
596
  caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption)
597
 
@@ -610,7 +673,15 @@ class VideoPixArtAlphaPipeline(DiffusionPipeline):
610
 
611
  # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
612
  def prepare_latents(
613
- self, batch_size, num_latent_channels, num_patches, dtype, device, generator, latents=None, latents_mask=None
 
 
 
 
 
 
 
 
614
  ):
615
  shape = (
616
  batch_size,
@@ -625,10 +696,14 @@ class VideoPixArtAlphaPipeline(DiffusionPipeline):
625
  )
626
 
627
  if latents is None:
628
- latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
 
 
629
  elif latents_mask is not None:
630
  noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
631
- latents = latents * latents_mask[..., None] + noise * (1 - latents_mask[..., None])
 
 
632
  else:
633
  latents = latents.to(device)
634
 
@@ -637,7 +712,9 @@ class VideoPixArtAlphaPipeline(DiffusionPipeline):
637
  return latents
638
 
639
  @staticmethod
640
- def classify_height_width_bin(height: int, width: int, ratios: dict) -> Tuple[int, int]:
 
 
641
  """Returns binned height and width."""
642
  ar = float(height / width)
643
  closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - ar))
@@ -645,7 +722,9 @@ class VideoPixArtAlphaPipeline(DiffusionPipeline):
645
  return int(default_hw[0]), int(default_hw[1])
646
 
647
  @staticmethod
648
- def resize_and_crop_tensor(samples: torch.Tensor, new_width: int, new_height: int) -> torch.Tensor:
 
 
649
  n_frames, orig_height, orig_width = samples.shape[-3:]
650
 
651
  # Check if resizing is needed
@@ -656,7 +735,12 @@ class VideoPixArtAlphaPipeline(DiffusionPipeline):
656
 
657
  # Resize
658
  samples = rearrange(samples, "b c n h w -> (b n) c h w")
659
- samples = F.interpolate(samples, size=(resized_height, resized_width), mode="bilinear", align_corners=False)
 
 
 
 
 
660
  samples = rearrange(samples, "(b n) c h w -> b c n h w", n=n_frames)
661
 
662
  # Center Crop
@@ -821,14 +905,21 @@ class VideoPixArtAlphaPipeline(DiffusionPipeline):
821
  )
822
  if do_classifier_free_guidance:
823
  prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
824
- prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
 
 
825
 
826
  # 3b. Encode and prepare conditioning data
827
  self.video_scale_factor = self.video_scale_factor if is_video else 1
828
  conditioning_method = kwargs.get("conditioning_method", None)
829
  vae_per_channel_normalize = kwargs.get("vae_per_channel_normalize", False)
830
  init_latents, conditioning_mask = self.prepare_conditioning(
831
- media_items, num_frames, height, width, conditioning_method, vae_per_channel_normalize
 
 
 
 
 
832
  )
833
 
834
  # 4. Prepare latents.
@@ -851,29 +942,46 @@ class VideoPixArtAlphaPipeline(DiffusionPipeline):
851
  )
852
  if conditioning_mask is not None and is_video:
853
  assert num_images_per_prompt == 1
854
- conditioning_mask = torch.cat([conditioning_mask] * 2) if do_classifier_free_guidance else conditioning_mask
 
 
 
 
855
 
856
  # 5. Prepare timesteps
857
  retrieve_timesteps_kwargs = {}
858
  if isinstance(self.scheduler, TimestepShifter):
859
  retrieve_timesteps_kwargs["samples"] = latents
860
  timesteps, num_inference_steps = retrieve_timesteps(
861
- self.scheduler, num_inference_steps, device, timesteps, **retrieve_timesteps_kwargs
 
 
 
 
862
  )
863
 
864
  # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
865
  extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
866
 
867
  # 7. Denoising loop
868
- num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
 
 
869
 
870
  with self.progress_bar(total=num_inference_steps) as progress_bar:
871
  for i, t in enumerate(timesteps):
872
- latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
873
- latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
 
 
 
 
874
 
875
  latent_frame_rates = (
876
- torch.ones(latent_model_input.shape[0], 1, device=latent_model_input.device) * latent_frame_rate
 
 
 
877
  )
878
 
879
  current_timestep = t
@@ -885,13 +993,25 @@ class VideoPixArtAlphaPipeline(DiffusionPipeline):
885
  dtype = torch.float32 if is_mps else torch.float64
886
  else:
887
  dtype = torch.int32 if is_mps else torch.int64
888
- current_timestep = torch.tensor([current_timestep], dtype=dtype, device=latent_model_input.device)
 
 
 
 
889
  elif len(current_timestep.shape) == 0:
890
- current_timestep = current_timestep[None].to(latent_model_input.device)
 
 
891
  # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
892
- current_timestep = current_timestep.expand(latent_model_input.shape[0]).unsqueeze(-1)
 
 
893
  scale_grid = (
894
- (1 / latent_frame_rates, self.vae_scale_factor, self.vae_scale_factor)
 
 
 
 
895
  if self.transformer.use_rope
896
  else None
897
  )
@@ -920,11 +1040,16 @@ class VideoPixArtAlphaPipeline(DiffusionPipeline):
920
  # perform guidance
921
  if do_classifier_free_guidance:
922
  noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
923
- noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
 
 
924
  current_timestep, _ = current_timestep.chunk(2)
925
 
926
  # learned sigma
927
- if self.transformer.config.out_channels // 2 == self.transformer.config.in_channels:
 
 
 
928
  noise_pred = noise_pred.chunk(2, dim=1)[0]
929
 
930
  # compute previous image: x_t -> x_t-1
@@ -937,7 +1062,9 @@ class VideoPixArtAlphaPipeline(DiffusionPipeline):
937
  )[0]
938
 
939
  # call the callback, if provided
940
- if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
 
 
941
  progress_bar.update()
942
 
943
  if callback_on_step_end is not None:
@@ -948,11 +1075,15 @@ class VideoPixArtAlphaPipeline(DiffusionPipeline):
948
  output_height=latent_height,
949
  output_width=latent_width,
950
  output_num_frames=latent_num_frames,
951
- out_channels=self.transformer.in_channels // math.prod(self.patchifier.patch_size),
 
952
  )
953
  if output_type != "latent":
954
  image = vae_decode(
955
- latents, self.vae, is_video, vae_per_channel_normalize=kwargs["vae_per_channel_normalize"]
 
 
 
956
  )
957
  image = self.image_processor.postprocess(image, output_type=output_type)
958
 
@@ -1005,20 +1136,31 @@ class VideoPixArtAlphaPipeline(DiffusionPipeline):
1005
  vae_per_channel_normalize=vae_per_channel_normalize,
1006
  ).float()
1007
 
1008
- init_len, target_len = init_latents.shape[2], num_frames // self.video_scale_factor
 
 
 
1009
  if isinstance(self.vae, CausalVideoAutoencoder):
1010
  target_len += 1
1011
  init_latents = init_latents[:, :, :target_len]
1012
  if target_len > init_len:
1013
  repeat_factor = (target_len + init_len - 1) // init_len # Ceiling division
1014
- init_latents = init_latents.repeat(1, 1, repeat_factor, 1, 1)[:, :, :target_len]
 
 
1015
 
1016
  # Prepare the conditioning mask (1.0 = condition on this token)
1017
  b, n, f, h, w = init_latents.shape
1018
  conditioning_mask = torch.zeros([b, 1, f, h, w], device=init_latents.device)
1019
- if method in [ConditioningMethod.FIRST_FRAME, ConditioningMethod.FIRST_AND_LAST_FRAME]:
 
 
 
1020
  conditioning_mask[:, :, 0] = 1.0
1021
- if method in [ConditioningMethod.LAST_FRAME, ConditioningMethod.FIRST_AND_LAST_FRAME]:
 
 
 
1022
  conditioning_mask[:, :, -1] = 1.0
1023
 
1024
  # Patchify the init latents and the mask
 
5
  import re
6
  import urllib.parse as ul
7
  from typing import Callable, Dict, List, Optional, Tuple, Union
 
8
 
9
 
10
  import torch
11
  import torch.nn.functional as F
 
12
  from diffusers.image_processor import VaeImageProcessor
13
  from diffusers.models import AutoencoderKL
14
  from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
 
27
 
28
  from xora.models.transformers.transformer3d import Transformer3DModel
29
  from xora.models.transformers.symmetric_patchifier import Patchifier
30
+ from xora.models.autoencoders.vae_encode import (
31
+ get_vae_size_scale_factor,
32
+ vae_decode,
33
+ vae_encode,
34
+ )
35
  from xora.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder
36
  from xora.schedulers.rf import TimestepShifter
37
  from xora.utils.conditioning_method import ConditioningMethod
 
163
  second element is the number of inference steps.
164
  """
165
  if timesteps is not None:
166
+ accepts_timesteps = "timesteps" in set(
167
+ inspect.signature(scheduler.set_timesteps).parameters.keys()
168
+ )
169
  if not accepts_timesteps:
170
  raise ValueError(
171
  f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
 
242
  patchifier=patchifier,
243
  )
244
 
245
+ self.video_scale_factor, self.vae_scale_factor, _ = get_vae_size_scale_factor(
246
+ self.vae
247
+ )
248
  self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
249
 
250
  # Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/utils.py
 
326
  return_tensors="pt",
327
  )
328
  text_input_ids = text_inputs.input_ids
329
+ untruncated_ids = self.tokenizer(
330
+ prompt, padding="longest", return_tensors="pt"
331
+ ).input_ids
332
+
333
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[
334
+ -1
335
+ ] and not torch.equal(text_input_ids, untruncated_ids):
336
+ removed_text = self.tokenizer.batch_decode(
337
+ untruncated_ids[:, max_length - 1 : -1]
338
+ )
339
  logger.warning(
340
  "The following part of your input was truncated because CLIP can only handle sequences up to"
341
  f" {max_length} tokens: {removed_text}"
 
344
  prompt_attention_mask = text_inputs.attention_mask
345
  prompt_attention_mask = prompt_attention_mask.to(device)
346
 
347
+ prompt_embeds = self.text_encoder(
348
+ text_input_ids.to(device), attention_mask=prompt_attention_mask
349
+ )
350
  prompt_embeds = prompt_embeds[0]
351
 
352
  if self.text_encoder is not None:
 
361
  bs_embed, seq_len, _ = prompt_embeds.shape
362
  # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
363
  prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
364
+ prompt_embeds = prompt_embeds.view(
365
+ bs_embed * num_images_per_prompt, seq_len, -1
366
+ )
367
  prompt_attention_mask = prompt_attention_mask.repeat(1, num_images_per_prompt)
368
+ prompt_attention_mask = prompt_attention_mask.view(
369
+ bs_embed * num_images_per_prompt, -1
370
+ )
371
 
372
  # get unconditional embeddings for classifier free guidance
373
  if do_classifier_free_guidance and negative_prompt_embeds is None:
374
  uncond_tokens = [negative_prompt] * batch_size
375
+ uncond_tokens = self._text_preprocessing(
376
+ uncond_tokens, clean_caption=clean_caption
377
+ )
378
  max_length = prompt_embeds.shape[1]
379
  uncond_input = self.tokenizer(
380
  uncond_tokens,
 
389
  negative_prompt_attention_mask = negative_prompt_attention_mask.to(device)
390
 
391
  negative_prompt_embeds = self.text_encoder(
392
+ uncond_input.input_ids.to(device),
393
+ attention_mask=negative_prompt_attention_mask,
394
  )
395
  negative_prompt_embeds = negative_prompt_embeds[0]
396
 
 
398
  # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
399
  seq_len = negative_prompt_embeds.shape[1]
400
 
401
+ negative_prompt_embeds = negative_prompt_embeds.to(
402
+ dtype=dtype, device=device
403
+ )
404
 
405
+ negative_prompt_embeds = negative_prompt_embeds.repeat(
406
+ 1, num_images_per_prompt, 1
407
+ )
408
+ negative_prompt_embeds = negative_prompt_embeds.view(
409
+ batch_size * num_images_per_prompt, seq_len, -1
410
+ )
411
 
412
+ negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(
413
+ 1, num_images_per_prompt
414
+ )
415
+ negative_prompt_attention_mask = negative_prompt_attention_mask.view(
416
+ bs_embed * num_images_per_prompt, -1
417
+ )
418
  else:
419
  negative_prompt_embeds = None
420
  negative_prompt_attention_mask = None
421
 
422
+ return (
423
+ prompt_embeds,
424
+ prompt_attention_mask,
425
+ negative_prompt_embeds,
426
+ negative_prompt_attention_mask,
427
+ )
428
 
429
  # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
430
  def prepare_extra_step_kwargs(self, generator, eta):
 
433
  # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
434
  # and should be between [0, 1]
435
 
436
+ accepts_eta = "eta" in set(
437
+ inspect.signature(self.scheduler.step).parameters.keys()
438
+ )
439
  extra_step_kwargs = {}
440
  if accepts_eta:
441
  extra_step_kwargs["eta"] = eta
442
 
443
  # check if the scheduler accepts generator
444
+ accepts_generator = "generator" in set(
445
+ inspect.signature(self.scheduler.step).parameters.keys()
446
+ )
447
  if accepts_generator:
448
  extra_step_kwargs["generator"] = generator
449
  return extra_step_kwargs
 
460
  negative_prompt_attention_mask=None,
461
  ):
462
  if height % 8 != 0 or width % 8 != 0:
463
+ raise ValueError(
464
+ f"`height` and `width` have to be divisible by 8 but are {height} and {width}."
465
+ )
466
 
467
  if prompt is not None and prompt_embeds is not None:
468
  raise ValueError(
 
473
  raise ValueError(
474
  "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
475
  )
476
+ elif prompt is not None and (
477
+ not isinstance(prompt, str) and not isinstance(prompt, list)
478
+ ):
479
+ raise ValueError(
480
+ f"`prompt` has to be of type `str` or `list` but is {type(prompt)}"
481
+ )
482
 
483
  if prompt is not None and negative_prompt_embeds is not None:
484
  raise ValueError(
 
493
  )
494
 
495
  if prompt_embeds is not None and prompt_attention_mask is None:
496
+ raise ValueError(
497
+ "Must provide `prompt_attention_mask` when specifying `prompt_embeds`."
498
+ )
499
 
500
+ if (
501
+ negative_prompt_embeds is not None
502
+ and negative_prompt_attention_mask is None
503
+ ):
504
+ raise ValueError(
505
+ "Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`."
506
+ )
507
 
508
  if prompt_embeds is not None and negative_prompt_embeds is not None:
509
  if prompt_embeds.shape != negative_prompt_embeds.shape:
 
522
  # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing
523
  def _text_preprocessing(self, text, clean_caption=False):
524
  if clean_caption and not is_bs4_available():
525
+ logger.warn(
526
+ BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`")
527
+ )
528
  logger.warn("Setting `clean_caption` to False...")
529
  clean_caption = False
530
 
531
  if clean_caption and not is_ftfy_available():
532
+ logger.warn(
533
+ BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`")
534
+ )
535
  logger.warn("Setting `clean_caption` to False...")
536
  clean_caption = False
537
 
 
619
  # "123456.."
620
  caption = re.sub(r"\b\d{6,}\b", "", caption)
621
  # filenames:
622
+ caption = re.sub(
623
+ r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption
624
+ )
625
 
626
  #
627
  caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT"""
628
  caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT"""
629
 
630
+ caption = re.sub(
631
+ self.bad_punct_regex, r" ", caption
632
+ ) # ***AUSVERKAUFT***, #AUSVERKAUFT
633
  caption = re.sub(r"\s+\.\s+", r" ", caption) # " . "
634
 
635
  # this-is-my-cute-cat / this_is_my_cute_cat
 
647
  caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption)
648
  caption = re.sub(r"(free\s)?download(\sfree)?", "", caption)
649
  caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption)
650
+ caption = re.sub(
651
+ r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption
652
+ )
653
  caption = re.sub(r"\bpage\s+\d+\b", "", caption)
654
 
655
+ caption = re.sub(
656
+ r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption
657
+ ) # j2d1a2a...
658
 
659
  caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption)
660
 
 
673
 
674
  # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
675
  def prepare_latents(
676
+ self,
677
+ batch_size,
678
+ num_latent_channels,
679
+ num_patches,
680
+ dtype,
681
+ device,
682
+ generator,
683
+ latents=None,
684
+ latents_mask=None,
685
  ):
686
  shape = (
687
  batch_size,
 
696
  )
697
 
698
  if latents is None:
699
+ latents = randn_tensor(
700
+ shape, generator=generator, device=device, dtype=dtype
701
+ )
702
  elif latents_mask is not None:
703
  noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
704
+ latents = latents * latents_mask[..., None] + noise * (
705
+ 1 - latents_mask[..., None]
706
+ )
707
  else:
708
  latents = latents.to(device)
709
 
 
712
  return latents
713
 
714
  @staticmethod
715
+ def classify_height_width_bin(
716
+ height: int, width: int, ratios: dict
717
+ ) -> Tuple[int, int]:
718
  """Returns binned height and width."""
719
  ar = float(height / width)
720
  closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - ar))
 
722
  return int(default_hw[0]), int(default_hw[1])
723
 
724
  @staticmethod
725
+ def resize_and_crop_tensor(
726
+ samples: torch.Tensor, new_width: int, new_height: int
727
+ ) -> torch.Tensor:
728
  n_frames, orig_height, orig_width = samples.shape[-3:]
729
 
730
  # Check if resizing is needed
 
735
 
736
  # Resize
737
  samples = rearrange(samples, "b c n h w -> (b n) c h w")
738
+ samples = F.interpolate(
739
+ samples,
740
+ size=(resized_height, resized_width),
741
+ mode="bilinear",
742
+ align_corners=False,
743
+ )
744
  samples = rearrange(samples, "(b n) c h w -> b c n h w", n=n_frames)
745
 
746
  # Center Crop
 
905
  )
906
  if do_classifier_free_guidance:
907
  prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
908
+ prompt_attention_mask = torch.cat(
909
+ [negative_prompt_attention_mask, prompt_attention_mask], dim=0
910
+ )
911
 
912
  # 3b. Encode and prepare conditioning data
913
  self.video_scale_factor = self.video_scale_factor if is_video else 1
914
  conditioning_method = kwargs.get("conditioning_method", None)
915
  vae_per_channel_normalize = kwargs.get("vae_per_channel_normalize", False)
916
  init_latents, conditioning_mask = self.prepare_conditioning(
917
+ media_items,
918
+ num_frames,
919
+ height,
920
+ width,
921
+ conditioning_method,
922
+ vae_per_channel_normalize,
923
  )
924
 
925
  # 4. Prepare latents.
 
942
  )
943
  if conditioning_mask is not None and is_video:
944
  assert num_images_per_prompt == 1
945
+ conditioning_mask = (
946
+ torch.cat([conditioning_mask] * 2)
947
+ if do_classifier_free_guidance
948
+ else conditioning_mask
949
+ )
950
 
951
  # 5. Prepare timesteps
952
  retrieve_timesteps_kwargs = {}
953
  if isinstance(self.scheduler, TimestepShifter):
954
  retrieve_timesteps_kwargs["samples"] = latents
955
  timesteps, num_inference_steps = retrieve_timesteps(
956
+ self.scheduler,
957
+ num_inference_steps,
958
+ device,
959
+ timesteps,
960
+ **retrieve_timesteps_kwargs,
961
  )
962
 
963
  # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
964
  extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
965
 
966
  # 7. Denoising loop
967
+ num_warmup_steps = max(
968
+ len(timesteps) - num_inference_steps * self.scheduler.order, 0
969
+ )
970
 
971
  with self.progress_bar(total=num_inference_steps) as progress_bar:
972
  for i, t in enumerate(timesteps):
973
+ latent_model_input = (
974
+ torch.cat([latents] * 2) if do_classifier_free_guidance else latents
975
+ )
976
+ latent_model_input = self.scheduler.scale_model_input(
977
+ latent_model_input, t
978
+ )
979
 
980
  latent_frame_rates = (
981
+ torch.ones(
982
+ latent_model_input.shape[0], 1, device=latent_model_input.device
983
+ )
984
+ * latent_frame_rate
985
  )
986
 
987
  current_timestep = t
 
993
  dtype = torch.float32 if is_mps else torch.float64
994
  else:
995
  dtype = torch.int32 if is_mps else torch.int64
996
+ current_timestep = torch.tensor(
997
+ [current_timestep],
998
+ dtype=dtype,
999
+ device=latent_model_input.device,
1000
+ )
1001
  elif len(current_timestep.shape) == 0:
1002
+ current_timestep = current_timestep[None].to(
1003
+ latent_model_input.device
1004
+ )
1005
  # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
1006
+ current_timestep = current_timestep.expand(
1007
+ latent_model_input.shape[0]
1008
+ ).unsqueeze(-1)
1009
  scale_grid = (
1010
+ (
1011
+ 1 / latent_frame_rates,
1012
+ self.vae_scale_factor,
1013
+ self.vae_scale_factor,
1014
+ )
1015
  if self.transformer.use_rope
1016
  else None
1017
  )
 
1040
  # perform guidance
1041
  if do_classifier_free_guidance:
1042
  noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1043
+ noise_pred = noise_pred_uncond + guidance_scale * (
1044
+ noise_pred_text - noise_pred_uncond
1045
+ )
1046
  current_timestep, _ = current_timestep.chunk(2)
1047
 
1048
  # learned sigma
1049
+ if (
1050
+ self.transformer.config.out_channels // 2
1051
+ == self.transformer.config.in_channels
1052
+ ):
1053
  noise_pred = noise_pred.chunk(2, dim=1)[0]
1054
 
1055
  # compute previous image: x_t -> x_t-1
 
1062
  )[0]
1063
 
1064
  # call the callback, if provided
1065
+ if i == len(timesteps) - 1 or (
1066
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
1067
+ ):
1068
  progress_bar.update()
1069
 
1070
  if callback_on_step_end is not None:
 
1075
  output_height=latent_height,
1076
  output_width=latent_width,
1077
  output_num_frames=latent_num_frames,
1078
+ out_channels=self.transformer.in_channels
1079
+ // math.prod(self.patchifier.patch_size),
1080
  )
1081
  if output_type != "latent":
1082
  image = vae_decode(
1083
+ latents,
1084
+ self.vae,
1085
+ is_video,
1086
+ vae_per_channel_normalize=kwargs["vae_per_channel_normalize"],
1087
  )
1088
  image = self.image_processor.postprocess(image, output_type=output_type)
1089
 
 
1136
  vae_per_channel_normalize=vae_per_channel_normalize,
1137
  ).float()
1138
 
1139
+ init_len, target_len = (
1140
+ init_latents.shape[2],
1141
+ num_frames // self.video_scale_factor,
1142
+ )
1143
  if isinstance(self.vae, CausalVideoAutoencoder):
1144
  target_len += 1
1145
  init_latents = init_latents[:, :, :target_len]
1146
  if target_len > init_len:
1147
  repeat_factor = (target_len + init_len - 1) // init_len # Ceiling division
1148
+ init_latents = init_latents.repeat(1, 1, repeat_factor, 1, 1)[
1149
+ :, :, :target_len
1150
+ ]
1151
 
1152
  # Prepare the conditioning mask (1.0 = condition on this token)
1153
  b, n, f, h, w = init_latents.shape
1154
  conditioning_mask = torch.zeros([b, 1, f, h, w], device=init_latents.device)
1155
+ if method in [
1156
+ ConditioningMethod.FIRST_FRAME,
1157
+ ConditioningMethod.FIRST_AND_LAST_FRAME,
1158
+ ]:
1159
  conditioning_mask[:, :, 0] = 1.0
1160
+ if method in [
1161
+ ConditioningMethod.LAST_FRAME,
1162
+ ConditioningMethod.FIRST_AND_LAST_FRAME,
1163
+ ]:
1164
  conditioning_mask[:, :, -1] = 1.0
1165
 
1166
  # Patchify the init latents and the mask
xora/schedulers/rf.py CHANGED
@@ -22,7 +22,9 @@ def simple_diffusion_resolution_dependent_timestep_shift(
22
  elif len(samples.shape) in [4, 5]:
23
  m = math.prod(samples.shape[2:])
24
  else:
25
- raise ValueError("Samples must have shape (b, t, c), (b, c, h, w) or (b, c, f, h, w)")
 
 
26
  snr = (timesteps / (1 - timesteps)) ** 2
27
  shift_snr = torch.log(snr) + 2 * math.log(m / n)
28
  shifted_timesteps = torch.sigmoid(0.5 * shift_snr)
@@ -46,7 +48,9 @@ def get_normal_shift(
46
  return m * n_tokens + b
47
 
48
 
49
- def sd3_resolution_dependent_timestep_shift(samples: Tensor, timesteps: Tensor) -> Tensor:
 
 
50
  """
51
  Shifts the timestep schedule as a function of the generated resolution.
52
 
@@ -70,7 +74,9 @@ def sd3_resolution_dependent_timestep_shift(samples: Tensor, timesteps: Tensor)
70
  elif len(samples.shape) in [4, 5]:
71
  m = math.prod(samples.shape[2:])
72
  else:
73
- raise ValueError("Samples must have shape (b, t, c), (b, c, h, w) or (b, c, f, h, w)")
 
 
74
 
75
  shift = get_normal_shift(m)
76
  return time_shift(shift, 1, timesteps)
@@ -104,12 +110,21 @@ class RectifiedFlowScheduler(SchedulerMixin, ConfigMixin, TimestepShifter):
104
  order = 1
105
 
106
  @register_to_config
107
- def __init__(self, num_train_timesteps=1000, shifting: Optional[str] = None, base_resolution: int = 32**2):
 
 
 
 
 
108
  super().__init__()
109
  self.init_noise_sigma = 1.0
110
  self.num_inference_steps = None
111
- self.timesteps = self.sigmas = torch.linspace(1, 1 / num_train_timesteps, num_train_timesteps)
112
- self.delta_timesteps = self.timesteps - torch.cat([self.timesteps[1:], torch.zeros_like(self.timesteps[-1:])])
 
 
 
 
113
  self.shifting = shifting
114
  self.base_resolution = base_resolution
115
 
@@ -117,10 +132,17 @@ class RectifiedFlowScheduler(SchedulerMixin, ConfigMixin, TimestepShifter):
117
  if self.shifting == "SD3":
118
  return sd3_resolution_dependent_timestep_shift(samples, timesteps)
119
  elif self.shifting == "SimpleDiffusion":
120
- return simple_diffusion_resolution_dependent_timestep_shift(samples, timesteps, self.base_resolution)
 
 
121
  return timesteps
122
 
123
- def set_timesteps(self, num_inference_steps: int, samples: Tensor, device: Union[str, torch.device] = None):
 
 
 
 
 
124
  """
125
  Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
126
 
@@ -130,13 +152,19 @@ class RectifiedFlowScheduler(SchedulerMixin, ConfigMixin, TimestepShifter):
130
  device (`Union[str, torch.device]`, *optional*): The device to which the timesteps tensor will be moved.
131
  """
132
  num_inference_steps = min(self.config.num_train_timesteps, num_inference_steps)
133
- timesteps = torch.linspace(1, 1 / num_inference_steps, num_inference_steps).to(device)
 
 
134
  self.timesteps = self.shift_timesteps(samples, timesteps)
135
- self.delta_timesteps = self.timesteps - torch.cat([self.timesteps[1:], torch.zeros_like(self.timesteps[-1:])])
 
 
136
  self.num_inference_steps = num_inference_steps
137
  self.sigmas = self.timesteps
138
 
139
- def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
 
 
140
  # pylint: disable=unused-argument
141
  """
142
  Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
@@ -206,7 +234,9 @@ class RectifiedFlowScheduler(SchedulerMixin, ConfigMixin, TimestepShifter):
206
  else:
207
  # Timestep per token
208
  assert timestep.ndim == 2
209
- current_index = (self.timesteps[:, None, None] - timestep[None]).abs().argmin(dim=0)
 
 
210
  dt = self.delta_timesteps[current_index]
211
  # Special treatment for zero timestep tokens - set dt to 0 so prev_sample = sample
212
  dt = torch.where(timestep == 0.0, torch.zeros_like(dt), dt)[..., None]
@@ -228,4 +258,4 @@ class RectifiedFlowScheduler(SchedulerMixin, ConfigMixin, TimestepShifter):
228
  sigmas = append_dims(sigmas, original_samples.ndim)
229
  alphas = 1 - sigmas
230
  noisy_samples = alphas * original_samples + sigmas * noise
231
- return noisy_samples
 
22
  elif len(samples.shape) in [4, 5]:
23
  m = math.prod(samples.shape[2:])
24
  else:
25
+ raise ValueError(
26
+ "Samples must have shape (b, t, c), (b, c, h, w) or (b, c, f, h, w)"
27
+ )
28
  snr = (timesteps / (1 - timesteps)) ** 2
29
  shift_snr = torch.log(snr) + 2 * math.log(m / n)
30
  shifted_timesteps = torch.sigmoid(0.5 * shift_snr)
 
48
  return m * n_tokens + b
49
 
50
 
51
+ def sd3_resolution_dependent_timestep_shift(
52
+ samples: Tensor, timesteps: Tensor
53
+ ) -> Tensor:
54
  """
55
  Shifts the timestep schedule as a function of the generated resolution.
56
 
 
74
  elif len(samples.shape) in [4, 5]:
75
  m = math.prod(samples.shape[2:])
76
  else:
77
+ raise ValueError(
78
+ "Samples must have shape (b, t, c), (b, c, h, w) or (b, c, f, h, w)"
79
+ )
80
 
81
  shift = get_normal_shift(m)
82
  return time_shift(shift, 1, timesteps)
 
110
  order = 1
111
 
112
  @register_to_config
113
+ def __init__(
114
+ self,
115
+ num_train_timesteps=1000,
116
+ shifting: Optional[str] = None,
117
+ base_resolution: int = 32**2,
118
+ ):
119
  super().__init__()
120
  self.init_noise_sigma = 1.0
121
  self.num_inference_steps = None
122
+ self.timesteps = self.sigmas = torch.linspace(
123
+ 1, 1 / num_train_timesteps, num_train_timesteps
124
+ )
125
+ self.delta_timesteps = self.timesteps - torch.cat(
126
+ [self.timesteps[1:], torch.zeros_like(self.timesteps[-1:])]
127
+ )
128
  self.shifting = shifting
129
  self.base_resolution = base_resolution
130
 
 
132
  if self.shifting == "SD3":
133
  return sd3_resolution_dependent_timestep_shift(samples, timesteps)
134
  elif self.shifting == "SimpleDiffusion":
135
+ return simple_diffusion_resolution_dependent_timestep_shift(
136
+ samples, timesteps, self.base_resolution
137
+ )
138
  return timesteps
139
 
140
+ def set_timesteps(
141
+ self,
142
+ num_inference_steps: int,
143
+ samples: Tensor,
144
+ device: Union[str, torch.device] = None,
145
+ ):
146
  """
147
  Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
148
 
 
152
  device (`Union[str, torch.device]`, *optional*): The device to which the timesteps tensor will be moved.
153
  """
154
  num_inference_steps = min(self.config.num_train_timesteps, num_inference_steps)
155
+ timesteps = torch.linspace(1, 1 / num_inference_steps, num_inference_steps).to(
156
+ device
157
+ )
158
  self.timesteps = self.shift_timesteps(samples, timesteps)
159
+ self.delta_timesteps = self.timesteps - torch.cat(
160
+ [self.timesteps[1:], torch.zeros_like(self.timesteps[-1:])]
161
+ )
162
  self.num_inference_steps = num_inference_steps
163
  self.sigmas = self.timesteps
164
 
165
+ def scale_model_input(
166
+ self, sample: torch.FloatTensor, timestep: Optional[int] = None
167
+ ) -> torch.FloatTensor:
168
  # pylint: disable=unused-argument
169
  """
170
  Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
 
234
  else:
235
  # Timestep per token
236
  assert timestep.ndim == 2
237
+ current_index = (
238
+ (self.timesteps[:, None, None] - timestep[None]).abs().argmin(dim=0)
239
+ )
240
  dt = self.delta_timesteps[current_index]
241
  # Special treatment for zero timestep tokens - set dt to 0 so prev_sample = sample
242
  dt = torch.where(timestep == 0.0, torch.zeros_like(dt), dt)[..., None]
 
258
  sigmas = append_dims(sigmas, original_samples.ndim)
259
  alphas = 1 - sigmas
260
  noisy_samples = alphas * original_samples + sigmas * noise
261
+ return noisy_samples
xora/utils/conditioning_method.py CHANGED
@@ -1,7 +1,8 @@
1
  from enum import Enum
2
 
 
3
  class ConditioningMethod(Enum):
4
  UNCONDITIONAL = "unconditional"
5
  FIRST_FRAME = "first_frame"
6
  LAST_FRAME = "last_frame"
7
- FIRST_AND_LAST_FRAME = "first_and_last_frame"
 
1
  from enum import Enum
2
 
3
+
4
  class ConditioningMethod(Enum):
5
  UNCONDITIONAL = "unconditional"
6
  FIRST_FRAME = "first_frame"
7
  LAST_FRAME = "last_frame"
8
+ FIRST_AND_LAST_FRAME = "first_and_last_frame"
xora/utils/torch_utils.py CHANGED
@@ -1,15 +1,19 @@
1
  import torch
2
  from torch import nn
3
 
 
4
  def append_dims(x: torch.Tensor, target_dims: int) -> torch.Tensor:
5
  """Appends dimensions to the end of a tensor until it has target_dims dimensions."""
6
  dims_to_append = target_dims - x.ndim
7
  if dims_to_append < 0:
8
- raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less")
 
 
9
  elif dims_to_append == 0:
10
  return x
11
  return x[(...,) + (None,) * dims_to_append]
12
 
 
13
  class Identity(nn.Module):
14
  """A placeholder identity operator that is argument-insensitive."""
15
 
 
1
  import torch
2
  from torch import nn
3
 
4
+
5
  def append_dims(x: torch.Tensor, target_dims: int) -> torch.Tensor:
6
  """Appends dimensions to the end of a tensor until it has target_dims dimensions."""
7
  dims_to_append = target_dims - x.ndim
8
  if dims_to_append < 0:
9
+ raise ValueError(
10
+ f"input has {x.ndim} dims but target_dims is {target_dims}, which is less"
11
+ )
12
  elif dims_to_append == 0:
13
  return x
14
  return x[(...,) + (None,) * dims_to_append]
15
 
16
+
17
  class Identity(nn.Module):
18
  """A placeholder identity operator that is argument-insensitive."""
19