YulianSa commited on
Commit
ef198e0
·
1 Parent(s): 51a80da

Add application file

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +4 -0
  2. app.py +139 -0
  3. blender/blender_lrm_script.py +1387 -0
  4. blender/distributed_uniform_lrm.py +122 -0
  5. blender/install_addon.py +15 -0
  6. canonicalize/__init__.py +0 -0
  7. canonicalize/models/attention.py +344 -0
  8. canonicalize/models/imageproj.py +118 -0
  9. canonicalize/models/refunet.py +127 -0
  10. canonicalize/models/resnet.py +209 -0
  11. canonicalize/models/transformer_mv2d.py +976 -0
  12. canonicalize/models/unet.py +475 -0
  13. canonicalize/models/unet_blocks.py +596 -0
  14. canonicalize/models/unet_mv2d_blocks.py +924 -0
  15. canonicalize/models/unet_mv2d_condition.py +1502 -0
  16. canonicalize/models/unet_mv2d_ref.py +1543 -0
  17. canonicalize/pipeline_canonicalize.py +518 -0
  18. canonicalize/util.py +128 -0
  19. configs/canonicalization-infer.yaml +22 -0
  20. configs/mesh-slrm-infer.yaml +25 -0
  21. data/test_list.json +111 -0
  22. data/train_list.json +0 -0
  23. infer_api.py +881 -0
  24. infer_canonicalize.py +215 -0
  25. infer_multiview.py +274 -0
  26. infer_refine.py +353 -0
  27. infer_slrm.py +199 -0
  28. input_cases/1.png +0 -0
  29. input_cases/2.png +0 -0
  30. input_cases/3.png +0 -0
  31. input_cases/4.png +0 -0
  32. input_cases/ayaka.png +0 -0
  33. input_cases/firefly2.png +0 -0
  34. input_cases_apose/1.png +0 -0
  35. input_cases_apose/2.png +0 -0
  36. input_cases_apose/3.png +0 -0
  37. input_cases_apose/4.png +0 -0
  38. input_cases_apose/ayaka.png +0 -0
  39. input_cases_apose/belle.png +0 -0
  40. input_cases_apose/firefly.png +0 -0
  41. multiview/__init__.py +0 -0
  42. multiview/fixed_prompt_embeds_6view/clr_embeds.pt +3 -0
  43. multiview/fixed_prompt_embeds_6view/normal_embeds.pt +3 -0
  44. multiview/models/transformer_mv2d_image.py +995 -0
  45. multiview/models/transformer_mv2d_rowwise.py +972 -0
  46. multiview/models/transformer_mv2d_self_rowwise.py +1042 -0
  47. multiview/models/unet_mv2d_blocks.py +980 -0
  48. multiview/models/unet_mv2d_condition.py +1685 -0
  49. multiview/pipeline_multiclass.py +656 -0
  50. refine/func.py +427 -0
.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ ckpt
2
+ result
3
+ **/__pycache__/
4
+ **/.DS_Store
app.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import glob
4
+ import torch
5
+ import random
6
+ from tempfile import NamedTemporaryFile
7
+ from infer_api import InferAPI
8
+ from PIL import Image
9
+
10
+ config_canocalize = {
11
+ 'config_path': './configs/canonicalization-infer.yaml',
12
+ }
13
+ config_multiview = {}
14
+ config_slrm = {
15
+ 'config_path': './configs/mesh-slrm-infer.yaml'
16
+ }
17
+ config_refine = {}
18
+
19
+ EXAMPLE_IMAGES = glob.glob("./input_cases/*")
20
+ EXAMPLE_APOSE_IMAGES = glob.glob("./input_cases_apose/*")
21
+
22
+ infer_api = InferAPI(config_canocalize, config_multiview, config_slrm, config_refine)
23
+
24
+ REMINDER = """
25
+ ### Reminder:
26
+ 1. **Reference Image**:
27
+ - You can upload any reference image (with or without background).
28
+ - If the image has an alpha channel (transparency), background segmentation will be automatically performed.
29
+ - Alternatively, you can pre-segment the background using other tools and upload the result directly.
30
+ - A-pose images are also supported.
31
+
32
+ 2. Real person images generally work well, but note that normals may appear smoother than expected. You can try to use other monocular normal estimation models.
33
+
34
+ 3. The base human model in the output is uncolored due to potential NSFW concerns. If you need colored results, please refer to the official GitHub repository for instructions.
35
+ """
36
+
37
+ # 示例占位函数 - 需替换实际模型
38
+ def arbitrary_to_apose(image, seed):
39
+ # convert image to PIL.Image
40
+ image = Image.fromarray(image)
41
+ return infer_api.genStage1(image, seed)
42
+
43
+ def apose_to_multiview(apose_img, seed):
44
+ # convert image to PIL.Image
45
+ apose_img = Image.fromarray(apose_img)
46
+ return infer_api.genStage2(apose_img, seed, num_levels=1)[0]["images"]
47
+
48
+ def multiview_to_mesh(images):
49
+ mesh_files = infer_api.genStage3(images)
50
+ return mesh_files
51
+
52
+ def refine_mesh(apose_img, mesh1, mesh2, mesh3, seed):
53
+ apose_img = Image.fromarray(apose_img)
54
+ infer_api.genStage2(apose_img, seed, num_levels=2)
55
+ print(infer_api.multiview_infer.results.keys())
56
+ refined = infer_api.genStage4([mesh1, mesh2, mesh3], infer_api.multiview_infer.results)
57
+ return refined
58
+
59
+ with gr.Blocks(title="StdGEN: Semantically Decomposed 3D Character Generation from Single Images") as demo:
60
+ gr.Markdown(REMINDER)
61
+ with gr.Row():
62
+ with gr.Column():
63
+ gr.Markdown("## 1. Reference Image to A-pose Image")
64
+ input_image = gr.Image(label="Input Reference Image", type="numpy", width=384, height=384)
65
+ gr.Examples(
66
+ examples=EXAMPLE_IMAGES,
67
+ inputs=input_image,
68
+ label="Click to use sample images",
69
+ )
70
+ seed_input = gr.Number(
71
+ label="Seed",
72
+ value=42,
73
+ precision=0,
74
+ interactive=True
75
+ )
76
+ pose_btn = gr.Button("Convert")
77
+ with gr.Column():
78
+ gr.Markdown("## 2. Multi-view Generation")
79
+ a_pose_image = gr.Image(label="A-pose Result", type="numpy", width=384, height=384)
80
+ gr.Examples(
81
+ examples=EXAMPLE_APOSE_IMAGES,
82
+ inputs=a_pose_image,
83
+ label="Click to use sample A-pose images",
84
+ )
85
+ seed_input2 = gr.Number(
86
+ label="Seed",
87
+ value=42,
88
+ precision=0,
89
+ interactive=True
90
+ )
91
+ view_btn = gr.Button("Generate Multi-view Images")
92
+
93
+ with gr.Column():
94
+ gr.Markdown("## 3. Semantic-aware Reconstruction")
95
+ multiview_gallery = gr.Gallery(
96
+ label="Multi-view results",
97
+ columns=2,
98
+ interactive=False,
99
+ height="None"
100
+ )
101
+ mesh_btn = gr.Button("Reconstruct")
102
+
103
+ with gr.Row():
104
+ mesh_cols = [gr.Model3D(label=f"Mesh {i+1}", interactive=False, height=384) for i in range(3)]
105
+ full_mesh = gr.Model3D(label="Whole Mesh", height=384)
106
+ refine_btn = gr.Button("Refine")
107
+
108
+ gr.Markdown("## 4. Mesh refinement")
109
+ with gr.Row():
110
+ refined_meshes = [gr.Model3D(label=f"refined mesh {i+1}", height=384) for i in range(3)]
111
+ refined_full_mesh = gr.Model3D(label="refined whole mesh", height=384)
112
+
113
+ # 交互逻辑
114
+ pose_btn.click(
115
+ arbitrary_to_apose,
116
+ inputs=[input_image, seed_input],
117
+ outputs=a_pose_image
118
+ )
119
+
120
+ view_btn.click(
121
+ apose_to_multiview,
122
+ inputs=[a_pose_image, seed_input2],
123
+ outputs=multiview_gallery
124
+ )
125
+
126
+ mesh_btn.click(
127
+ multiview_to_mesh,
128
+ inputs=multiview_gallery,
129
+ outputs=[*mesh_cols, full_mesh]
130
+ )
131
+
132
+ refine_btn.click(
133
+ refine_mesh,
134
+ inputs=[a_pose_image, *mesh_cols, seed_input2],
135
+ outputs=[refined_meshes[2], refined_meshes[0], refined_meshes[1], refined_full_mesh]
136
+ )
137
+
138
+ if __name__ == "__main__":
139
+ demo.launch()
blender/blender_lrm_script.py ADDED
@@ -0,0 +1,1387 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Blender script to render images of 3D models."""
2
+
3
+ import argparse
4
+ import json
5
+ import math
6
+ import os
7
+ import random
8
+ import sys
9
+ from typing import Any, Callable, Dict, Generator, List, Literal, Optional, Set, Tuple
10
+
11
+ import bpy
12
+ import numpy as np
13
+ from mathutils import Matrix, Vector
14
+ import pdb
15
+ MAX_DEPTH = 5.0
16
+ import shutil
17
+ IMPORT_FUNCTIONS: Dict[str, Callable] = {
18
+ "obj": bpy.ops.import_scene.obj,
19
+ "glb": bpy.ops.import_scene.gltf,
20
+ "gltf": bpy.ops.import_scene.gltf,
21
+ "usd": bpy.ops.import_scene.usd,
22
+ "fbx": bpy.ops.import_scene.fbx,
23
+ "stl": bpy.ops.import_mesh.stl,
24
+ "usda": bpy.ops.import_scene.usda,
25
+ "dae": bpy.ops.wm.collada_import,
26
+ "ply": bpy.ops.import_mesh.ply,
27
+ "abc": bpy.ops.wm.alembic_import,
28
+ "blend": bpy.ops.wm.append,
29
+ "vrm": bpy.ops.import_scene.vrm,
30
+ }
31
+
32
+ configs = {
33
+ "custom2": {"camera_pose": "z-circular-elevated", 'elevation_range': [0,0], "rotate": 0.0},
34
+ "custom_top": {"camera_pose": "z-circular-elevated", 'elevation_range': [90,90], "rotate": 0.0, "render_num": 1},
35
+ "custom_bottom": {"camera_pose": "z-circular-elevated", 'elevation_range': [-90,-90], "rotate": 0.0, "render_num": 1},
36
+ "custom_face": {"camera_pose": "z-circular-elevated", 'elevation_range': [0,0], "rotate": 0.0, "render_num": 8},
37
+ "random": {"camera_pose": "random", 'elevation_range': [-90,90], "rotate": 0.0, "render_num": 20},
38
+ }
39
+
40
+
41
+ def reset_cameras() -> None:
42
+ """Resets the cameras in the scene to a single default camera."""
43
+ # Delete all existing cameras
44
+ bpy.ops.object.select_all(action="DESELECT")
45
+ bpy.ops.object.select_by_type(type="CAMERA")
46
+ bpy.ops.object.delete()
47
+
48
+ # Create a new camera with default properties
49
+ bpy.ops.object.camera_add()
50
+
51
+ # Rename the new camera to 'NewDefaultCamera'
52
+ new_camera = bpy.context.active_object
53
+ new_camera.name = "Camera"
54
+
55
+ # Set the new camera as the active camera for the scene
56
+ scene.camera = new_camera
57
+
58
+
59
+ def _sample_spherical(
60
+ radius_min: float = 1.5,
61
+ radius_max: float = 2.0,
62
+ maxz: float = 1.6,
63
+ minz: float = -0.75,
64
+ ) -> np.ndarray:
65
+ """Sample a random point in a spherical shell.
66
+
67
+ Args:
68
+ radius_min (float): Minimum radius of the spherical shell.
69
+ radius_max (float): Maximum radius of the spherical shell.
70
+ maxz (float): Maximum z value of the spherical shell.
71
+ minz (float): Minimum z value of the spherical shell.
72
+
73
+ Returns:
74
+ np.ndarray: A random (x, y, z) point in the spherical shell.
75
+ """
76
+ correct = False
77
+ vec = np.array([0, 0, 0])
78
+ while not correct:
79
+ vec = np.random.uniform(-1, 1, 3)
80
+ # vec[2] = np.abs(vec[2])
81
+ radius = np.random.uniform(radius_min, radius_max, 1)
82
+ vec = vec / np.linalg.norm(vec, axis=0) * radius[0]
83
+ if maxz > vec[2] > minz:
84
+ correct = True
85
+ return vec
86
+
87
+
88
+ def randomize_camera(
89
+ radius_min: float = 1.5,
90
+ radius_max: float = 2.2,
91
+ maxz: float = 2.2,
92
+ minz: float = -2.2,
93
+ only_northern_hemisphere: bool = False,
94
+ ) -> bpy.types.Object:
95
+ """Randomizes the camera location and rotation inside of a spherical shell.
96
+
97
+ Args:
98
+ radius_min (float, optional): Minimum radius of the spherical shell. Defaults to
99
+ 1.5.
100
+ radius_max (float, optional): Maximum radius of the spherical shell. Defaults to
101
+ 2.0.
102
+ maxz (float, optional): Maximum z value of the spherical shell. Defaults to 1.6.
103
+ minz (float, optional): Minimum z value of the spherical shell. Defaults to
104
+ -0.75.
105
+ only_northern_hemisphere (bool, optional): Whether to only sample points in the
106
+ northern hemisphere. Defaults to False.
107
+
108
+ Returns:
109
+ bpy.types.Object: The camera object.
110
+ """
111
+
112
+ x, y, z = _sample_spherical(
113
+ radius_min=radius_min, radius_max=radius_max, maxz=maxz, minz=minz
114
+ )
115
+ camera = bpy.data.objects["Camera"]
116
+
117
+ # only positive z
118
+ if only_northern_hemisphere:
119
+ z = abs(z)
120
+
121
+ camera.location = Vector(np.array([x, y, z]))
122
+
123
+ direction = -camera.location
124
+ rot_quat = direction.to_track_quat("-Z", "Y")
125
+ camera.rotation_euler = rot_quat.to_euler()
126
+
127
+ return camera
128
+
129
+
130
+ cached_cameras = []
131
+
132
+ def randomize_camera_with_cache(
133
+ radius_min: float = 1.5,
134
+ radius_max: float = 2.2,
135
+ maxz: float = 2.2,
136
+ minz: float = -2.2,
137
+ only_northern_hemisphere: bool = False,
138
+ idx: int = 0,
139
+ ) -> bpy.types.Object:
140
+
141
+ assert len(cached_cameras) >= idx
142
+
143
+ if len(cached_cameras) == idx:
144
+ x, y, z = _sample_spherical(
145
+ radius_min=radius_min, radius_max=radius_max, maxz=maxz, minz=minz
146
+ )
147
+ cached_cameras.append((x, y, z))
148
+ else:
149
+ x, y, z = cached_cameras[idx]
150
+
151
+ camera = bpy.data.objects["Camera"]
152
+
153
+ # only positive z
154
+ if only_northern_hemisphere:
155
+ z = abs(z)
156
+
157
+ camera.location = Vector(np.array([x, y, z]))
158
+
159
+ direction = -camera.location
160
+ rot_quat = direction.to_track_quat("-Z", "Y")
161
+ camera.rotation_euler = rot_quat.to_euler()
162
+
163
+ return camera
164
+
165
+
166
+ def set_camera(direction, camera_dist=2.0, camera_offset=0.0):
167
+ camera = bpy.data.objects["Camera"]
168
+ camera_pos = -camera_dist * direction
169
+ if type(camera_offset) == float:
170
+ camera_offset = Vector(np.array([0., 0., 0.]))
171
+ camera_pos += camera_offset
172
+ camera.location = camera_pos
173
+
174
+ # https://blender.stackexchange.com/questions/5210/pointing-the-camera-in-a-particular-direction-programmatically
175
+ rot_quat = direction.to_track_quat("-Z", "Y")
176
+ camera.rotation_euler = rot_quat.to_euler()
177
+ return camera
178
+
179
+
180
+ def _set_camera_at_size(i: int, scale: float = 1.5) -> bpy.types.Object:
181
+ """Debugging function to set the camera on the 6 faces of a cube.
182
+
183
+ Args:
184
+ i (int): Index of the face of the cube.
185
+ scale (float, optional): Scale of the cube. Defaults to 1.5.
186
+
187
+ Returns:
188
+ bpy.types.Object: The camera object.
189
+ """
190
+ if i == 0:
191
+ x, y, z = scale, 0, 0
192
+ elif i == 1:
193
+ x, y, z = -scale, 0, 0
194
+ elif i == 2:
195
+ x, y, z = 0, scale, 0
196
+ elif i == 3:
197
+ x, y, z = 0, -scale, 0
198
+ elif i == 4:
199
+ x, y, z = 0, 0, scale
200
+ elif i == 5:
201
+ x, y, z = 0, 0, -scale
202
+ else:
203
+ raise ValueError(f"Invalid index: i={i}, must be int in range [0, 5].")
204
+ camera = bpy.data.objects["Camera"]
205
+ camera.location = Vector(np.array([x, y, z]))
206
+ direction = -camera.location
207
+ rot_quat = direction.to_track_quat("-Z", "Y")
208
+ camera.rotation_euler = rot_quat.to_euler()
209
+ return camera
210
+
211
+
212
+ def _create_light(
213
+ name: str,
214
+ light_type: Literal["POINT", "SUN", "SPOT", "AREA"],
215
+ location: Tuple[float, float, float],
216
+ rotation: Tuple[float, float, float],
217
+ energy: float,
218
+ use_shadow: bool = False,
219
+ specular_factor: float = 1.0,
220
+ ):
221
+ """Creates a light object.
222
+
223
+ Args:
224
+ name (str): Name of the light object.
225
+ light_type (Literal["POINT", "SUN", "SPOT", "AREA"]): Type of the light.
226
+ location (Tuple[float, float, float]): Location of the light.
227
+ rotation (Tuple[float, float, float]): Rotation of the light.
228
+ energy (float): Energy of the light.
229
+ use_shadow (bool, optional): Whether to use shadows. Defaults to False.
230
+ specular_factor (float, optional): Specular factor of the light. Defaults to 1.0.
231
+
232
+ Returns:
233
+ bpy.types.Object: The light object.
234
+ """
235
+
236
+ light_data = bpy.data.lights.new(name=name, type=light_type)
237
+ light_object = bpy.data.objects.new(name, light_data)
238
+ bpy.context.collection.objects.link(light_object)
239
+ light_object.location = location
240
+ light_object.rotation_euler = rotation
241
+ light_data.use_shadow = use_shadow
242
+ light_data.specular_factor = specular_factor
243
+ light_data.energy = energy
244
+ return light_object
245
+
246
+
247
+ def reset_scene() -> None:
248
+ """Resets the scene to a clean state.
249
+
250
+ Returns:
251
+ None
252
+ """
253
+ # delete everything that isn't part of a camera or a light
254
+ for obj in bpy.data.objects:
255
+ if obj.type not in {"CAMERA", "LIGHT"}:
256
+ bpy.data.objects.remove(obj, do_unlink=True)
257
+
258
+ # delete all the materials
259
+ for material in bpy.data.materials:
260
+ bpy.data.materials.remove(material, do_unlink=True)
261
+
262
+ # delete all the textures
263
+ for texture in bpy.data.textures:
264
+ bpy.data.textures.remove(texture, do_unlink=True)
265
+
266
+ # delete all the images
267
+ for image in bpy.data.images:
268
+ bpy.data.images.remove(image, do_unlink=True)
269
+
270
+ # delete all the collider collections
271
+ for collider in bpy.data.collections:
272
+ if collider.name != "Collection":
273
+ bpy.data.collections.remove(collider, do_unlink=True)
274
+
275
+
276
+ def load_object(object_path: str) -> None:
277
+ """Loads a model with a supported file extension into the scene.
278
+
279
+ Args:
280
+ object_path (str): Path to the model file.
281
+
282
+ Raises:
283
+ ValueError: If the file extension is not supported.
284
+
285
+ Returns:
286
+ None
287
+ """
288
+ file_extension = object_path.split(".")[-1].lower()
289
+ if file_extension is None:
290
+ raise ValueError(f"Unsupported file type: {object_path}")
291
+
292
+ if file_extension == "usdz":
293
+ # install usdz io package
294
+ dirname = os.path.dirname(os.path.realpath(__file__))
295
+ usdz_package = os.path.join(dirname, "io_scene_usdz.zip")
296
+ bpy.ops.preferences.addon_install(filepath=usdz_package)
297
+ # enable it
298
+ addon_name = "io_scene_usdz"
299
+ bpy.ops.preferences.addon_enable(module=addon_name)
300
+ # import the usdz
301
+ from io_scene_usdz.import_usdz import import_usdz
302
+
303
+ import_usdz(context, filepath=object_path, materials=True, animations=True)
304
+ return None
305
+
306
+ # load from existing import functions
307
+ import_function = IMPORT_FUNCTIONS[file_extension]
308
+
309
+ if file_extension == "blend":
310
+ import_function(directory=object_path, link=False)
311
+ elif file_extension in {"glb", "gltf"}:
312
+ import_function(filepath=object_path, merge_vertices=True)
313
+ else:
314
+ import_function(filepath=object_path)
315
+
316
+
317
+ def scene_bbox(
318
+ single_obj: Optional[bpy.types.Object] = None, ignore_matrix: bool = False
319
+ ) -> Tuple[Vector, Vector]:
320
+ """Returns the bounding box of the scene.
321
+
322
+ Taken from Shap-E rendering script
323
+ (https://github.com/openai/shap-e/blob/main/shap_e/rendering/blender/blender_script.py#L68-L82)
324
+
325
+ Args:
326
+ single_obj (Optional[bpy.types.Object], optional): If not None, only computes
327
+ the bounding box for the given object. Defaults to None.
328
+ ignore_matrix (bool, optional): Whether to ignore the object's matrix. Defaults
329
+ to False.
330
+
331
+ Raises:
332
+ RuntimeError: If there are no objects in the scene.
333
+
334
+ Returns:
335
+ Tuple[Vector, Vector]: The minimum and maximum coordinates of the bounding box.
336
+ """
337
+ bbox_min = (math.inf,) * 3
338
+ bbox_max = (-math.inf,) * 3
339
+ found = False
340
+ for obj in get_scene_meshes() if single_obj is None else [single_obj]:
341
+ found = True
342
+ for coord in obj.bound_box:
343
+ coord = Vector(coord)
344
+ if not ignore_matrix:
345
+ coord = obj.matrix_world @ coord
346
+ bbox_min = tuple(min(x, y) for x, y in zip(bbox_min, coord))
347
+ bbox_max = tuple(max(x, y) for x, y in zip(bbox_max, coord))
348
+
349
+ if not found:
350
+ raise RuntimeError("no objects in scene to compute bounding box for")
351
+
352
+ return Vector(bbox_min), Vector(bbox_max)
353
+
354
+
355
+ def get_scene_root_objects() -> Generator[bpy.types.Object, None, None]:
356
+ """Returns all root objects in the scene.
357
+
358
+ Yields:
359
+ Generator[bpy.types.Object, None, None]: Generator of all root objects in the
360
+ scene.
361
+ """
362
+ for obj in bpy.context.scene.objects.values():
363
+ if not obj.parent:
364
+ yield obj
365
+
366
+
367
+ def get_scene_meshes() -> Generator[bpy.types.Object, None, None]:
368
+ """Returns all meshes in the scene.
369
+
370
+ Yields:
371
+ Generator[bpy.types.Object, None, None]: Generator of all meshes in the scene.
372
+ """
373
+ for obj in bpy.context.scene.objects.values():
374
+ if isinstance(obj.data, (bpy.types.Mesh)):
375
+ yield obj
376
+
377
+
378
+ def get_3x4_RT_matrix_from_blender(cam: bpy.types.Object) -> Matrix:
379
+ """Returns the 3x4 RT matrix from the given camera.
380
+
381
+ Taken from Zero123, which in turn was taken from
382
+ https://github.com/panmari/stanford-shapenet-renderer/blob/master/render_blender.py
383
+
384
+ Args:
385
+ cam (bpy.types.Object): The camera object.
386
+
387
+ Returns:
388
+ Matrix: The 3x4 RT matrix from the given camera.
389
+ """
390
+ # Use matrix_world instead to account for all constraints
391
+ location, rotation = cam.matrix_world.decompose()[0:2]
392
+ R_world2bcam = rotation.to_matrix().transposed()
393
+
394
+ # Use location from matrix_world to account for constraints:
395
+ T_world2bcam = -1 * R_world2bcam @ location
396
+
397
+ # put into 3x4 matrix
398
+ RT = Matrix(
399
+ (
400
+ R_world2bcam[0][:] + (T_world2bcam[0],),
401
+ R_world2bcam[1][:] + (T_world2bcam[1],),
402
+ R_world2bcam[2][:] + (T_world2bcam[2],),
403
+ )
404
+ )
405
+ return RT
406
+
407
+
408
+ def delete_invisible_objects() -> None:
409
+ """Deletes all invisible objects in the scene.
410
+
411
+ Returns:
412
+ None
413
+ """
414
+ bpy.ops.object.select_all(action="DESELECT")
415
+ for obj in scene.objects:
416
+ if obj.hide_viewport or obj.hide_render:
417
+ obj.hide_viewport = False
418
+ obj.hide_render = False
419
+ obj.hide_select = False
420
+ obj.select_set(True)
421
+ bpy.ops.object.delete()
422
+
423
+ # Delete invisible collections
424
+ invisible_collections = [col for col in bpy.data.collections if col.hide_viewport]
425
+ for col in invisible_collections:
426
+ bpy.data.collections.remove(col)
427
+
428
+
429
+ def normalize_scene() -> None:
430
+ """Normalizes the scene by scaling and translating it to fit in a unit cube centered
431
+ at the origin.
432
+
433
+ Mostly taken from the Point-E / Shap-E rendering script
434
+ (https://github.com/openai/point-e/blob/main/point_e/evals/scripts/blender_script.py#L97-L112),
435
+ but fix for multiple root objects: (see bug report here:
436
+ https://github.com/openai/shap-e/pull/60).
437
+
438
+ Returns:
439
+ None
440
+ """
441
+ if len(list(get_scene_root_objects())) > 1:
442
+ # create an empty object to be used as a parent for all root objects
443
+ parent_empty = bpy.data.objects.new("ParentEmpty", None)
444
+ bpy.context.scene.collection.objects.link(parent_empty)
445
+
446
+ # parent all root objects to the empty object
447
+ for obj in get_scene_root_objects():
448
+ if obj != parent_empty:
449
+ obj.parent = parent_empty
450
+
451
+ bbox_min, bbox_max = scene_bbox()
452
+ scale = 1 / max(bbox_max - bbox_min)
453
+ for obj in get_scene_root_objects():
454
+ obj.scale = obj.scale * scale
455
+
456
+ # Apply scale to matrix_world.
457
+ bpy.context.view_layer.update()
458
+ bbox_min, bbox_max = scene_bbox()
459
+ offset = -(bbox_min + bbox_max) / 2
460
+ for obj in get_scene_root_objects():
461
+ obj.matrix_world.translation += offset
462
+ bpy.ops.object.select_all(action="DESELECT")
463
+
464
+ # unparent the camera
465
+ bpy.data.objects["Camera"].parent = None
466
+
467
+
468
+ def delete_missing_textures() -> Dict[str, Any]:
469
+ """Deletes all missing textures in the scene.
470
+
471
+ Returns:
472
+ Dict[str, Any]: Dictionary with keys "count", "files", and "file_path_to_color".
473
+ "count" is the number of missing textures, "files" is a list of the missing
474
+ texture file paths, and "file_path_to_color" is a dictionary mapping the
475
+ missing texture file paths to a random color.
476
+ """
477
+ missing_file_count = 0
478
+ out_files = []
479
+ file_path_to_color = {}
480
+
481
+ # Check all materials in the scene
482
+ for material in bpy.data.materials:
483
+ if material.use_nodes:
484
+ for node in material.node_tree.nodes:
485
+ if node.type == "TEX_IMAGE":
486
+ image = node.image
487
+ if image is not None:
488
+ file_path = bpy.path.abspath(image.filepath)
489
+ if file_path == "":
490
+ # means it's embedded
491
+ continue
492
+
493
+ if not os.path.exists(file_path):
494
+ # Find the connected Principled BSDF node
495
+ connected_node = node.outputs[0].links[0].to_node
496
+
497
+ if connected_node.type == "BSDF_PRINCIPLED":
498
+ if file_path not in file_path_to_color:
499
+ # Set a random color for the unique missing file path
500
+ random_color = [random.random() for _ in range(3)]
501
+ file_path_to_color[file_path] = random_color + [1]
502
+
503
+ connected_node.inputs[
504
+ "Base Color"
505
+ ].default_value = file_path_to_color[file_path]
506
+
507
+ # Delete the TEX_IMAGE node
508
+ material.node_tree.nodes.remove(node)
509
+ missing_file_count += 1
510
+ out_files.append(image.filepath)
511
+ return {
512
+ "count": missing_file_count,
513
+ "files": out_files,
514
+ "file_path_to_color": file_path_to_color,
515
+ }
516
+
517
+
518
+ def _get_random_color() -> Tuple[float, float, float, float]:
519
+ """Generates a random RGB-A color.
520
+
521
+ The alpha value is always 1.
522
+
523
+ Returns:
524
+ Tuple[float, float, float, float]: A random RGB-A color. Each value is in the
525
+ range [0, 1].
526
+ """
527
+ return (random.random(), random.random(), random.random(), 1)
528
+
529
+
530
+ def _apply_color_to_object(
531
+ obj: bpy.types.Object, color: Tuple[float, float, float, float]
532
+ ) -> None:
533
+ """Applies the given color to the object.
534
+
535
+ Args:
536
+ obj (bpy.types.Object): The object to apply the color to.
537
+ color (Tuple[float, float, float, float]): The color to apply to the object.
538
+
539
+ Returns:
540
+ None
541
+ """
542
+ mat = bpy.data.materials.new(name=f"RandomMaterial_{obj.name}")
543
+ mat.use_nodes = True
544
+ nodes = mat.node_tree.nodes
545
+ principled_bsdf = nodes.get("Principled BSDF")
546
+ if principled_bsdf:
547
+ principled_bsdf.inputs["Base Color"].default_value = color
548
+ obj.data.materials.append(mat)
549
+
550
+
551
+ class MetadataExtractor:
552
+ """Class to extract metadata from a Blender scene."""
553
+
554
+ def __init__(
555
+ self, object_path: str, scene: bpy.types.Scene, bdata: bpy.types.BlendData
556
+ ) -> None:
557
+ """Initializes the MetadataExtractor.
558
+
559
+ Args:
560
+ object_path (str): Path to the object file.
561
+ scene (bpy.types.Scene): The current scene object from `bpy.context.scene`.
562
+ bdata (bpy.types.BlendData): The current blender data from `bpy.data`.
563
+
564
+ Returns:
565
+ None
566
+ """
567
+ self.object_path = object_path
568
+ self.scene = scene
569
+ self.bdata = bdata
570
+
571
+ def get_poly_count(self) -> int:
572
+ """Returns the total number of polygons in the scene."""
573
+ total_poly_count = 0
574
+ for obj in self.scene.objects:
575
+ if obj.type == "MESH":
576
+ total_poly_count += len(obj.data.polygons)
577
+ return total_poly_count
578
+
579
+ def get_vertex_count(self) -> int:
580
+ """Returns the total number of vertices in the scene."""
581
+ total_vertex_count = 0
582
+ for obj in self.scene.objects:
583
+ if obj.type == "MESH":
584
+ total_vertex_count += len(obj.data.vertices)
585
+ return total_vertex_count
586
+
587
+ def get_edge_count(self) -> int:
588
+ """Returns the total number of edges in the scene."""
589
+ total_edge_count = 0
590
+ for obj in self.scene.objects:
591
+ if obj.type == "MESH":
592
+ total_edge_count += len(obj.data.edges)
593
+ return total_edge_count
594
+
595
+ def get_lamp_count(self) -> int:
596
+ """Returns the number of lamps in the scene."""
597
+ return sum(1 for obj in self.scene.objects if obj.type == "LIGHT")
598
+
599
+ def get_mesh_count(self) -> int:
600
+ """Returns the number of meshes in the scene."""
601
+ return sum(1 for obj in self.scene.objects if obj.type == "MESH")
602
+
603
+ def get_material_count(self) -> int:
604
+ """Returns the number of materials in the scene."""
605
+ return len(self.bdata.materials)
606
+
607
+ def get_object_count(self) -> int:
608
+ """Returns the number of objects in the scene."""
609
+ return len(self.bdata.objects)
610
+
611
+ def get_animation_count(self) -> int:
612
+ """Returns the number of animations in the scene."""
613
+ return len(self.bdata.actions)
614
+
615
+ def get_linked_files(self) -> List[str]:
616
+ """Returns the filepaths of all linked files."""
617
+ image_filepaths = self._get_image_filepaths()
618
+ material_filepaths = self._get_material_filepaths()
619
+ linked_libraries_filepaths = self._get_linked_libraries_filepaths()
620
+
621
+ all_filepaths = (
622
+ image_filepaths | material_filepaths | linked_libraries_filepaths
623
+ )
624
+ if "" in all_filepaths:
625
+ all_filepaths.remove("")
626
+ return list(all_filepaths)
627
+
628
+ def _get_image_filepaths(self) -> Set[str]:
629
+ """Returns the filepaths of all images used in the scene."""
630
+ filepaths = set()
631
+ for image in self.bdata.images:
632
+ if image.source == "FILE":
633
+ filepaths.add(bpy.path.abspath(image.filepath))
634
+ return filepaths
635
+
636
+ def _get_material_filepaths(self) -> Set[str]:
637
+ """Returns the filepaths of all images used in materials."""
638
+ filepaths = set()
639
+ for material in self.bdata.materials:
640
+ if material.use_nodes:
641
+ for node in material.node_tree.nodes:
642
+ if node.type == "TEX_IMAGE":
643
+ image = node.image
644
+ if image is not None:
645
+ filepaths.add(bpy.path.abspath(image.filepath))
646
+ return filepaths
647
+
648
+ def _get_linked_libraries_filepaths(self) -> Set[str]:
649
+ """Returns the filepaths of all linked libraries."""
650
+ filepaths = set()
651
+ for library in self.bdata.libraries:
652
+ filepaths.add(bpy.path.abspath(library.filepath))
653
+ return filepaths
654
+
655
+ def get_scene_size(self) -> Dict[str, list]:
656
+ """Returns the size of the scene bounds in meters."""
657
+ bbox_min, bbox_max = scene_bbox()
658
+ return {"bbox_max": list(bbox_max), "bbox_min": list(bbox_min)}
659
+
660
+ def get_shape_key_count(self) -> int:
661
+ """Returns the number of shape keys in the scene."""
662
+ total_shape_key_count = 0
663
+ for obj in self.scene.objects:
664
+ if obj.type == "MESH":
665
+ shape_keys = obj.data.shape_keys
666
+ if shape_keys is not None:
667
+ total_shape_key_count += (
668
+ len(shape_keys.key_blocks) - 1
669
+ ) # Subtract 1 to exclude the Basis shape key
670
+ return total_shape_key_count
671
+
672
+ def get_armature_count(self) -> int:
673
+ """Returns the number of armatures in the scene."""
674
+ total_armature_count = 0
675
+ for obj in self.scene.objects:
676
+ if obj.type == "ARMATURE":
677
+ total_armature_count += 1
678
+ return total_armature_count
679
+
680
+ def read_file_size(self) -> int:
681
+ """Returns the size of the file in bytes."""
682
+ return os.path.getsize(self.object_path)
683
+
684
+ def get_metadata(self) -> Dict[str, Any]:
685
+ """Returns the metadata of the scene.
686
+
687
+ Returns:
688
+ Dict[str, Any]: Dictionary of the metadata with keys for "file_size",
689
+ "poly_count", "vert_count", "edge_count", "material_count", "object_count",
690
+ "lamp_count", "mesh_count", "animation_count", "linked_files", "scene_size",
691
+ "shape_key_count", and "armature_count".
692
+ """
693
+ return {
694
+ "file_size": self.read_file_size(),
695
+ "poly_count": self.get_poly_count(),
696
+ "vert_count": self.get_vertex_count(),
697
+ "edge_count": self.get_edge_count(),
698
+ "material_count": self.get_material_count(),
699
+ "object_count": self.get_object_count(),
700
+ "lamp_count": self.get_lamp_count(),
701
+ "mesh_count": self.get_mesh_count(),
702
+ "animation_count": self.get_animation_count(),
703
+ "linked_files": self.get_linked_files(),
704
+ "scene_size": self.get_scene_size(),
705
+ "shape_key_count": self.get_shape_key_count(),
706
+ "armature_count": self.get_armature_count(),
707
+ }
708
+
709
+ def pan_camera(time, axis="Z", camera_dist=2.0, elevation=-0.1, camera_offset=0.0):
710
+ angle = time * math.pi * 2 - math.pi / 2 # start from -90 degree
711
+ direction = [-math.cos(angle), -math.sin(angle), -elevation]
712
+ assert axis in ["X", "Y", "Z"]
713
+ if axis == "X":
714
+ direction = [direction[2], *direction[:2]]
715
+ elif axis == "Y":
716
+ direction = [direction[0], -elevation, direction[1]]
717
+ direction = Vector(direction).normalized()
718
+ camera = set_camera(direction, camera_dist=camera_dist, camera_offset=camera_offset)
719
+ return camera
720
+
721
+
722
+ def pan_camera_along(time, pose="alone-x-rotate", camera_dist=2.0, rotate=0.0):
723
+ angle = time * math.pi * 2
724
+ # direction_plane = [-math.cos(angle), -math.sin(angle), 0]
725
+ x_new = math.cos(angle)
726
+ y_new = math.cos(rotate) * math.sin(angle)
727
+ z_new = math.sin(rotate) * math.sin(angle)
728
+ direction = [-x_new, -y_new, -z_new]
729
+ assert pose in ["alone-x-rotate"]
730
+ direction = Vector(direction).normalized()
731
+ camera = set_camera(direction, camera_dist=camera_dist)
732
+ return camera
733
+
734
+ def pan_camera_by_angle(angle, axis="Z", camera_dist=2.0, elevation=-0.1 ):
735
+ direction = [-math.cos(angle), -math.sin(angle), -elevation]
736
+ assert axis in ["X", "Y", "Z"]
737
+ if axis == "X":
738
+ direction = [direction[2], *direction[:2]]
739
+ elif axis == "Y":
740
+ direction = [direction[0], -elevation, direction[1]]
741
+ direction = Vector(direction).normalized()
742
+ camera = set_camera(direction, camera_dist=camera_dist)
743
+ return camera
744
+
745
+ def z_circular_custom_track(time,
746
+ camera_dist,
747
+ azimuth_shift = [-9, 9],
748
+ init_elevation = 0.0,
749
+ elevation_shift = [-5, 5]):
750
+
751
+ adjusted_azimuth = (-math.degrees(math.pi / 2) +
752
+ time * 360 +
753
+ np.random.uniform(low=azimuth_shift[0], high=azimuth_shift[1]))
754
+
755
+ # Add random noise to the elevation
756
+ adjusted_elevation = init_elevation + np.random.uniform(low=elevation_shift[0], high=elevation_shift[1])
757
+ return math.radians(adjusted_azimuth), math.radians(adjusted_elevation), camera_dist
758
+
759
+
760
+ def place_camera(time, camera_pose_mode="random", camera_dist=2.0, rotate=0.0, elevation=0.0, camera_offset=0.0, idx=0):
761
+ if camera_pose_mode == "z-circular-elevated":
762
+ cam = pan_camera(time, axis="Z", camera_dist=camera_dist, elevation=elevation, camera_offset=camera_offset)
763
+ elif camera_pose_mode == 'alone-x-rotate':
764
+ cam = pan_camera_along(time, pose=camera_pose_mode, camera_dist=camera_dist, rotate=rotate)
765
+ elif camera_pose_mode == 'z-circular-elevated-noise':
766
+ angle, elevation, camera_dist = z_circular_custom_track(time, camera_dist=camera_dist, init_elevation=elevation)
767
+ cam = pan_camera_by_angle(angle, axis="Z", camera_dist=camera_dist, elevation=elevation)
768
+ elif camera_pose_mode == 'random':
769
+ cam = randomize_camera_with_cache(radius_min=camera_dist, radius_max=camera_dist, maxz=114514., minz=-114514., idx=idx)
770
+ else:
771
+ raise ValueError(f"Unknown camera pose mode: {camera_pose_mode}")
772
+ return cam
773
+
774
+
775
+ def setup_nodes(output_path, capturing_material_alpha: bool = False):
776
+ tree = bpy.context.scene.node_tree
777
+ links = tree.links
778
+
779
+ for node in tree.nodes:
780
+ tree.nodes.remove(node)
781
+
782
+ # Helpers to perform math on links and constants.
783
+ def node_op(op: str, *args, clamp=False):
784
+ node = tree.nodes.new(type="CompositorNodeMath")
785
+ node.operation = op
786
+ if clamp:
787
+ node.use_clamp = True
788
+ for i, arg in enumerate(args):
789
+ if isinstance(arg, (int, float)):
790
+ node.inputs[i].default_value = arg
791
+ else:
792
+ links.new(arg, node.inputs[i])
793
+ return node.outputs[0]
794
+
795
+ def node_clamp(x, maximum=1.0):
796
+ return node_op("MINIMUM", x, maximum)
797
+
798
+ def node_mul(x, y, **kwargs):
799
+ return node_op("MULTIPLY", x, y, **kwargs)
800
+
801
+ input_node = tree.nodes.new(type="CompositorNodeRLayers")
802
+ input_node.scene = bpy.context.scene
803
+
804
+ input_sockets = {}
805
+ for output in input_node.outputs:
806
+ input_sockets[output.name] = output
807
+
808
+ if capturing_material_alpha:
809
+ color_socket = input_sockets["Image"]
810
+ else:
811
+ raw_color_socket = input_sockets["Image"]
812
+
813
+ # We apply sRGB here so that our fixed-point depth map and material
814
+ # alpha values are not sRGB, and so that we perform ambient+diffuse
815
+ # lighting in linear RGB space.
816
+ color_node = tree.nodes.new(type="CompositorNodeConvertColorSpace")
817
+ color_node.from_color_space = "Linear"
818
+ color_node.to_color_space = "sRGB"
819
+ tree.links.new(raw_color_socket, color_node.inputs[0])
820
+ color_socket = color_node.outputs[0]
821
+ split_node = tree.nodes.new(type="CompositorNodeSepRGBA")
822
+ tree.links.new(color_socket, split_node.inputs[0])
823
+ # Create separate file output nodes for every channel we care about.
824
+ # The process calling this script must decide how to recombine these
825
+ # channels, possibly into a single image.
826
+ for i, channel in enumerate("rgba") if not capturing_material_alpha else [(0, "MatAlpha")]:
827
+ output_node = tree.nodes.new(type="CompositorNodeOutputFile")
828
+ output_node.base_path = f"{output_path}_{channel}"
829
+ links.new(split_node.outputs[i], output_node.inputs[0])
830
+ if capturing_material_alpha:
831
+ # No need to re-write depth here.
832
+ return
833
+
834
+ depth_out = node_clamp(node_mul(input_sockets["Depth"], 1 / MAX_DEPTH))
835
+ output_node = tree.nodes.new(type="CompositorNodeOutputFile")
836
+ output_node.format.file_format = 'OPEN_EXR'
837
+ output_node.base_path = f"{output_path}_depth"
838
+ links.new(depth_out, output_node.inputs[0])
839
+
840
+ # Add normal map output
841
+ normal_out = input_sockets["Normal"]
842
+
843
+ # Scale normal by 0.5
844
+ scale_normal = tree.nodes.new(type="CompositorNodeMixRGB")
845
+ scale_normal.blend_type = 'MULTIPLY'
846
+ scale_normal.inputs[2].default_value = (0.5, 0.5, 0.5, 1)
847
+ links.new(normal_out, scale_normal.inputs[1])
848
+
849
+ # Bias normal by 0.5
850
+ bias_normal = tree.nodes.new(type="CompositorNodeMixRGB")
851
+ bias_normal.blend_type = 'ADD'
852
+ bias_normal.inputs[2].default_value = (0.5, 0.5, 0.5, 0)
853
+ links.new(scale_normal.outputs[0], bias_normal.inputs[1])
854
+
855
+ # Output the transformed normal map
856
+ normal_file_output = tree.nodes.new(type="CompositorNodeOutputFile")
857
+ normal_file_output.base_path = f"{output_path}_normal"
858
+ normal_file_output.format.file_format = 'OPEN_EXR'
859
+ links.new(bias_normal.outputs[0], normal_file_output.inputs[0])
860
+
861
+
862
+ def setup_nodes_semantic(output_path, capturing_material_alpha: bool = False):
863
+ tree = bpy.context.scene.node_tree
864
+ links = tree.links
865
+
866
+ for node in tree.nodes:
867
+ tree.nodes.remove(node)
868
+
869
+ # Helpers to perform math on links and constants.
870
+ def node_op(op: str, *args, clamp=False):
871
+ node = tree.nodes.new(type="CompositorNodeMath")
872
+ node.operation = op
873
+ if clamp:
874
+ node.use_clamp = True
875
+ for i, arg in enumerate(args):
876
+ if isinstance(arg, (int, float)):
877
+ node.inputs[i].default_value = arg
878
+ else:
879
+ links.new(arg, node.inputs[i])
880
+ return node.outputs[0]
881
+
882
+ def node_clamp(x, maximum=1.0):
883
+ return node_op("MINIMUM", x, maximum)
884
+
885
+ def node_mul(x, y, **kwargs):
886
+ return node_op("MULTIPLY", x, y, **kwargs)
887
+
888
+ input_node = tree.nodes.new(type="CompositorNodeRLayers")
889
+ input_node.scene = bpy.context.scene
890
+
891
+ input_sockets = {}
892
+ for output in input_node.outputs:
893
+ input_sockets[output.name] = output
894
+
895
+ if capturing_material_alpha:
896
+ color_socket = input_sockets["Image"]
897
+ else:
898
+ raw_color_socket = input_sockets["Image"]
899
+ # We apply sRGB here so that our fixed-point depth map and material
900
+ # alpha values are not sRGB, and so that we perform ambient+diffuse
901
+ # lighting in linear RGB space.
902
+ color_node = tree.nodes.new(type="CompositorNodeConvertColorSpace")
903
+ color_node.from_color_space = "Linear"
904
+ color_node.to_color_space = "sRGB"
905
+ tree.links.new(raw_color_socket, color_node.inputs[0])
906
+ color_socket = color_node.outputs[0]
907
+
908
+
909
+ def render_object(
910
+ object_file: str,
911
+ num_renders: int,
912
+ only_northern_hemisphere: bool,
913
+ output_dir: str,
914
+ ) -> None:
915
+ """Saves rendered images with its camera matrix and metadata of the object.
916
+
917
+ Args:
918
+ object_file (str): Path to the object file.
919
+ num_renders (int): Number of renders to save of the object.
920
+ only_northern_hemisphere (bool): Whether to only render sides of the object that
921
+ are in the northern hemisphere. This is useful for rendering objects that
922
+ are photogrammetrically scanned, as the bottom of the object often has
923
+ holes.
924
+ output_dir (str): Path to the directory where the rendered images and metadata
925
+ will be saved.
926
+
927
+ Returns:
928
+ None
929
+ """
930
+ os.makedirs(output_dir, exist_ok=True)
931
+
932
+ # load the object
933
+ if object_file.endswith(".blend"):
934
+ bpy.ops.object.mode_set(mode="OBJECT")
935
+ reset_cameras()
936
+ delete_invisible_objects()
937
+ else:
938
+ reset_scene()
939
+ load_object(object_file)
940
+
941
+ # Set up cameras
942
+ cam = scene.objects["Camera"]
943
+ cam.data.lens = 35
944
+ cam.data.sensor_width = 32
945
+
946
+ # Set up camera constraints
947
+ cam_constraint = cam.constraints.new(type="TRACK_TO")
948
+ cam_constraint.track_axis = "TRACK_NEGATIVE_Z"
949
+ cam_constraint.up_axis = "UP_Y"
950
+
951
+ # Extract the metadata. This must be done before normalizing the scene to get
952
+ # accurate bounding box information.
953
+ metadata_extractor = MetadataExtractor(
954
+ object_path=object_file, scene=scene, bdata=bpy.data
955
+ )
956
+ metadata = metadata_extractor.get_metadata()
957
+
958
+ # delete all objects that are not meshes
959
+ if object_file.lower().endswith(".usdz") or object_file.lower().endswith(".vrm"):
960
+ # don't delete missing textures on usdz files, lots of them are embedded
961
+ missing_textures = None
962
+ else:
963
+ missing_textures = delete_missing_textures()
964
+ metadata["missing_textures"] = missing_textures
965
+ metadata["random_color"] = None
966
+
967
+ # save metadata
968
+ metadata_path = os.path.join(output_dir, "metadata.json")
969
+ os.makedirs(os.path.dirname(metadata_path), exist_ok=True)
970
+ with open(metadata_path, "w", encoding="utf-8") as f:
971
+ json.dump(metadata, f, sort_keys=True, indent=2)
972
+
973
+ # normalize the scene
974
+ normalize_scene()
975
+
976
+ # cancel edge rim lighting in vrm files
977
+ if object_file.endswith(".vrm"):
978
+ for i in bpy.data.materials:
979
+ i.vrm_addon_extension.mtoon1.extensions.vrmc_materials_mtoon.rim_lighting_mix_factor = 0.0
980
+ i.vrm_addon_extension.mtoon1.extensions.vrmc_materials_mtoon.matcap_texture.index.source = None
981
+ i.vrm_addon_extension.mtoon1.extensions.vrmc_materials_mtoon.outline_width_factor = 0.0
982
+
983
+ # rotate two arms to A-pose
984
+ if object_file.endswith(".vrm"):
985
+ armature = [ i for i in bpy.data.objects if 'Armature' in i.name ][0]
986
+ bpy.context.view_layer.objects.active = armature
987
+ bpy.ops.object.mode_set(mode='POSE')
988
+ pbone1 = armature.pose.bones['J_Bip_L_UpperArm']
989
+ pbone2 = armature.pose.bones['J_Bip_R_UpperArm']
990
+ pbone1.rotation_mode = 'XYZ'
991
+ pbone2.rotation_mode = 'XYZ'
992
+ pbone1.rotation_euler.rotate_axis('X', math.radians(-45))
993
+ pbone2.rotation_euler.rotate_axis('X', math.radians(-45))
994
+ bpy.ops.object.mode_set(mode='OBJECT')
995
+
996
+ def printInfo():
997
+ print("====== Objects ======")
998
+ for i in bpy.data.objects:
999
+ print(i.name)
1000
+ print("====== Materials ======")
1001
+ for i in bpy.data.materials:
1002
+ print(i.name)
1003
+
1004
+ def parse_material():
1005
+ hair_mats = []
1006
+ cloth_mats = []
1007
+ face_mats = []
1008
+ body_mats = []
1009
+
1010
+ # main hair material
1011
+ if 'Hair' in bpy.data.objects:
1012
+ hair_mats = [i.name for i in bpy.data.objects['Hair'].data.materials if 'MToon Outline' not in i.name]
1013
+ else:
1014
+ flag = False
1015
+ for i in bpy.data.objects:
1016
+ if i.name[:4] == 'Hair' and bpy.data.objects[i.name].data:
1017
+ hair_mats += [i.name for i in bpy.data.objects[i.name].data.materials if 'MToon Outline' not in i.name]
1018
+ flag = True
1019
+ if not flag:
1020
+ if 'Hairs' in bpy.data.objects and bpy.data.objects['Hairs'].data:
1021
+ hair_mats = [i.name for i in bpy.data.objects['Hairs'].data.materials if 'MToon Outline' not in i.name]
1022
+ else:
1023
+ for i in bpy.data.materials:
1024
+ if 'HAIR' in i.name and 'MToon Outline' not in i.name:
1025
+ hair_mats.append(i.name)
1026
+ if len(hair_mats) == 0:
1027
+ printInfo()
1028
+ with open('error.txt', 'a+') as f:
1029
+ f.write(object_file + '\t' + 'Cannot find main hair material\t' + str([iii.name for iii in bpy.data.objects]) + '\n')
1030
+ raise ValueError("Cannot find main hair material")
1031
+
1032
+ # face material
1033
+ if 'Face' in bpy.data.objects:
1034
+ face_mats = [i.name for i in bpy.data.objects['Face'].data.materials if 'MToon Outline' not in i.name]
1035
+ else:
1036
+ for i in bpy.data.materials:
1037
+ if 'FACE' in i.name and 'MToon Outline' not in i.name:
1038
+ face_mats.append(i.name)
1039
+ elif 'Face' in i.name and 'SKIN' in i.name and 'MToon Outline' not in i.name:
1040
+ face_mats.append(i.name)
1041
+ if len(face_mats) == 0:
1042
+ printInfo()
1043
+ with open('error.txt', 'a+') as f:
1044
+ f.write(object_file + '\t' + 'Cannot find face material\t' + str([iii.name for iii in bpy.data.objects]) + '\n')
1045
+ raise ValueError("Cannot find face material")
1046
+
1047
+ # loop
1048
+ for i in bpy.data.materials:
1049
+ if 'MToon Outline' in i.name:
1050
+ continue
1051
+ elif 'CLOTH' in i.name:
1052
+ if 'Shoes' in i.name:
1053
+ body_mats.append(i.name)
1054
+ elif 'Accessory' in i.name:
1055
+ if 'CatEar' in i.name:
1056
+ hair_mats.append(i.name)
1057
+ else:
1058
+ cloth_mats.append(i.name)
1059
+ elif any( name in i.name for name in ['Tops', 'Bottoms', 'Onepice'] ):
1060
+ cloth_mats.append(i.name)
1061
+ else:
1062
+ raise ValueError(f"Unknown cloth material: {i.name}")
1063
+ elif 'Body' in i.name and 'SKIN' in i.name:
1064
+ body_mats.append(i.name)
1065
+ elif i.name in hair_mats or i.name in face_mats:
1066
+ continue
1067
+ elif 'HairBack' in i.name and 'HAIR' in i.name:
1068
+ hair_mats.append(i.name)
1069
+ elif 'EYE' in i.name:
1070
+ face_mats.append(i.name)
1071
+ elif 'Face' in i.name and 'SKIN' in i.name:
1072
+ face_mats.append(i.name)
1073
+ else:
1074
+ print("hair_mats", hair_mats)
1075
+ print("cloth_mats", cloth_mats)
1076
+ print("face_mats", face_mats)
1077
+ print("body_mats", body_mats)
1078
+ with open('error.txt', 'a+') as f:
1079
+ f.write(object_file + '\t' + 'Cannot find material\t' + i.name + '\n')
1080
+ raise ValueError(f"Unknown material: {i.name}")
1081
+
1082
+ return hair_mats, cloth_mats, face_mats, body_mats
1083
+
1084
+ hair_mats, cloth_mats, face_mats, body_mats = parse_material()
1085
+
1086
+ # get bounding box of face
1087
+ def get_face_bbox():
1088
+ if 'Face' in bpy.data.objects:
1089
+ face = bpy.data.objects['Face']
1090
+ bbox_min, bbox_max = scene_bbox(face)
1091
+ return bbox_min, bbox_max
1092
+ else:
1093
+ bbox_min, bbox_max = scene_bbox()
1094
+ for i in bpy.data.objects:
1095
+ if i.data.materials and i.data.materials[0].name in face_mats:
1096
+ face = i
1097
+ cur_bbox_min, cur_bbox_max = scene_bbox(face)
1098
+ bbox_min = np.minimum(bbox_min, cur_bbox_min)
1099
+ bbox_max = np.maximum(bbox_max, cur_bbox_max)
1100
+ return bbox_min, bbox_max
1101
+
1102
+ def assign_color(material_name, color):
1103
+ material = bpy.data.materials.get(material_name)
1104
+ if material:
1105
+ material.vrm_addon_extension.mtoon1.pbr_metallic_roughness.base_color_factor = (1, 1, 1, 1)
1106
+ image = material.vrm_addon_extension.mtoon1.pbr_metallic_roughness.base_color_texture.index.source
1107
+ if image:
1108
+ pixels = np.array(image.pixels[:])
1109
+ width, height = image.size
1110
+ num_channels = 4
1111
+ pixels = pixels.reshape((height, width, num_channels))
1112
+ srgb_pixels = np.clip(np.power(pixels, 1/2.2), 0.0, 1.0)
1113
+ print("Image converted to NumPy array")
1114
+
1115
+ # Step 2: Edit the NumPy array
1116
+ srgb_pixels[..., 0] = color[0]
1117
+ srgb_pixels[..., 1] = color[1]
1118
+ srgb_pixels[..., 2] = color[2]
1119
+ edited_image_rgba = srgb_pixels
1120
+
1121
+ # Step 3: Convert the edited NumPy array back to a Blender image
1122
+ edited_image_flat = edited_image_rgba.astype(np.float32)
1123
+ edited_image_flat = edited_image_flat.flatten()
1124
+ edited_image_name = "Edited_Texture"
1125
+ edited_blender_image = bpy.data.images.new(edited_image_name, width, height, alpha=True)
1126
+ edited_blender_image.pixels = edited_image_flat
1127
+ material.vrm_addon_extension.mtoon1.pbr_metallic_roughness.base_color_texture.index.source = edited_blender_image
1128
+ print(f"Edited image assigned to {material_name}")
1129
+
1130
+ material.vrm_addon_extension.mtoon1.extensions.vrmc_materials_mtoon.shade_color_factor = (1, 1, 1)
1131
+ image = material.vrm_addon_extension.mtoon1.extensions.vrmc_materials_mtoon.shade_multiply_texture.index.source
1132
+ if image:
1133
+ pixels = np.array(image.pixels[:])
1134
+ width, height = image.size
1135
+ num_channels = 4
1136
+ pixels = pixels.reshape((height, width, num_channels))
1137
+ srgb_pixels = np.clip(np.power(pixels, 1/2.2), 0.0, 1.0)
1138
+ print("Image converted to NumPy array")
1139
+
1140
+ # Step 2: Edit the NumPy array
1141
+ srgb_pixels[..., 0] = color[0]
1142
+ srgb_pixels[..., 1] = color[1]
1143
+ srgb_pixels[..., 2] = color[2]
1144
+ edited_image_rgba = srgb_pixels
1145
+
1146
+ # Step 3: Convert the edited NumPy array back to a Blender image
1147
+ edited_image_flat = edited_image_rgba.astype(np.float32)
1148
+ edited_image_flat = edited_image_flat.flatten()
1149
+ edited_image_name = "Edited_Texture"
1150
+ edited_blender_image = bpy.data.images.new(edited_image_name, width, height, alpha=True)
1151
+ edited_blender_image.pixels = edited_image_flat
1152
+ material.vrm_addon_extension.mtoon1.extensions.vrmc_materials_mtoon.shade_multiply_texture.index.source = edited_blender_image
1153
+ print(f"Edited image assigned to {material_name}")
1154
+ material.vrm_addon_extension.mtoon1.extensions.khr_materials_emissive_strength.emissive_strength = 0.0
1155
+
1156
+ def assign_transparency(material_name, alpha):
1157
+ material = bpy.data.materials.get(material_name)
1158
+ if material:
1159
+ material.vrm_addon_extension.mtoon1.pbr_metallic_roughness.base_color_factor = (1, 1, 1, alpha)
1160
+
1161
+ # render the images
1162
+ use_workbench = bpy.context.scene.render.engine == "BLENDER_WORKBENCH"
1163
+
1164
+ face_bbox_min, face_bbox_max = get_face_bbox()
1165
+ face_bbox_center = (face_bbox_min + face_bbox_max) / 2
1166
+ face_bbox_size = face_bbox_max - face_bbox_min
1167
+ print("face_bbox_center", face_bbox_center)
1168
+ print("face_bbox_size", face_bbox_size)
1169
+
1170
+ config_names = ["custom2", "custom_top", "custom_bottom", "custom_face", "random"]
1171
+
1172
+ # normal rendering
1173
+ for l in range(3): # 3 levels: all; no hair; no hair and no cloth
1174
+ if l == 0:
1175
+ pass
1176
+ elif l == 1:
1177
+ for i in hair_mats:
1178
+ bpy.data.materials[i].vrm_addon_extension.mtoon1.pbr_metallic_roughness.base_color_factor = (0, 0, 0, 0)
1179
+ elif l == 2:
1180
+ for i in cloth_mats:
1181
+ bpy.data.materials[i].vrm_addon_extension.mtoon1.pbr_metallic_roughness.base_color_factor = (0, 0, 0, 0)
1182
+
1183
+ for j in range(5): # 5 track
1184
+ config = configs[config_names[j]]
1185
+ if "render_num" in config:
1186
+ new_num_renders = config["render_num"]
1187
+ else:
1188
+ new_num_renders = num_renders
1189
+
1190
+ for i in range(new_num_renders):
1191
+ camera_dist = 1.4
1192
+ if config_names[j] == "custom_face":
1193
+ camera_dist = 0.6
1194
+ if i not in [0, 1, 2, 6, 7]:
1195
+ continue
1196
+ t = i / num_renders
1197
+ elevation_range = config["elevation_range"]
1198
+ init_elevation = elevation_range[0]
1199
+ # set camera
1200
+ camera = place_camera(
1201
+ t,
1202
+ camera_pose_mode=config["camera_pose"],
1203
+ camera_dist=camera_dist,
1204
+ rotate=config["rotate"],
1205
+ elevation=init_elevation,
1206
+ camera_offset=face_bbox_center if config_names[j] == "custom_face" else 0.0,
1207
+ idx=i
1208
+ )
1209
+
1210
+ # set camera to ortho
1211
+ bpy.data.objects["Camera"].data.type = 'ORTHO'
1212
+ bpy.data.objects["Camera"].data.ortho_scale = 1.2 if config_names[j] != "custom_face" else np.max(face_bbox_size) * 1.2
1213
+
1214
+ # render the image
1215
+ render_path = os.path.join(output_dir, f"{(i + j * 100 + l * 1000):05}.png")
1216
+ scene.render.filepath = render_path
1217
+ setup_nodes(render_path)
1218
+ bpy.ops.render.render(write_still=True)
1219
+
1220
+ # save camera RT matrix
1221
+ rt_matrix = get_3x4_RT_matrix_from_blender(camera)
1222
+ rt_matrix_path = os.path.join(output_dir, f"{(i + j * 100 + l * 1000):05}.npy")
1223
+ np.save(rt_matrix_path, rt_matrix)
1224
+
1225
+ for channel_name in ["r", "g", "b", "a", "depth", "normal"]:
1226
+ sub_dir = f"{render_path}_{channel_name}"
1227
+ if channel_name in ['r', 'g', 'b']:
1228
+ # remove path
1229
+ shutil.rmtree(sub_dir)
1230
+ continue
1231
+
1232
+ image_path = os.path.join(sub_dir, os.listdir(sub_dir)[0])
1233
+ name, ext = os.path.splitext(render_path)
1234
+ if channel_name == "a":
1235
+ os.rename(image_path, f"{name}_{channel_name}.png")
1236
+ elif channel_name == 'depth':
1237
+ os.rename(image_path, f"{name}_{channel_name}.exr")
1238
+ elif channel_name == "normal":
1239
+ os.rename(image_path, f"{name}_{channel_name}.exr")
1240
+ else:
1241
+ os.remove(image_path)
1242
+
1243
+ os.removedirs(sub_dir)
1244
+
1245
+ # reset
1246
+ for i in hair_mats:
1247
+ bpy.data.materials[i].vrm_addon_extension.mtoon1.pbr_metallic_roughness.base_color_factor = (1, 1, 1, 1)
1248
+ for i in cloth_mats:
1249
+ bpy.data.materials[i].vrm_addon_extension.mtoon1.pbr_metallic_roughness.base_color_factor = (1, 1, 1, 1)
1250
+
1251
+ # switch to semantic rendering
1252
+ for i in hair_mats:
1253
+ assign_color(i, [1.0, 0.0, 0.0])
1254
+ for i in cloth_mats:
1255
+ assign_color(i, [0.0, 0.0, 1.0])
1256
+ for i in face_mats:
1257
+ assign_color(i, [0.0, 1.0, 1.0])
1258
+ if any( ii in i for ii in ['Eyeline', 'Eyelash', 'Brow', 'Highlight'] ):
1259
+ assign_transparency(i, 0.0)
1260
+ for i in body_mats:
1261
+ assign_color(i, [0.0, 1.0, 0.0])
1262
+
1263
+ for l in range(3): # 3 levels: all; no hair; no hair and no cloth
1264
+ if l == 0:
1265
+ pass
1266
+ elif l == 1:
1267
+ for i in hair_mats:
1268
+ bpy.data.materials[i].vrm_addon_extension.mtoon1.pbr_metallic_roughness.base_color_factor = (0, 0, 0, 0)
1269
+ elif l == 2:
1270
+ for i in cloth_mats:
1271
+ bpy.data.materials[i].vrm_addon_extension.mtoon1.pbr_metallic_roughness.base_color_factor = (0, 0, 0, 0)
1272
+ for j in range(5): # 5 track
1273
+ config = configs[config_names[j]]
1274
+ if "render_num" in config:
1275
+ new_num_renders = config["render_num"]
1276
+ else:
1277
+ new_num_renders = num_renders
1278
+
1279
+ for i in range(new_num_renders):
1280
+ camera_dist = 1.4
1281
+ if config_names[j] == "custom_face":
1282
+ camera_dist = 0.6
1283
+ if i not in [0, 1, 2, 6, 7]:
1284
+ continue
1285
+ t = i / num_renders
1286
+ elevation_range = config["elevation_range"]
1287
+ init_elevation = elevation_range[0]
1288
+ # set camera
1289
+ camera = place_camera(
1290
+ t,
1291
+ camera_pose_mode=config["camera_pose"],
1292
+ camera_dist=camera_dist,
1293
+ rotate=config["rotate"],
1294
+ elevation=init_elevation,
1295
+ camera_offset=face_bbox_center if config_names[j] == "custom_face" else 0.0,
1296
+ idx=i
1297
+ )
1298
+
1299
+ # set camera to ortho
1300
+ bpy.data.objects["Camera"].data.type = 'ORTHO'
1301
+ bpy.data.objects["Camera"].data.ortho_scale = 1.2 if config_names[j] != "custom_face" else np.max(face_bbox_size) * 1.2
1302
+
1303
+ # render the image
1304
+ render_path = os.path.join(output_dir, f"{(i + j * 100 + l * 1000):05}_semantic.png")
1305
+ scene.render.filepath = render_path
1306
+ setup_nodes_semantic(render_path)
1307
+ bpy.ops.render.render(write_still=True)
1308
+
1309
+
1310
+ if __name__ == "__main__":
1311
+ parser = argparse.ArgumentParser()
1312
+ parser.add_argument(
1313
+ "--object_path",
1314
+ type=str,
1315
+ required=True,
1316
+ help="Path to the object file",
1317
+ )
1318
+ parser.add_argument(
1319
+ "--output_dir",
1320
+ type=str,
1321
+ required=True,
1322
+ help="Path to the directory where the rendered images and metadata will be saved.",
1323
+ )
1324
+ parser.add_argument(
1325
+ "--engine",
1326
+ type=str,
1327
+ default="BLENDER_EEVEE",
1328
+ choices=["CYCLES", "BLENDER_EEVEE"],
1329
+ )
1330
+ parser.add_argument(
1331
+ "--only_northern_hemisphere",
1332
+ action="store_true",
1333
+ help="Only render the northern hemisphere of the object.",
1334
+ default=False,
1335
+ )
1336
+ parser.add_argument(
1337
+ "--num_renders",
1338
+ type=int,
1339
+ default=8,
1340
+ help="Number of renders to save of the object.",
1341
+ )
1342
+ argv = sys.argv[sys.argv.index("--") + 1 :]
1343
+ args = parser.parse_args(argv)
1344
+
1345
+ context = bpy.context
1346
+ scene = context.scene
1347
+ render = scene.render
1348
+
1349
+ # Set render settings
1350
+ render.engine = args.engine
1351
+ render.image_settings.file_format = "PNG"
1352
+ render.image_settings.color_mode = "RGB"
1353
+ render.resolution_x = 1024
1354
+ render.resolution_y = 1024
1355
+ render.resolution_percentage = 100
1356
+
1357
+ # Set EEVEE settings
1358
+ scene.eevee.taa_render_samples = 64
1359
+ scene.eevee.use_taa_reprojection = True
1360
+
1361
+ # Set cycles settings
1362
+ scene.cycles.device = "GPU"
1363
+ scene.cycles.samples = 128
1364
+ scene.cycles.diffuse_bounces = 9
1365
+ scene.cycles.glossy_bounces = 9
1366
+ scene.cycles.transparent_max_bounces = 9
1367
+ scene.cycles.transmission_bounces = 9
1368
+ scene.cycles.filter_width = 0.01
1369
+ scene.cycles.use_denoising = True
1370
+ scene.render.film_transparent = True
1371
+ bpy.context.preferences.addons["cycles"].preferences.get_devices()
1372
+ bpy.context.preferences.addons[
1373
+ "cycles"
1374
+ ].preferences.compute_device_type = "CUDA" # or "OPENCL"
1375
+ bpy.context.scene.view_layers["ViewLayer"].use_pass_z = True
1376
+
1377
+ bpy.context.view_layer.use_pass_normal = True
1378
+ render.image_settings.color_depth = "16"
1379
+ bpy.context.scene.use_nodes = True
1380
+
1381
+ # Render the images
1382
+ render_object(
1383
+ object_file=args.object_path,
1384
+ num_renders=args.num_renders,
1385
+ only_northern_hemisphere=args.only_northern_hemisphere,
1386
+ output_dir=args.output_dir,
1387
+ )
blender/distributed_uniform_lrm.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import multiprocessing
3
+ import subprocess
4
+ import time
5
+ from dataclasses import dataclass
6
+ import os
7
+ import tyro
8
+ import concurrent.futures
9
+ @dataclass
10
+ class Args:
11
+ workers_per_gpu: int
12
+ """number of workers per gpu"""
13
+ num_gpus: int = 8
14
+ """number of gpus to use. -1 means all available gpus"""
15
+ input_dir: str
16
+ save_dir: str
17
+ engine: str = "BLENDER_EEVEE"
18
+
19
+
20
+ def check_already_rendered(save_path):
21
+ if not os.path.exists(os.path.join(save_path, '02419_semantic.png')):
22
+ return False
23
+ return True
24
+
25
+ def process_file(file):
26
+ if not check_already_rendered(file[1]):
27
+ return file
28
+ return None
29
+
30
+ def worker(queue, count, gpu):
31
+ while True:
32
+ try:
33
+ item = queue.get()
34
+ if item is None:
35
+ queue.task_done()
36
+ break
37
+ data_path, save_path, engine, log_name = item
38
+ print(f"Processing: {data_path} on GPU {gpu}")
39
+ start = time.time()
40
+ if check_already_rendered(save_path):
41
+ queue.task_done()
42
+ print('========', item, 'rendered', '========')
43
+ continue
44
+ else:
45
+ os.makedirs(save_path, exist_ok=True)
46
+ command = (f"export DISPLAY=:0.{gpu} &&"
47
+ f" CUDA_VISIBLE_DEVICES={gpu} "
48
+ f" blender -b -P blender_lrm_script.py --"
49
+ f" --object_path {data_path} --output_dir {save_path} --engine {engine}")
50
+
51
+ try:
52
+ subprocess.run(command, shell=True, timeout=3600, check=True)
53
+ count.value += 1
54
+ end = time.time()
55
+ with open(log_name, 'a') as f:
56
+ f.write(f'{end - start}\n')
57
+ except subprocess.CalledProcessError as e:
58
+ print(f"Subprocess error processing {item}: {e}")
59
+ except subprocess.TimeoutExpired as e:
60
+ print(f"Timeout expired processing {item}: {e}")
61
+ except Exception as e:
62
+ print(f"Error processing {item}: {e}")
63
+ finally:
64
+ queue.task_done()
65
+
66
+ except Exception as e:
67
+ print(f"Error processing {item}: {e}")
68
+ queue.task_done()
69
+
70
+
71
+ if __name__ == "__main__":
72
+ args = tyro.cli(Args)
73
+ queue = multiprocessing.JoinableQueue()
74
+ count = multiprocessing.Value("i", 0)
75
+ log_name = f'time_log_{args.workers_per_gpu}_{args.num_gpus}_{args.engine}.txt'
76
+
77
+ if args.num_gpus == -1:
78
+ result = subprocess.run(['nvidia-smi', '--list-gpus'], stdout=subprocess.PIPE)
79
+ output = result.stdout.decode('utf-8')
80
+ args.num_gpus = output.count('GPU')
81
+
82
+ files = []
83
+
84
+ for group in [ str(i) for i in range(10) ]:
85
+ for folder in os.listdir(f'{args.input_dir}/{group}'):
86
+ filename = f'{args.input_dir}/{group}/{folder}/{folder}.vrm'
87
+ outputdir = f'{args.save_dir}/{group}/{folder}'
88
+ files.append([filename, outputdir])
89
+
90
+ # sorted the files
91
+ files = sorted(files, key=lambda x: x[0])
92
+
93
+ # Use ThreadPoolExecutor for parallel processing
94
+ with concurrent.futures.ThreadPoolExecutor() as executor:
95
+ # Map the process_file function to the files
96
+ results = list(executor.map(process_file, files))
97
+
98
+ # Filter out None values from the results
99
+ unprocess_files = [file for file in results if file is not None]
100
+
101
+ # Print the number of unprocessed files and the split ID
102
+ print(f'Unprocessed files: {len(unprocess_files)}')
103
+
104
+ # Start worker processes on each of the GPUs
105
+ for gpu_i in range(args.num_gpus):
106
+ for worker_i in range(args.workers_per_gpu):
107
+ worker_i = gpu_i * args.workers_per_gpu + worker_i
108
+ process = multiprocessing.Process(
109
+ target=worker, args=(queue, count, gpu_i)
110
+ )
111
+ process.daemon = True
112
+ process.start()
113
+
114
+ for file in unprocess_files:
115
+ queue.put((file[0], file[1], args.engine, log_name))
116
+
117
+ # Add sentinels to the queue to stop the worker processes
118
+ for i in range(args.num_gpus * args.workers_per_gpu * 10):
119
+ queue.put(None)
120
+ # Wait for all tasks to be completed
121
+ queue.join()
122
+ end = time.time()
blender/install_addon.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import bpy
2
+ import sys
3
+
4
+ def install_addon(addon_path):
5
+ bpy.ops.preferences.addon_install(filepath=addon_path)
6
+ bpy.ops.preferences.addon_enable(module=addon_path.split('/')[-1].replace('.py', '').replace('.zip', ''))
7
+ bpy.ops.wm.save_userpref()
8
+
9
+ if __name__ == "__main__":
10
+ if len(sys.argv) < 2:
11
+ print("Usage: blender --background --python install_addon.py -- <path_to_addon>")
12
+ sys.exit(1)
13
+
14
+ addon_path = sys.argv[-1]
15
+ install_addon(addon_path)
canonicalize/__init__.py ADDED
File without changes
canonicalize/models/attention.py ADDED
@@ -0,0 +1,344 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from torch import nn
9
+
10
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
11
+ from diffusers import ModelMixin
12
+ from diffusers.utils import BaseOutput
13
+ from diffusers.utils.import_utils import is_xformers_available
14
+ from diffusers.models.attention import CrossAttention, FeedForward, AdaLayerNorm
15
+
16
+ from einops import rearrange, repeat
17
+
18
+
19
+ @dataclass
20
+ class Transformer3DModelOutput(BaseOutput):
21
+ sample: torch.FloatTensor
22
+
23
+
24
+ if is_xformers_available():
25
+ import xformers
26
+ import xformers.ops
27
+ else:
28
+ xformers = None
29
+
30
+
31
+ class Transformer3DModel(ModelMixin, ConfigMixin):
32
+ @register_to_config
33
+ def __init__(
34
+ self,
35
+ num_attention_heads: int = 16,
36
+ attention_head_dim: int = 88,
37
+ in_channels: Optional[int] = None,
38
+ num_layers: int = 1,
39
+ dropout: float = 0.0,
40
+ norm_num_groups: int = 32,
41
+ cross_attention_dim: Optional[int] = None,
42
+ attention_bias: bool = False,
43
+ activation_fn: str = "geglu",
44
+ num_embeds_ada_norm: Optional[int] = None,
45
+ use_linear_projection: bool = False,
46
+ only_cross_attention: bool = False,
47
+ upcast_attention: bool = False,
48
+ use_attn_temp: bool = False,
49
+ ):
50
+ super().__init__()
51
+ self.use_linear_projection = use_linear_projection
52
+ self.num_attention_heads = num_attention_heads
53
+ self.attention_head_dim = attention_head_dim
54
+ inner_dim = num_attention_heads * attention_head_dim
55
+
56
+ # Define input layers
57
+ self.in_channels = in_channels
58
+
59
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
60
+ if use_linear_projection:
61
+ self.proj_in = nn.Linear(in_channels, inner_dim)
62
+ else:
63
+ self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
64
+
65
+ # Define transformers blocks
66
+ self.transformer_blocks = nn.ModuleList(
67
+ [
68
+ BasicTransformerBlock(
69
+ inner_dim,
70
+ num_attention_heads,
71
+ attention_head_dim,
72
+ dropout=dropout,
73
+ cross_attention_dim=cross_attention_dim,
74
+ activation_fn=activation_fn,
75
+ num_embeds_ada_norm=num_embeds_ada_norm,
76
+ attention_bias=attention_bias,
77
+ only_cross_attention=only_cross_attention,
78
+ upcast_attention=upcast_attention,
79
+ use_attn_temp = use_attn_temp,
80
+ )
81
+ for d in range(num_layers)
82
+ ]
83
+ )
84
+
85
+ # 4. Define output layers
86
+ if use_linear_projection:
87
+ self.proj_out = nn.Linear(in_channels, inner_dim)
88
+ else:
89
+ self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
90
+
91
+ def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True):
92
+ # Input
93
+ assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
94
+ video_length = hidden_states.shape[2]
95
+ hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
96
+ encoder_hidden_states = repeat(encoder_hidden_states, 'b n c -> (b f) n c', f=video_length)
97
+
98
+ batch, channel, height, weight = hidden_states.shape
99
+ residual = hidden_states
100
+
101
+ hidden_states = self.norm(hidden_states)
102
+ if not self.use_linear_projection:
103
+ hidden_states = self.proj_in(hidden_states)
104
+ inner_dim = hidden_states.shape[1]
105
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
106
+ else:
107
+ inner_dim = hidden_states.shape[1]
108
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
109
+ hidden_states = self.proj_in(hidden_states)
110
+
111
+ # Blocks
112
+ for block in self.transformer_blocks:
113
+ hidden_states = block(
114
+ hidden_states,
115
+ encoder_hidden_states=encoder_hidden_states,
116
+ timestep=timestep,
117
+ video_length=video_length
118
+ )
119
+
120
+ # Output
121
+ if not self.use_linear_projection:
122
+ hidden_states = (
123
+ hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
124
+ )
125
+ hidden_states = self.proj_out(hidden_states)
126
+ else:
127
+ hidden_states = self.proj_out(hidden_states)
128
+ hidden_states = (
129
+ hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
130
+ )
131
+
132
+ output = hidden_states + residual
133
+
134
+ output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
135
+ if not return_dict:
136
+ return (output,)
137
+
138
+ return Transformer3DModelOutput(sample=output)
139
+
140
+
141
+ class BasicTransformerBlock(nn.Module):
142
+ def __init__(
143
+ self,
144
+ dim: int,
145
+ num_attention_heads: int,
146
+ attention_head_dim: int,
147
+ dropout=0.0,
148
+ cross_attention_dim: Optional[int] = None,
149
+ activation_fn: str = "geglu",
150
+ num_embeds_ada_norm: Optional[int] = None,
151
+ attention_bias: bool = False,
152
+ only_cross_attention: bool = False,
153
+ upcast_attention: bool = False,
154
+ use_attn_temp: bool = False
155
+ ):
156
+ super().__init__()
157
+ self.only_cross_attention = only_cross_attention
158
+ self.use_ada_layer_norm = num_embeds_ada_norm is not None
159
+ self.use_attn_temp = use_attn_temp
160
+ # SC-Attn
161
+ self.attn1 = SparseCausalAttention(
162
+ query_dim=dim,
163
+ heads=num_attention_heads,
164
+ dim_head=attention_head_dim,
165
+ dropout=dropout,
166
+ bias=attention_bias,
167
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
168
+ upcast_attention=upcast_attention,
169
+ )
170
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
171
+
172
+ # Cross-Attn
173
+ if cross_attention_dim is not None:
174
+ self.attn2 = CrossAttention(
175
+ query_dim=dim,
176
+ cross_attention_dim=cross_attention_dim,
177
+ heads=num_attention_heads,
178
+ dim_head=attention_head_dim,
179
+ dropout=dropout,
180
+ bias=attention_bias,
181
+ upcast_attention=upcast_attention,
182
+ )
183
+ else:
184
+ self.attn2 = None
185
+
186
+ if cross_attention_dim is not None:
187
+ self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
188
+ else:
189
+ self.norm2 = None
190
+
191
+ # Feed-forward
192
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
193
+ self.norm3 = nn.LayerNorm(dim)
194
+
195
+ # Temp-Attn
196
+ if self.use_attn_temp:
197
+ self.attn_temp = CrossAttention(
198
+ query_dim=dim,
199
+ heads=num_attention_heads,
200
+ dim_head=attention_head_dim,
201
+ dropout=dropout,
202
+ bias=attention_bias,
203
+ upcast_attention=upcast_attention,
204
+ )
205
+ nn.init.zeros_(self.attn_temp.to_out[0].weight.data)
206
+ self.norm_temp = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
207
+
208
+ def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
209
+ if not is_xformers_available():
210
+ print("Here is how to install it")
211
+ raise ModuleNotFoundError(
212
+ "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
213
+ " xformers",
214
+ name="xformers",
215
+ )
216
+ elif not torch.cuda.is_available():
217
+ raise ValueError(
218
+ "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only"
219
+ " available for GPU "
220
+ )
221
+ else:
222
+ try:
223
+ # Make sure we can run the memory efficient attention
224
+ _ = xformers.ops.memory_efficient_attention(
225
+ torch.randn((1, 2, 40), device="cuda"),
226
+ torch.randn((1, 2, 40), device="cuda"),
227
+ torch.randn((1, 2, 40), device="cuda"),
228
+ )
229
+ except Exception as e:
230
+ raise e
231
+ self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
232
+ if self.attn2 is not None:
233
+ self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
234
+ #self.attn_temp._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
235
+
236
+ def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None, video_length=None):
237
+ # SparseCausal-Attention
238
+ norm_hidden_states = (
239
+ self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states)
240
+ )
241
+
242
+ if self.only_cross_attention:
243
+ hidden_states = (
244
+ self.attn1(norm_hidden_states, encoder_hidden_states, attention_mask=attention_mask) + hidden_states
245
+ )
246
+ else:
247
+ hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length) + hidden_states
248
+
249
+ if self.attn2 is not None:
250
+ # Cross-Attention
251
+ norm_hidden_states = (
252
+ self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
253
+ )
254
+ hidden_states = (
255
+ self.attn2(
256
+ norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
257
+ )
258
+ + hidden_states
259
+ )
260
+
261
+ # Feed-forward
262
+ hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
263
+
264
+ # Temporal-Attention
265
+ if self.use_attn_temp:
266
+ d = hidden_states.shape[1]
267
+ hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length)
268
+ norm_hidden_states = (
269
+ self.norm_temp(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_temp(hidden_states)
270
+ )
271
+ hidden_states = self.attn_temp(norm_hidden_states) + hidden_states
272
+ hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
273
+
274
+ return hidden_states
275
+
276
+
277
+ class SparseCausalAttention(CrossAttention):
278
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None, use_full_attn=True):
279
+ batch_size, sequence_length, _ = hidden_states.shape
280
+
281
+ encoder_hidden_states = encoder_hidden_states
282
+
283
+ if self.group_norm is not None:
284
+ hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
285
+
286
+ query = self.to_q(hidden_states)
287
+ # query = rearrange(query, "(b f) d c -> b (f d) c", f=video_length)
288
+ dim = query.shape[-1]
289
+ query = self.reshape_heads_to_batch_dim(query)
290
+
291
+ if self.added_kv_proj_dim is not None:
292
+ raise NotImplementedError
293
+
294
+ encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
295
+ key = self.to_k(encoder_hidden_states)
296
+ value = self.to_v(encoder_hidden_states)
297
+
298
+ former_frame_index = torch.arange(video_length) - 1
299
+ former_frame_index[0] = 0
300
+
301
+ key = rearrange(key, "(b f) d c -> b f d c", f=video_length)
302
+ if not use_full_attn:
303
+ key = torch.cat([key[:, [0] * video_length], key[:, former_frame_index]], dim=2)
304
+ else:
305
+ # key = torch.cat([key[:, [0] * video_length], key[:, [1] * video_length], key[:, [2] * video_length], key[:, [3] * video_length]], dim=2)
306
+ key_video_length = [key[:, [i] * video_length] for i in range(video_length)]
307
+ key = torch.cat(key_video_length, dim=2)
308
+ key = rearrange(key, "b f d c -> (b f) d c")
309
+
310
+ value = rearrange(value, "(b f) d c -> b f d c", f=video_length)
311
+ if not use_full_attn:
312
+ value = torch.cat([value[:, [0] * video_length], value[:, former_frame_index]], dim=2)
313
+ else:
314
+ # value = torch.cat([value[:, [0] * video_length], value[:, [1] * video_length], value[:, [2] * video_length], value[:, [3] * video_length]], dim=2)
315
+ value_video_length = [value[:, [i] * video_length] for i in range(video_length)]
316
+ value = torch.cat(value_video_length, dim=2)
317
+ value = rearrange(value, "b f d c -> (b f) d c")
318
+
319
+ key = self.reshape_heads_to_batch_dim(key)
320
+ value = self.reshape_heads_to_batch_dim(value)
321
+
322
+ if attention_mask is not None:
323
+ if attention_mask.shape[-1] != query.shape[1]:
324
+ target_length = query.shape[1]
325
+ attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
326
+ attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
327
+
328
+ # attention, what we cannot get enough of
329
+ if self._use_memory_efficient_attention_xformers:
330
+ hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
331
+ # Some versions of xformers return output in fp32, cast it back to the dtype of the input
332
+ hidden_states = hidden_states.to(query.dtype)
333
+ else:
334
+ if self._slice_size is None or query.shape[0] // self._slice_size == 1:
335
+ hidden_states = self._attention(query, key, value, attention_mask)
336
+ else:
337
+ hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
338
+
339
+ # linear proj
340
+ hidden_states = self.to_out[0](hidden_states)
341
+
342
+ # dropout
343
+ hidden_states = self.to_out[1](hidden_states)
344
+ return hidden_states
canonicalize/models/imageproj.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
2
+ import math
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ # FFN
8
+ def FeedForward(dim, mult=4):
9
+ inner_dim = int(dim * mult)
10
+ return nn.Sequential(
11
+ nn.LayerNorm(dim),
12
+ nn.Linear(dim, inner_dim, bias=False),
13
+ nn.GELU(),
14
+ nn.Linear(inner_dim, dim, bias=False),
15
+ )
16
+
17
+ def reshape_tensor(x, heads):
18
+ bs, length, width = x.shape
19
+ #(bs, length, width) --> (bs, length, n_heads, dim_per_head)
20
+ x = x.view(bs, length, heads, -1)
21
+ # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
22
+ x = x.transpose(1, 2)
23
+ # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
24
+ x = x.reshape(bs, heads, length, -1)
25
+ return x
26
+
27
+
28
+ class PerceiverAttention(nn.Module):
29
+ def __init__(self, *, dim, dim_head=64, heads=8):
30
+ super().__init__()
31
+ self.scale = dim_head**-0.5
32
+ self.dim_head = dim_head
33
+ self.heads = heads
34
+ inner_dim = dim_head * heads
35
+
36
+ self.norm1 = nn.LayerNorm(dim)
37
+ self.norm2 = nn.LayerNorm(dim)
38
+
39
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
40
+ self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
41
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
42
+
43
+
44
+ def forward(self, x, latents):
45
+ """
46
+ Args:
47
+ x (torch.Tensor): image features
48
+ shape (b, n1, D)
49
+ latent (torch.Tensor): latent features
50
+ shape (b, n2, D)
51
+ """
52
+ x = self.norm1(x)
53
+ latents = self.norm2(latents)
54
+
55
+ b, l, _ = latents.shape
56
+
57
+ q = self.to_q(latents)
58
+ kv_input = torch.cat((x, latents), dim=-2)
59
+ k, v = self.to_kv(kv_input).chunk(2, dim=-1)
60
+
61
+ q = reshape_tensor(q, self.heads)
62
+ k = reshape_tensor(k, self.heads)
63
+ v = reshape_tensor(v, self.heads)
64
+
65
+ # attention
66
+ scale = 1 / math.sqrt(math.sqrt(self.dim_head))
67
+ weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
68
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
69
+ out = weight @ v
70
+
71
+ out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
72
+
73
+ return self.to_out(out)
74
+
75
+ class Resampler(nn.Module):
76
+ def __init__(
77
+ self,
78
+ dim=1024,
79
+ depth=8,
80
+ dim_head=64,
81
+ heads=16,
82
+ num_queries=8,
83
+ embedding_dim=768,
84
+ output_dim=1024,
85
+ ff_mult=4,
86
+ ):
87
+ super().__init__()
88
+
89
+ self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
90
+
91
+ self.proj_in = nn.Linear(embedding_dim, dim)
92
+
93
+ self.proj_out = nn.Linear(dim, output_dim)
94
+ self.norm_out = nn.LayerNorm(output_dim)
95
+
96
+ self.layers = nn.ModuleList([])
97
+ for _ in range(depth):
98
+ self.layers.append(
99
+ nn.ModuleList(
100
+ [
101
+ PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
102
+ FeedForward(dim=dim, mult=ff_mult),
103
+ ]
104
+ )
105
+ )
106
+
107
+ def forward(self, x):
108
+
109
+ latents = self.latents.repeat(x.size(0), 1, 1)
110
+
111
+ x = self.proj_in(x)
112
+
113
+ for attn, ff in self.layers:
114
+ latents = attn(x, latents) + latents
115
+ latents = ff(latents) + latents
116
+
117
+ latents = self.proj_out(latents)
118
+ return self.norm_out(latents)
canonicalize/models/refunet.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from einops import rearrange
3
+ from typing import Any, Dict, Optional
4
+ from diffusers.utils.import_utils import is_xformers_available
5
+ from canonicalize.models.transformer_mv2d import XFormersMVAttnProcessor, MVAttnProcessor
6
+
7
+
8
+ class ReferenceOnlyAttnProc(torch.nn.Module):
9
+ def __init__(
10
+ self,
11
+ chained_proc,
12
+ enabled=False,
13
+ name=None
14
+ ) -> None:
15
+ super().__init__()
16
+ self.enabled = enabled
17
+ self.chained_proc = chained_proc
18
+ self.name = name
19
+
20
+ def __call__(
21
+ self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None,
22
+ mode="w", ref_dict: dict = None, is_cfg_guidance = False,num_views=4,
23
+ multiview_attention=True,
24
+ cross_domain_attention=False,
25
+ ) -> Any:
26
+ if encoder_hidden_states is None:
27
+ encoder_hidden_states = hidden_states
28
+
29
+ if self.enabled:
30
+ if mode == 'w':
31
+ ref_dict[self.name] = encoder_hidden_states
32
+ res = self.chained_proc(attn, hidden_states, encoder_hidden_states, attention_mask, num_views=1,
33
+ multiview_attention=False,
34
+ cross_domain_attention=False,)
35
+ elif mode == 'r':
36
+ encoder_hidden_states = rearrange(encoder_hidden_states, '(b t) d c-> b (t d) c', t=num_views)
37
+ if self.name in ref_dict:
38
+ encoder_hidden_states = torch.cat([encoder_hidden_states, ref_dict.pop(self.name)], dim=1).unsqueeze(1).repeat(1,num_views,1,1).flatten(0,1)
39
+ res = self.chained_proc(attn, hidden_states, encoder_hidden_states, attention_mask, num_views=num_views,
40
+ multiview_attention=False,
41
+ cross_domain_attention=False,)
42
+ elif mode == 'm':
43
+ encoder_hidden_states = torch.cat([encoder_hidden_states, ref_dict[self.name]], dim=1)
44
+ elif mode == 'n':
45
+ encoder_hidden_states = rearrange(encoder_hidden_states, '(b t) d c-> b (t d) c', t=num_views)
46
+ encoder_hidden_states = torch.cat([encoder_hidden_states], dim=1).unsqueeze(1).repeat(1,num_views,1,1).flatten(0,1)
47
+ res = self.chained_proc(attn, hidden_states, encoder_hidden_states, attention_mask, num_views=num_views,
48
+ multiview_attention=False,
49
+ cross_domain_attention=False,)
50
+ else:
51
+ assert False, mode
52
+ else:
53
+ res = self.chained_proc(attn, hidden_states, encoder_hidden_states, attention_mask)
54
+ return res
55
+
56
+ class RefOnlyNoisedUNet(torch.nn.Module):
57
+ def __init__(self, unet, train_sched, val_sched) -> None:
58
+ super().__init__()
59
+ self.unet = unet
60
+ self.train_sched = train_sched
61
+ self.val_sched = val_sched
62
+
63
+ unet_lora_attn_procs = dict()
64
+ for name, _ in unet.attn_processors.items():
65
+ if is_xformers_available():
66
+ default_attn_proc = XFormersMVAttnProcessor()
67
+ else:
68
+ default_attn_proc = MVAttnProcessor()
69
+ unet_lora_attn_procs[name] = ReferenceOnlyAttnProc(
70
+ default_attn_proc, enabled=name.endswith("attn1.processor"), name=name)
71
+
72
+ self.unet.set_attn_processor(unet_lora_attn_procs)
73
+
74
+ def __getattr__(self, name: str):
75
+ try:
76
+ return super().__getattr__(name)
77
+ except AttributeError:
78
+ return getattr(self.unet, name)
79
+
80
+ def forward_cond(self, noisy_cond_lat, timestep, encoder_hidden_states, class_labels, ref_dict, is_cfg_guidance, **kwargs):
81
+ if is_cfg_guidance:
82
+ encoder_hidden_states = encoder_hidden_states[1:]
83
+ class_labels = class_labels[1:]
84
+ self.unet(
85
+ noisy_cond_lat, timestep,
86
+ encoder_hidden_states=encoder_hidden_states,
87
+ class_labels=class_labels,
88
+ cross_attention_kwargs=dict(mode="w", ref_dict=ref_dict),
89
+ **kwargs
90
+ )
91
+
92
+ def forward(
93
+ self, sample, timestep, encoder_hidden_states, class_labels=None,
94
+ *args, cross_attention_kwargs,
95
+ down_block_res_samples=None, mid_block_res_sample=None,
96
+ **kwargs
97
+ ):
98
+ cond_lat = cross_attention_kwargs['cond_lat']
99
+ is_cfg_guidance = cross_attention_kwargs.get('is_cfg_guidance', False)
100
+ noise = torch.randn_like(cond_lat)
101
+ if self.training:
102
+ noisy_cond_lat = self.train_sched.add_noise(cond_lat, noise, timestep)
103
+ noisy_cond_lat = self.train_sched.scale_model_input(noisy_cond_lat, timestep)
104
+ else:
105
+ noisy_cond_lat = self.val_sched.add_noise(cond_lat, noise, timestep.reshape(-1))
106
+ noisy_cond_lat = self.val_sched.scale_model_input(noisy_cond_lat, timestep.reshape(-1))
107
+ ref_dict = {}
108
+ self.forward_cond(
109
+ noisy_cond_lat, timestep,
110
+ encoder_hidden_states, class_labels,
111
+ ref_dict, is_cfg_guidance, **kwargs
112
+ )
113
+ weight_dtype = self.unet.dtype
114
+ return self.unet(
115
+ sample, timestep,
116
+ encoder_hidden_states, *args,
117
+ class_labels=class_labels,
118
+ cross_attention_kwargs=dict(mode="r", ref_dict=ref_dict, is_cfg_guidance=is_cfg_guidance),
119
+ down_block_additional_residuals=[
120
+ sample.to(dtype=weight_dtype) for sample in down_block_res_samples
121
+ ] if down_block_res_samples is not None else None,
122
+ mid_block_additional_residual=(
123
+ mid_block_res_sample.to(dtype=weight_dtype)
124
+ if mid_block_res_sample is not None else None
125
+ ),
126
+ **kwargs
127
+ )
canonicalize/models/resnet.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ from einops import rearrange
8
+
9
+
10
+ class InflatedConv3d(nn.Conv2d):
11
+ def forward(self, x):
12
+ video_length = x.shape[2]
13
+
14
+ x = rearrange(x, "b c f h w -> (b f) c h w")
15
+ x = super().forward(x)
16
+ x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
17
+
18
+ return x
19
+
20
+
21
+ class Upsample3D(nn.Module):
22
+ def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
23
+ super().__init__()
24
+ self.channels = channels
25
+ self.out_channels = out_channels or channels
26
+ self.use_conv = use_conv
27
+ self.use_conv_transpose = use_conv_transpose
28
+ self.name = name
29
+
30
+ conv = None
31
+ if use_conv_transpose:
32
+ raise NotImplementedError
33
+ elif use_conv:
34
+ conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1)
35
+
36
+ if name == "conv":
37
+ self.conv = conv
38
+ else:
39
+ self.Conv2d_0 = conv
40
+
41
+ def forward(self, hidden_states, output_size=None):
42
+ assert hidden_states.shape[1] == self.channels
43
+
44
+ if self.use_conv_transpose:
45
+ raise NotImplementedError
46
+
47
+ # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
48
+ dtype = hidden_states.dtype
49
+ if dtype == torch.bfloat16:
50
+ hidden_states = hidden_states.to(torch.float32)
51
+
52
+ # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
53
+ if hidden_states.shape[0] >= 64:
54
+ hidden_states = hidden_states.contiguous()
55
+
56
+ # if `output_size` is passed we force the interpolation output
57
+ # size and do not make use of `scale_factor=2`
58
+ if output_size is None:
59
+ hidden_states = F.interpolate(hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest")
60
+ else:
61
+ hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
62
+
63
+ # If the input is bfloat16, we cast back to bfloat16
64
+ if dtype == torch.bfloat16:
65
+ hidden_states = hidden_states.to(dtype)
66
+
67
+ if self.use_conv:
68
+ if self.name == "conv":
69
+ hidden_states = self.conv(hidden_states)
70
+ else:
71
+ hidden_states = self.Conv2d_0(hidden_states)
72
+
73
+ return hidden_states
74
+
75
+
76
+ class Downsample3D(nn.Module):
77
+ def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
78
+ super().__init__()
79
+ self.channels = channels
80
+ self.out_channels = out_channels or channels
81
+ self.use_conv = use_conv
82
+ self.padding = padding
83
+ stride = 2
84
+ self.name = name
85
+
86
+ if use_conv:
87
+ conv = InflatedConv3d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
88
+ else:
89
+ raise NotImplementedError
90
+
91
+ if name == "conv":
92
+ self.Conv2d_0 = conv
93
+ self.conv = conv
94
+ elif name == "Conv2d_0":
95
+ self.conv = conv
96
+ else:
97
+ self.conv = conv
98
+
99
+ def forward(self, hidden_states):
100
+ assert hidden_states.shape[1] == self.channels
101
+ if self.use_conv and self.padding == 0:
102
+ raise NotImplementedError
103
+
104
+ assert hidden_states.shape[1] == self.channels
105
+ hidden_states = self.conv(hidden_states)
106
+
107
+ return hidden_states
108
+
109
+
110
+ class ResnetBlock3D(nn.Module):
111
+ def __init__(
112
+ self,
113
+ *,
114
+ in_channels,
115
+ out_channels=None,
116
+ conv_shortcut=False,
117
+ dropout=0.0,
118
+ temb_channels=512,
119
+ groups=32,
120
+ groups_out=None,
121
+ pre_norm=True,
122
+ eps=1e-6,
123
+ non_linearity="swish",
124
+ time_embedding_norm="default",
125
+ output_scale_factor=1.0,
126
+ use_in_shortcut=None,
127
+ ):
128
+ super().__init__()
129
+ self.pre_norm = pre_norm
130
+ self.pre_norm = True
131
+ self.in_channels = in_channels
132
+ out_channels = in_channels if out_channels is None else out_channels
133
+ self.out_channels = out_channels
134
+ self.use_conv_shortcut = conv_shortcut
135
+ self.time_embedding_norm = time_embedding_norm
136
+ self.output_scale_factor = output_scale_factor
137
+
138
+ if groups_out is None:
139
+ groups_out = groups
140
+
141
+ self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
142
+
143
+ self.conv1 = InflatedConv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
144
+
145
+ if temb_channels is not None:
146
+ if self.time_embedding_norm == "default":
147
+ time_emb_proj_out_channels = out_channels
148
+ elif self.time_embedding_norm == "scale_shift":
149
+ time_emb_proj_out_channels = out_channels * 2
150
+ else:
151
+ raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ")
152
+
153
+ self.time_emb_proj = torch.nn.Linear(temb_channels, time_emb_proj_out_channels)
154
+ else:
155
+ self.time_emb_proj = None
156
+
157
+ self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
158
+ self.dropout = torch.nn.Dropout(dropout)
159
+ self.conv2 = InflatedConv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
160
+
161
+ if non_linearity == "swish":
162
+ self.nonlinearity = lambda x: F.silu(x)
163
+ elif non_linearity == "mish":
164
+ self.nonlinearity = Mish()
165
+ elif non_linearity == "silu":
166
+ self.nonlinearity = nn.SiLU()
167
+
168
+ self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut
169
+
170
+ self.conv_shortcut = None
171
+ if self.use_in_shortcut:
172
+ self.conv_shortcut = InflatedConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
173
+
174
+ def forward(self, input_tensor, temb):
175
+ hidden_states = input_tensor
176
+
177
+ hidden_states = self.norm1(hidden_states)
178
+ hidden_states = self.nonlinearity(hidden_states)
179
+
180
+ hidden_states = self.conv1(hidden_states)
181
+
182
+ if temb is not None:
183
+ temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, :, None, None].permute(0,2,1,3,4)
184
+
185
+ if temb is not None and self.time_embedding_norm == "default":
186
+ hidden_states = hidden_states + temb
187
+
188
+ hidden_states = self.norm2(hidden_states)
189
+
190
+ if temb is not None and self.time_embedding_norm == "scale_shift":
191
+ scale, shift = torch.chunk(temb, 2, dim=1)
192
+ hidden_states = hidden_states * (1 + scale) + shift
193
+
194
+ hidden_states = self.nonlinearity(hidden_states)
195
+
196
+ hidden_states = self.dropout(hidden_states)
197
+ hidden_states = self.conv2(hidden_states)
198
+
199
+ if self.conv_shortcut is not None:
200
+ input_tensor = self.conv_shortcut(input_tensor)
201
+
202
+ output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
203
+
204
+ return output_tensor
205
+
206
+
207
+ class Mish(torch.nn.Module):
208
+ def forward(self, hidden_states):
209
+ return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states))
canonicalize/models/transformer_mv2d.py ADDED
@@ -0,0 +1,976 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from dataclasses import dataclass
15
+ from typing import Any, Dict, Optional
16
+
17
+ import torch
18
+ import torch.nn.functional as F
19
+ from torch import nn
20
+
21
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
22
+ from diffusers.models.embeddings import ImagePositionalEmbeddings
23
+ from diffusers.utils import BaseOutput, deprecate
24
+ try:
25
+ from diffusers.utils import maybe_allow_in_graph
26
+ except:
27
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
28
+ from diffusers.models.attention import FeedForward, AdaLayerNorm, AdaLayerNormZero, Attention
29
+ from diffusers.models.embeddings import PatchEmbed
30
+ from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear
31
+ from diffusers.models.modeling_utils import ModelMixin
32
+ from diffusers.utils.import_utils import is_xformers_available
33
+
34
+ from einops import rearrange
35
+ import pdb
36
+ import random
37
+
38
+
39
+ if is_xformers_available():
40
+ import xformers
41
+ import xformers.ops
42
+ else:
43
+ xformers = None
44
+
45
+
46
+ @dataclass
47
+ class TransformerMV2DModelOutput(BaseOutput):
48
+ """
49
+ The output of [`Transformer2DModel`].
50
+
51
+ Args:
52
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
53
+ The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
54
+ distributions for the unnoised latent pixels.
55
+ """
56
+
57
+ sample: torch.FloatTensor
58
+
59
+
60
+ class TransformerMV2DModel(ModelMixin, ConfigMixin):
61
+ """
62
+ A 2D Transformer model for image-like data.
63
+
64
+ Parameters:
65
+ num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
66
+ attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
67
+ in_channels (`int`, *optional*):
68
+ The number of channels in the input and output (specify if the input is **continuous**).
69
+ num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
70
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
71
+ cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
72
+ sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
73
+ This is fixed during training since it is used to learn a number of position embeddings.
74
+ num_vector_embeds (`int`, *optional*):
75
+ The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**).
76
+ Includes the class for the masked latent pixel.
77
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
78
+ num_embeds_ada_norm ( `int`, *optional*):
79
+ The number of diffusion steps used during training. Pass if at least one of the norm_layers is
80
+ `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
81
+ added to the hidden states.
82
+
83
+ During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`.
84
+ attention_bias (`bool`, *optional*):
85
+ Configure if the `TransformerBlocks` attention should contain a bias parameter.
86
+ """
87
+
88
+ @register_to_config
89
+ def __init__(
90
+ self,
91
+ num_attention_heads: int = 16,
92
+ attention_head_dim: int = 88,
93
+ in_channels: Optional[int] = None,
94
+ out_channels: Optional[int] = None,
95
+ num_layers: int = 1,
96
+ dropout: float = 0.0,
97
+ norm_num_groups: int = 32,
98
+ cross_attention_dim: Optional[int] = None,
99
+ attention_bias: bool = False,
100
+ sample_size: Optional[int] = None,
101
+ num_vector_embeds: Optional[int] = None,
102
+ patch_size: Optional[int] = None,
103
+ activation_fn: str = "geglu",
104
+ num_embeds_ada_norm: Optional[int] = None,
105
+ use_linear_projection: bool = False,
106
+ only_cross_attention: bool = False,
107
+ upcast_attention: bool = False,
108
+ norm_type: str = "layer_norm",
109
+ norm_elementwise_affine: bool = True,
110
+ num_views: int = 1,
111
+ joint_attention: bool=False,
112
+ joint_attention_twice: bool=False,
113
+ multiview_attention: bool=True,
114
+ cross_domain_attention: bool=False
115
+ ):
116
+ super().__init__()
117
+ self.use_linear_projection = use_linear_projection
118
+ self.num_attention_heads = num_attention_heads
119
+ self.attention_head_dim = attention_head_dim
120
+ inner_dim = num_attention_heads * attention_head_dim
121
+
122
+ # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
123
+ # Define whether input is continuous or discrete depending on configuration
124
+ self.is_input_continuous = (in_channels is not None) and (patch_size is None)
125
+ self.is_input_vectorized = num_vector_embeds is not None
126
+ self.is_input_patches = in_channels is not None and patch_size is not None
127
+
128
+ if norm_type == "layer_norm" and num_embeds_ada_norm is not None:
129
+ deprecation_message = (
130
+ f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or"
131
+ " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config."
132
+ " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect"
133
+ " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it"
134
+ " would be very nice if you could open a Pull request for the `transformer/config.json` file"
135
+ )
136
+ deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False)
137
+ norm_type = "ada_norm"
138
+
139
+ if self.is_input_continuous and self.is_input_vectorized:
140
+ raise ValueError(
141
+ f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
142
+ " sure that either `in_channels` or `num_vector_embeds` is None."
143
+ )
144
+ elif self.is_input_vectorized and self.is_input_patches:
145
+ raise ValueError(
146
+ f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make"
147
+ " sure that either `num_vector_embeds` or `num_patches` is None."
148
+ )
149
+ elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches:
150
+ raise ValueError(
151
+ f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:"
152
+ f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None."
153
+ )
154
+
155
+ # 2. Define input layers
156
+ if self.is_input_continuous:
157
+ self.in_channels = in_channels
158
+
159
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
160
+ if use_linear_projection:
161
+ self.proj_in = LoRACompatibleLinear(in_channels, inner_dim)
162
+ else:
163
+ self.proj_in = LoRACompatibleConv(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
164
+ elif self.is_input_vectorized:
165
+ assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
166
+ assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed"
167
+
168
+ self.height = sample_size
169
+ self.width = sample_size
170
+ self.num_vector_embeds = num_vector_embeds
171
+ self.num_latent_pixels = self.height * self.width
172
+
173
+ self.latent_image_embedding = ImagePositionalEmbeddings(
174
+ num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width
175
+ )
176
+ elif self.is_input_patches:
177
+ assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size"
178
+
179
+ self.height = sample_size
180
+ self.width = sample_size
181
+
182
+ self.patch_size = patch_size
183
+ self.pos_embed = PatchEmbed(
184
+ height=sample_size,
185
+ width=sample_size,
186
+ patch_size=patch_size,
187
+ in_channels=in_channels,
188
+ embed_dim=inner_dim,
189
+ )
190
+
191
+ # 3. Define transformers blocks
192
+ self.transformer_blocks = nn.ModuleList(
193
+ [
194
+ BasicMVTransformerBlock(
195
+ inner_dim,
196
+ num_attention_heads,
197
+ attention_head_dim,
198
+ dropout=dropout,
199
+ cross_attention_dim=cross_attention_dim,
200
+ activation_fn=activation_fn,
201
+ num_embeds_ada_norm=num_embeds_ada_norm,
202
+ attention_bias=attention_bias,
203
+ only_cross_attention=only_cross_attention,
204
+ upcast_attention=upcast_attention,
205
+ norm_type=norm_type,
206
+ norm_elementwise_affine=norm_elementwise_affine,
207
+ num_views=num_views,
208
+ joint_attention=joint_attention,
209
+ joint_attention_twice=joint_attention_twice,
210
+ multiview_attention=multiview_attention,
211
+ cross_domain_attention=cross_domain_attention
212
+ )
213
+ for d in range(num_layers)
214
+ ]
215
+ )
216
+
217
+ # 4. Define output layers
218
+ self.out_channels = in_channels if out_channels is None else out_channels
219
+ if self.is_input_continuous:
220
+ # TODO: should use out_channels for continuous projections
221
+ if use_linear_projection:
222
+ self.proj_out = LoRACompatibleLinear(inner_dim, in_channels)
223
+ else:
224
+ self.proj_out = LoRACompatibleConv(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
225
+ elif self.is_input_vectorized:
226
+ self.norm_out = nn.LayerNorm(inner_dim)
227
+ self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
228
+ elif self.is_input_patches:
229
+ self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
230
+ self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim)
231
+ self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
232
+
233
+ def forward(
234
+ self,
235
+ hidden_states: torch.Tensor,
236
+ encoder_hidden_states: Optional[torch.Tensor] = None,
237
+ timestep: Optional[torch.LongTensor] = None,
238
+ class_labels: Optional[torch.LongTensor] = None,
239
+ cross_attention_kwargs: Dict[str, Any] = None,
240
+ attention_mask: Optional[torch.Tensor] = None,
241
+ encoder_attention_mask: Optional[torch.Tensor] = None,
242
+ return_dict: bool = True,
243
+ ):
244
+ """
245
+ The [`Transformer2DModel`] forward method.
246
+
247
+ Args:
248
+ hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
249
+ Input `hidden_states`.
250
+ encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
251
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
252
+ self-attention.
253
+ timestep ( `torch.LongTensor`, *optional*):
254
+ Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
255
+ class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
256
+ Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
257
+ `AdaLayerZeroNorm`.
258
+ encoder_attention_mask ( `torch.Tensor`, *optional*):
259
+ Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
260
+
261
+ * Mask `(batch, sequence_length)` True = keep, False = discard.
262
+ * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
263
+
264
+ If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
265
+ above. This bias will be added to the cross-attention scores.
266
+ return_dict (`bool`, *optional*, defaults to `True`):
267
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
268
+ tuple.
269
+
270
+ Returns:
271
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
272
+ `tuple` where the first element is the sample tensor.
273
+ """
274
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
275
+ # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
276
+ # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
277
+ # expects mask of shape:
278
+ # [batch, key_tokens]
279
+ # adds singleton query_tokens dimension:
280
+ # [batch, 1, key_tokens]
281
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
282
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
283
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
284
+ if attention_mask is not None and attention_mask.ndim == 2:
285
+ # assume that mask is expressed as:
286
+ # (1 = keep, 0 = discard)
287
+ # convert mask into a bias that can be added to attention scores:
288
+ # (keep = +0, discard = -10000.0)
289
+ attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
290
+ attention_mask = attention_mask.unsqueeze(1)
291
+
292
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
293
+ if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
294
+ encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
295
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
296
+
297
+ # 1. Input
298
+ if self.is_input_continuous:
299
+ batch, _, height, width = hidden_states.shape
300
+ residual = hidden_states
301
+
302
+ hidden_states = self.norm(hidden_states)
303
+ if not self.use_linear_projection:
304
+ hidden_states = self.proj_in(hidden_states)
305
+ inner_dim = hidden_states.shape[1]
306
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
307
+ else:
308
+ inner_dim = hidden_states.shape[1]
309
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
310
+ hidden_states = self.proj_in(hidden_states)
311
+ elif self.is_input_vectorized:
312
+ hidden_states = self.latent_image_embedding(hidden_states)
313
+ elif self.is_input_patches:
314
+ hidden_states = self.pos_embed(hidden_states)
315
+
316
+ # 2. Blocks
317
+ for block in self.transformer_blocks:
318
+ hidden_states = block(
319
+ hidden_states,
320
+ attention_mask=attention_mask,
321
+ encoder_hidden_states=encoder_hidden_states,
322
+ encoder_attention_mask=encoder_attention_mask,
323
+ timestep=timestep,
324
+ cross_attention_kwargs=cross_attention_kwargs,
325
+ class_labels=class_labels,
326
+ )
327
+
328
+ # 3. Output
329
+ if self.is_input_continuous:
330
+ if not self.use_linear_projection:
331
+ hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
332
+ hidden_states = self.proj_out(hidden_states)
333
+ else:
334
+ hidden_states = self.proj_out(hidden_states)
335
+ hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
336
+
337
+ output = hidden_states + residual
338
+ elif self.is_input_vectorized:
339
+ hidden_states = self.norm_out(hidden_states)
340
+ logits = self.out(hidden_states)
341
+ # (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
342
+ logits = logits.permute(0, 2, 1)
343
+
344
+ # log(p(x_0))
345
+ output = F.log_softmax(logits.double(), dim=1).float()
346
+ elif self.is_input_patches:
347
+ # TODO: cleanup!
348
+ conditioning = self.transformer_blocks[0].norm1.emb(
349
+ timestep, class_labels, hidden_dtype=hidden_states.dtype
350
+ )
351
+ shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
352
+ hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
353
+ hidden_states = self.proj_out_2(hidden_states)
354
+
355
+ # unpatchify
356
+ height = width = int(hidden_states.shape[1] ** 0.5)
357
+ hidden_states = hidden_states.reshape(
358
+ shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
359
+ )
360
+ hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
361
+ output = hidden_states.reshape(
362
+ shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
363
+ )
364
+
365
+ if not return_dict:
366
+ return (output,)
367
+
368
+ return TransformerMV2DModelOutput(sample=output)
369
+
370
+
371
+ @maybe_allow_in_graph
372
+ class BasicMVTransformerBlock(nn.Module):
373
+ r"""
374
+ A basic Transformer block.
375
+
376
+ Parameters:
377
+ dim (`int`): The number of channels in the input and output.
378
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
379
+ attention_head_dim (`int`): The number of channels in each head.
380
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
381
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
382
+ only_cross_attention (`bool`, *optional*):
383
+ Whether to use only cross-attention layers. In this case two cross attention layers are used.
384
+ double_self_attention (`bool`, *optional*):
385
+ Whether to use two self-attention layers. In this case no cross attention layers are used.
386
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
387
+ num_embeds_ada_norm (:
388
+ obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
389
+ attention_bias (:
390
+ obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
391
+ """
392
+
393
+ def __init__(
394
+ self,
395
+ dim: int,
396
+ num_attention_heads: int,
397
+ attention_head_dim: int,
398
+ dropout=0.0,
399
+ cross_attention_dim: Optional[int] = None,
400
+ activation_fn: str = "geglu",
401
+ num_embeds_ada_norm: Optional[int] = None,
402
+ attention_bias: bool = False,
403
+ only_cross_attention: bool = False,
404
+ double_self_attention: bool = False,
405
+ upcast_attention: bool = False,
406
+ norm_elementwise_affine: bool = True,
407
+ norm_type: str = "layer_norm",
408
+ final_dropout: bool = False,
409
+ num_views: int = 1,
410
+ joint_attention: bool = False,
411
+ joint_attention_twice: bool = False,
412
+ multiview_attention: bool = True,
413
+ cross_domain_attention: bool = False
414
+ ):
415
+ super().__init__()
416
+ self.only_cross_attention = only_cross_attention
417
+
418
+ self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
419
+ self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
420
+
421
+ if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
422
+ raise ValueError(
423
+ f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
424
+ f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
425
+ )
426
+
427
+ # Define 3 blocks. Each block has its own normalization layer.
428
+ # 1. Self-Attn
429
+ if self.use_ada_layer_norm:
430
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
431
+ elif self.use_ada_layer_norm_zero:
432
+ self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
433
+ else:
434
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
435
+
436
+ self.multiview_attention = multiview_attention
437
+ self.cross_domain_attention = cross_domain_attention
438
+ self.attn1 = CustomAttention(
439
+ query_dim=dim,
440
+ heads=num_attention_heads,
441
+ dim_head=attention_head_dim,
442
+ dropout=dropout,
443
+ bias=attention_bias,
444
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
445
+ upcast_attention=upcast_attention,
446
+ processor=MVAttnProcessor()
447
+ )
448
+
449
+ # 2. Cross-Attn
450
+ if cross_attention_dim is not None or double_self_attention:
451
+ # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
452
+ # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
453
+ # the second cross attention block.
454
+ self.norm2 = (
455
+ AdaLayerNorm(dim, num_embeds_ada_norm)
456
+ if self.use_ada_layer_norm
457
+ else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
458
+ )
459
+ self.attn2 = Attention(
460
+ query_dim=dim,
461
+ cross_attention_dim=cross_attention_dim if not double_self_attention else None,
462
+ heads=num_attention_heads,
463
+ dim_head=attention_head_dim,
464
+ dropout=dropout,
465
+ bias=attention_bias,
466
+ upcast_attention=upcast_attention,
467
+ ) # is self-attn if encoder_hidden_states is none
468
+ else:
469
+ self.norm2 = None
470
+ self.attn2 = None
471
+
472
+ # 3. Feed-forward
473
+ self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
474
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
475
+
476
+ # let chunk size default to None
477
+ self._chunk_size = None
478
+ self._chunk_dim = 0
479
+
480
+ self.num_views = num_views
481
+
482
+ self.joint_attention = joint_attention
483
+
484
+ if self.joint_attention:
485
+ # Joint task -Attn
486
+ self.attn_joint = CustomJointAttention(
487
+ query_dim=dim,
488
+ heads=num_attention_heads,
489
+ dim_head=attention_head_dim,
490
+ dropout=dropout,
491
+ bias=attention_bias,
492
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
493
+ upcast_attention=upcast_attention,
494
+ processor=JointAttnProcessor()
495
+ )
496
+ nn.init.zeros_(self.attn_joint.to_out[0].weight.data)
497
+ self.norm_joint = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
498
+
499
+
500
+ self.joint_attention_twice = joint_attention_twice
501
+
502
+ if self.joint_attention_twice:
503
+ print("joint twice")
504
+ # Joint task -Attn
505
+ self.attn_joint_twice = CustomJointAttention(
506
+ query_dim=dim,
507
+ heads=num_attention_heads,
508
+ dim_head=attention_head_dim,
509
+ dropout=dropout,
510
+ bias=attention_bias,
511
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
512
+ upcast_attention=upcast_attention,
513
+ processor=JointAttnProcessor()
514
+ )
515
+ nn.init.zeros_(self.attn_joint_twice.to_out[0].weight.data)
516
+ self.norm_joint_twice = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
517
+
518
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
519
+ # Sets chunk feed-forward
520
+ self._chunk_size = chunk_size
521
+ self._chunk_dim = dim
522
+
523
+ def forward(
524
+ self,
525
+ hidden_states: torch.FloatTensor,
526
+ attention_mask: Optional[torch.FloatTensor] = None,
527
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
528
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
529
+ timestep: Optional[torch.LongTensor] = None,
530
+ cross_attention_kwargs: Dict[str, Any] = None,
531
+ class_labels: Optional[torch.LongTensor] = None,
532
+ ):
533
+ assert attention_mask is None # not supported yet
534
+ # Notice that normalization is always applied before the real computation in the following blocks.
535
+ # 1. Self-Attention
536
+ if self.use_ada_layer_norm:
537
+ norm_hidden_states = self.norm1(hidden_states, timestep)
538
+ elif self.use_ada_layer_norm_zero:
539
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
540
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
541
+ )
542
+ else:
543
+ norm_hidden_states = self.norm1(hidden_states)
544
+
545
+ cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
546
+ attn_output = self.attn1(
547
+ norm_hidden_states,
548
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
549
+ attention_mask=attention_mask,
550
+ num_views=self.num_views,
551
+ multiview_attention=self.multiview_attention,
552
+ cross_domain_attention=self.cross_domain_attention,
553
+ **cross_attention_kwargs,
554
+ )
555
+
556
+
557
+ if self.use_ada_layer_norm_zero:
558
+ attn_output = gate_msa.unsqueeze(1) * attn_output
559
+ hidden_states = attn_output + hidden_states
560
+
561
+ # joint attention twice
562
+ if self.joint_attention_twice:
563
+ norm_hidden_states = (
564
+ self.norm_joint_twice(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_joint_twice(hidden_states)
565
+ )
566
+ hidden_states = self.attn_joint_twice(norm_hidden_states) + hidden_states
567
+
568
+ # 2. Cross-Attention
569
+ if self.attn2 is not None:
570
+ norm_hidden_states = (
571
+ self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
572
+ )
573
+ attn_output = self.attn2(
574
+ norm_hidden_states,
575
+ encoder_hidden_states=encoder_hidden_states,
576
+ attention_mask=encoder_attention_mask,
577
+ **cross_attention_kwargs,
578
+ )
579
+ hidden_states = attn_output + hidden_states
580
+
581
+ # 3. Feed-forward
582
+ norm_hidden_states = self.norm3(hidden_states)
583
+
584
+ if self.use_ada_layer_norm_zero:
585
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
586
+
587
+ if self._chunk_size is not None:
588
+ # "feed_forward_chunk_size" can be used to save memory
589
+ if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
590
+ raise ValueError(
591
+ f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
592
+ )
593
+
594
+ num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
595
+ ff_output = torch.cat(
596
+ [self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)],
597
+ dim=self._chunk_dim,
598
+ )
599
+ else:
600
+ ff_output = self.ff(norm_hidden_states)
601
+
602
+ if self.use_ada_layer_norm_zero:
603
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
604
+
605
+ hidden_states = ff_output + hidden_states
606
+
607
+ if self.joint_attention:
608
+ norm_hidden_states = (
609
+ self.norm_joint(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_joint(hidden_states)
610
+ )
611
+ hidden_states = self.attn_joint(norm_hidden_states) + hidden_states
612
+
613
+ return hidden_states
614
+
615
+
616
+ class CustomAttention(Attention):
617
+ def set_use_memory_efficient_attention_xformers(
618
+ self, use_memory_efficient_attention_xformers: bool, *args, **kwargs
619
+ ):
620
+ processor = XFormersMVAttnProcessor()
621
+ self.set_processor(processor)
622
+ # print("using xformers attention processor")
623
+
624
+
625
+ class CustomJointAttention(Attention):
626
+ def set_use_memory_efficient_attention_xformers(
627
+ self, use_memory_efficient_attention_xformers: bool, *args, **kwargs
628
+ ):
629
+ processor = XFormersJointAttnProcessor()
630
+ self.set_processor(processor)
631
+ # print("using xformers attention processor")
632
+
633
+ class MVAttnProcessor:
634
+ r"""
635
+ Default processor for performing attention-related computations.
636
+ """
637
+
638
+ def __call__(
639
+ self,
640
+ attn: Attention,
641
+ hidden_states,
642
+ encoder_hidden_states=None,
643
+ attention_mask=None,
644
+ temb=None,
645
+ num_views=1,
646
+ multiview_attention=True
647
+ ):
648
+ residual = hidden_states
649
+
650
+ if attn.spatial_norm is not None:
651
+ hidden_states = attn.spatial_norm(hidden_states, temb)
652
+
653
+ input_ndim = hidden_states.ndim
654
+
655
+ if input_ndim == 4:
656
+ batch_size, channel, height, width = hidden_states.shape
657
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
658
+
659
+ batch_size, sequence_length, _ = (
660
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
661
+ )
662
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
663
+
664
+ if attn.group_norm is not None:
665
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
666
+
667
+ query = attn.to_q(hidden_states)
668
+
669
+ if encoder_hidden_states is None:
670
+ encoder_hidden_states = hidden_states
671
+ elif attn.norm_cross:
672
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
673
+
674
+ key = attn.to_k(encoder_hidden_states)
675
+ value = attn.to_v(encoder_hidden_states)
676
+
677
+ # multi-view self-attention
678
+ if multiview_attention:
679
+ if num_views <= 6:
680
+ # after use xformer; possible to train with 6 views
681
+ # key = rearrange(key, "(b t) d c -> b (t d) c", t=num_views).repeat_interleave(num_views, dim=0)
682
+ # value = rearrange(value, "(b t) d c -> b (t d) c", t=num_views).repeat_interleave(num_views, dim=0)
683
+ key = rearrange(key, '(b t) d c-> b (t d) c', t=num_views).unsqueeze(1).repeat(1,num_views,1,1).flatten(0,1)
684
+ value = rearrange(value, '(b t) d c-> b (t d) c', t=num_views).unsqueeze(1).repeat(1,num_views,1,1).flatten(0,1)
685
+
686
+ else:# apply sparse attention
687
+ raise NotImplementedError("Sparse attention is not implemented yet.")
688
+
689
+
690
+ query = attn.head_to_batch_dim(query).contiguous()
691
+ key = attn.head_to_batch_dim(key).contiguous()
692
+ value = attn.head_to_batch_dim(value).contiguous()
693
+
694
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
695
+ hidden_states = torch.bmm(attention_probs, value)
696
+ hidden_states = attn.batch_to_head_dim(hidden_states)
697
+
698
+ # linear proj
699
+ hidden_states = attn.to_out[0](hidden_states)
700
+ # dropout
701
+ hidden_states = attn.to_out[1](hidden_states)
702
+
703
+ if input_ndim == 4:
704
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
705
+
706
+ if attn.residual_connection:
707
+ hidden_states = hidden_states + residual
708
+
709
+ hidden_states = hidden_states / attn.rescale_output_factor
710
+
711
+ return hidden_states
712
+
713
+
714
+ class XFormersMVAttnProcessor:
715
+ r"""
716
+ Default processor for performing attention-related computations.
717
+ """
718
+
719
+ def __call__(
720
+ self,
721
+ attn: Attention,
722
+ hidden_states,
723
+ encoder_hidden_states=None,
724
+ attention_mask=None,
725
+ temb=None,
726
+ num_views=1.,
727
+ multiview_attention=True,
728
+ cross_domain_attention=False,
729
+ ):
730
+ residual = hidden_states
731
+
732
+ if attn.spatial_norm is not None:
733
+ hidden_states = attn.spatial_norm(hidden_states, temb)
734
+
735
+ input_ndim = hidden_states.ndim
736
+
737
+ if input_ndim == 4:
738
+ batch_size, channel, height, width = hidden_states.shape
739
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
740
+
741
+ batch_size, sequence_length, _ = (
742
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
743
+ )
744
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
745
+
746
+ # from yuancheng; here attention_mask is None
747
+ if attention_mask is not None:
748
+ # expand our mask's singleton query_tokens dimension:
749
+ # [batch*heads, 1, key_tokens] ->
750
+ # [batch*heads, query_tokens, key_tokens]
751
+ # so that it can be added as a bias onto the attention scores that xformers computes:
752
+ # [batch*heads, query_tokens, key_tokens]
753
+ # we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
754
+ _, query_tokens, _ = hidden_states.shape
755
+ attention_mask = attention_mask.expand(-1, query_tokens, -1)
756
+
757
+ if attn.group_norm is not None:
758
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
759
+
760
+ query = attn.to_q(hidden_states)
761
+
762
+ if encoder_hidden_states is None:
763
+ encoder_hidden_states = hidden_states
764
+ elif attn.norm_cross:
765
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
766
+
767
+ key_raw = attn.to_k(encoder_hidden_states)
768
+ value_raw = attn.to_v(encoder_hidden_states)
769
+
770
+ # multi-view self-attention
771
+ if multiview_attention:
772
+ key = rearrange(key_raw, "(b t) d c -> b (t d) c", t=num_views).repeat_interleave(num_views, dim=0)
773
+ value = rearrange(value_raw, "(b t) d c -> b (t d) c", t=num_views).repeat_interleave(num_views, dim=0)
774
+
775
+ if cross_domain_attention:
776
+ # memory efficient, cross domain attention
777
+ key_0, key_1 = torch.chunk(key_raw, dim=0, chunks=2) # keys shape (b t) d c
778
+ value_0, value_1 = torch.chunk(value_raw, dim=0, chunks=2)
779
+ key_cross = torch.concat([key_1, key_0], dim=0)
780
+ value_cross = torch.concat([value_1, value_0], dim=0) # shape (b t) d c
781
+ key = torch.cat([key, key_cross], dim=1)
782
+ value = torch.cat([value, value_cross], dim=1) # shape (b t) (t+1 d) c
783
+ else:
784
+ # print("don't use multiview attention.")
785
+ key = key_raw
786
+ value = value_raw
787
+
788
+ query = attn.head_to_batch_dim(query)
789
+ key = attn.head_to_batch_dim(key)
790
+ value = attn.head_to_batch_dim(value)
791
+
792
+ hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
793
+ hidden_states = attn.batch_to_head_dim(hidden_states)
794
+
795
+ # linear proj
796
+ hidden_states = attn.to_out[0](hidden_states)
797
+ # dropout
798
+ hidden_states = attn.to_out[1](hidden_states)
799
+
800
+ if input_ndim == 4:
801
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
802
+
803
+ if attn.residual_connection:
804
+ hidden_states = hidden_states + residual
805
+
806
+ hidden_states = hidden_states / attn.rescale_output_factor
807
+
808
+ return hidden_states
809
+
810
+
811
+
812
+ class XFormersJointAttnProcessor:
813
+ r"""
814
+ Default processor for performing attention-related computations.
815
+ """
816
+
817
+ def __call__(
818
+ self,
819
+ attn: Attention,
820
+ hidden_states,
821
+ encoder_hidden_states=None,
822
+ attention_mask=None,
823
+ temb=None,
824
+ num_tasks=2
825
+ ):
826
+
827
+ residual = hidden_states
828
+
829
+ if attn.spatial_norm is not None:
830
+ hidden_states = attn.spatial_norm(hidden_states, temb)
831
+
832
+ input_ndim = hidden_states.ndim
833
+
834
+ if input_ndim == 4:
835
+ batch_size, channel, height, width = hidden_states.shape
836
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
837
+
838
+ batch_size, sequence_length, _ = (
839
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
840
+ )
841
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
842
+
843
+ # from yuancheng; here attention_mask is None
844
+ if attention_mask is not None:
845
+ # expand our mask's singleton query_tokens dimension:
846
+ # [batch*heads, 1, key_tokens] ->
847
+ # [batch*heads, query_tokens, key_tokens]
848
+ # so that it can be added as a bias onto the attention scores that xformers computes:
849
+ # [batch*heads, query_tokens, key_tokens]
850
+ # we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
851
+ _, query_tokens, _ = hidden_states.shape
852
+ attention_mask = attention_mask.expand(-1, query_tokens, -1)
853
+
854
+ if attn.group_norm is not None:
855
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
856
+
857
+ query = attn.to_q(hidden_states)
858
+
859
+ if encoder_hidden_states is None:
860
+ encoder_hidden_states = hidden_states
861
+ elif attn.norm_cross:
862
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
863
+
864
+ key = attn.to_k(encoder_hidden_states)
865
+ value = attn.to_v(encoder_hidden_states)
866
+
867
+ assert num_tasks == 2 # only support two tasks now
868
+
869
+ key_0, key_1 = torch.chunk(key, dim=0, chunks=2) # keys shape (b t) d c
870
+ value_0, value_1 = torch.chunk(value, dim=0, chunks=2)
871
+ key = torch.cat([key_0, key_1], dim=1) # (b t) 2d c
872
+ value = torch.cat([value_0, value_1], dim=1) # (b t) 2d c
873
+ key = torch.cat([key]*2, dim=0) # ( 2 b t) 2d c
874
+ value = torch.cat([value]*2, dim=0) # (2 b t) 2d c
875
+
876
+
877
+ query = attn.head_to_batch_dim(query).contiguous()
878
+ key = attn.head_to_batch_dim(key).contiguous()
879
+ value = attn.head_to_batch_dim(value).contiguous()
880
+
881
+ hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
882
+ hidden_states = attn.batch_to_head_dim(hidden_states)
883
+
884
+ # linear proj
885
+ hidden_states = attn.to_out[0](hidden_states)
886
+ # dropout
887
+ hidden_states = attn.to_out[1](hidden_states)
888
+
889
+ if input_ndim == 4:
890
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
891
+
892
+ if attn.residual_connection:
893
+ hidden_states = hidden_states + residual
894
+
895
+ hidden_states = hidden_states / attn.rescale_output_factor
896
+
897
+ return hidden_states
898
+
899
+
900
+ class JointAttnProcessor:
901
+ r"""
902
+ Default processor for performing attention-related computations.
903
+ """
904
+
905
+ def __call__(
906
+ self,
907
+ attn: Attention,
908
+ hidden_states,
909
+ encoder_hidden_states=None,
910
+ attention_mask=None,
911
+ temb=None,
912
+ num_tasks=2
913
+ ):
914
+
915
+ residual = hidden_states
916
+
917
+ if attn.spatial_norm is not None:
918
+ hidden_states = attn.spatial_norm(hidden_states, temb)
919
+
920
+ input_ndim = hidden_states.ndim
921
+
922
+ if input_ndim == 4:
923
+ batch_size, channel, height, width = hidden_states.shape
924
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
925
+
926
+ batch_size, sequence_length, _ = (
927
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
928
+ )
929
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
930
+
931
+
932
+ if attn.group_norm is not None:
933
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
934
+
935
+ query = attn.to_q(hidden_states)
936
+
937
+ if encoder_hidden_states is None:
938
+ encoder_hidden_states = hidden_states
939
+ elif attn.norm_cross:
940
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
941
+
942
+ key = attn.to_k(encoder_hidden_states)
943
+ value = attn.to_v(encoder_hidden_states)
944
+
945
+ assert num_tasks == 2 # only support two tasks now
946
+
947
+ key_0, key_1 = torch.chunk(key, dim=0, chunks=2) # keys shape (b t) d c
948
+ value_0, value_1 = torch.chunk(value, dim=0, chunks=2)
949
+ key = torch.cat([key_0, key_1], dim=1) # (b t) 2d c
950
+ value = torch.cat([value_0, value_1], dim=1) # (b t) 2d c
951
+ key = torch.cat([key]*2, dim=0) # ( 2 b t) 2d c
952
+ value = torch.cat([value]*2, dim=0) # (2 b t) 2d c
953
+
954
+
955
+ query = attn.head_to_batch_dim(query).contiguous()
956
+ key = attn.head_to_batch_dim(key).contiguous()
957
+ value = attn.head_to_batch_dim(value).contiguous()
958
+
959
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
960
+ hidden_states = torch.bmm(attention_probs, value)
961
+ hidden_states = attn.batch_to_head_dim(hidden_states)
962
+
963
+ # linear proj
964
+ hidden_states = attn.to_out[0](hidden_states)
965
+ # dropout
966
+ hidden_states = attn.to_out[1](hidden_states)
967
+
968
+ if input_ndim == 4:
969
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
970
+
971
+ if attn.residual_connection:
972
+ hidden_states = hidden_states + residual
973
+
974
+ hidden_states = hidden_states / attn.rescale_output_factor
975
+
976
+ return hidden_states
canonicalize/models/unet.py ADDED
@@ -0,0 +1,475 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_condition.py
2
+
3
+ from dataclasses import dataclass
4
+ from typing import List, Optional, Tuple, Union
5
+
6
+ import os
7
+ import json
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.utils.checkpoint
12
+
13
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
14
+ from diffusers import ModelMixin
15
+ from diffusers.utils import BaseOutput, logging
16
+ from diffusers.models.embeddings import TimestepEmbedding, Timesteps
17
+ from .unet_blocks import (
18
+ CrossAttnDownBlock3D,
19
+ CrossAttnUpBlock3D,
20
+ DownBlock3D,
21
+ UNetMidBlock3DCrossAttn,
22
+ UpBlock3D,
23
+ get_down_block,
24
+ get_up_block,
25
+ )
26
+ from .resnet import InflatedConv3d
27
+
28
+
29
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
30
+
31
+
32
+ @dataclass
33
+ class UNet3DConditionOutput(BaseOutput):
34
+ sample: torch.FloatTensor
35
+
36
+
37
+ class UNet3DConditionModel(ModelMixin, ConfigMixin):
38
+ _supports_gradient_checkpointing = True
39
+
40
+ @register_to_config
41
+ def __init__(
42
+ self,
43
+ sample_size: Optional[int] = None,
44
+ in_channels: int = 4,
45
+ out_channels: int = 4,
46
+ center_input_sample: bool = False,
47
+ flip_sin_to_cos: bool = True,
48
+ freq_shift: int = 0,
49
+ down_block_types: Tuple[str] = (
50
+ "CrossAttnDownBlock3D",
51
+ "CrossAttnDownBlock3D",
52
+ "CrossAttnDownBlock3D",
53
+ "DownBlock3D",
54
+ ),
55
+ mid_block_type: str = "UNetMidBlock3DCrossAttn",
56
+ up_block_types: Tuple[str] = (
57
+ "UpBlock3D",
58
+ "CrossAttnUpBlock3D",
59
+ "CrossAttnUpBlock3D",
60
+ "CrossAttnUpBlock3D"
61
+ ),
62
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
63
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
64
+ layers_per_block: int = 2,
65
+ downsample_padding: int = 1,
66
+ mid_block_scale_factor: float = 1,
67
+ act_fn: str = "silu",
68
+ norm_num_groups: int = 32,
69
+ norm_eps: float = 1e-5,
70
+ cross_attention_dim: int = 1280,
71
+ attention_head_dim: Union[int, Tuple[int]] = 8,
72
+ dual_cross_attention: bool = False,
73
+ use_linear_projection: bool = False,
74
+ class_embed_type: Optional[str] = None,
75
+ num_class_embeds: Optional[int] = None,
76
+ upcast_attention: bool = False,
77
+ resnet_time_scale_shift: str = "default",
78
+ use_attn_temp: bool = False,
79
+ camera_input_dim: int = 12,
80
+ camera_hidden_dim: int = 320,
81
+ camera_output_dim: int = 1280,
82
+ ):
83
+ super().__init__()
84
+
85
+ self.sample_size = sample_size
86
+ time_embed_dim = block_out_channels[0] * 4
87
+
88
+ # input
89
+ self.conv_in = InflatedConv3d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1))
90
+
91
+ # time
92
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
93
+ timestep_input_dim = block_out_channels[0]
94
+
95
+ self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
96
+
97
+ # class embedding
98
+ if class_embed_type is None and num_class_embeds is not None:
99
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
100
+ elif class_embed_type == "timestep":
101
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
102
+ elif class_embed_type == "identity":
103
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
104
+ else:
105
+ self.class_embedding = None
106
+
107
+ self.camera_embedding = nn.Sequential(
108
+ nn.Linear(camera_input_dim, time_embed_dim),
109
+ nn.SiLU(),
110
+ nn.Linear(time_embed_dim, time_embed_dim),
111
+ )
112
+
113
+ self.down_blocks = nn.ModuleList([])
114
+ self.mid_block = None
115
+ self.up_blocks = nn.ModuleList([])
116
+
117
+ if isinstance(only_cross_attention, bool):
118
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
119
+
120
+ if isinstance(attention_head_dim, int):
121
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
122
+
123
+ # down
124
+ output_channel = block_out_channels[0]
125
+ for i, down_block_type in enumerate(down_block_types):
126
+ input_channel = output_channel
127
+ output_channel = block_out_channels[i]
128
+ is_final_block = i == len(block_out_channels) - 1
129
+
130
+ down_block = get_down_block(
131
+ down_block_type,
132
+ num_layers=layers_per_block,
133
+ in_channels=input_channel,
134
+ out_channels=output_channel,
135
+ temb_channels=time_embed_dim,
136
+ add_downsample=not is_final_block,
137
+ resnet_eps=norm_eps,
138
+ resnet_act_fn=act_fn,
139
+ resnet_groups=norm_num_groups,
140
+ cross_attention_dim=cross_attention_dim,
141
+ attn_num_head_channels=attention_head_dim[i],
142
+ downsample_padding=downsample_padding,
143
+ dual_cross_attention=dual_cross_attention,
144
+ use_linear_projection=use_linear_projection,
145
+ only_cross_attention=only_cross_attention[i],
146
+ upcast_attention=upcast_attention,
147
+ resnet_time_scale_shift=resnet_time_scale_shift,
148
+ use_attn_temp=use_attn_temp
149
+ )
150
+ self.down_blocks.append(down_block)
151
+
152
+ # mid
153
+ if mid_block_type == "UNetMidBlock3DCrossAttn":
154
+ self.mid_block = UNetMidBlock3DCrossAttn(
155
+ in_channels=block_out_channels[-1],
156
+ temb_channels=time_embed_dim,
157
+ resnet_eps=norm_eps,
158
+ resnet_act_fn=act_fn,
159
+ output_scale_factor=mid_block_scale_factor,
160
+ resnet_time_scale_shift=resnet_time_scale_shift,
161
+ cross_attention_dim=cross_attention_dim,
162
+ attn_num_head_channels=attention_head_dim[-1],
163
+ resnet_groups=norm_num_groups,
164
+ dual_cross_attention=dual_cross_attention,
165
+ use_linear_projection=use_linear_projection,
166
+ upcast_attention=upcast_attention,
167
+ )
168
+ else:
169
+ raise ValueError(f"unknown mid_block_type : {mid_block_type}")
170
+
171
+ # count how many layers upsample the videos
172
+ self.num_upsamplers = 0
173
+
174
+ # up
175
+ reversed_block_out_channels = list(reversed(block_out_channels))
176
+ reversed_attention_head_dim = list(reversed(attention_head_dim))
177
+ only_cross_attention = list(reversed(only_cross_attention))
178
+ output_channel = reversed_block_out_channels[0]
179
+ for i, up_block_type in enumerate(up_block_types):
180
+ is_final_block = i == len(block_out_channels) - 1
181
+
182
+ prev_output_channel = output_channel
183
+ output_channel = reversed_block_out_channels[i]
184
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
185
+
186
+ # add upsample block for all BUT final layer
187
+ if not is_final_block:
188
+ add_upsample = True
189
+ self.num_upsamplers += 1
190
+ else:
191
+ add_upsample = False
192
+
193
+ up_block = get_up_block(
194
+ up_block_type,
195
+ num_layers=layers_per_block + 1,
196
+ in_channels=input_channel,
197
+ out_channels=output_channel,
198
+ prev_output_channel=prev_output_channel,
199
+ temb_channels=time_embed_dim,
200
+ add_upsample=add_upsample,
201
+ resnet_eps=norm_eps,
202
+ resnet_act_fn=act_fn,
203
+ resnet_groups=norm_num_groups,
204
+ cross_attention_dim=cross_attention_dim,
205
+ attn_num_head_channels=reversed_attention_head_dim[i],
206
+ dual_cross_attention=dual_cross_attention,
207
+ use_linear_projection=use_linear_projection,
208
+ only_cross_attention=only_cross_attention[i],
209
+ upcast_attention=upcast_attention,
210
+ resnet_time_scale_shift=resnet_time_scale_shift,
211
+ use_attn_temp=use_attn_temp,
212
+ )
213
+ self.up_blocks.append(up_block)
214
+ prev_output_channel = output_channel
215
+
216
+ # out
217
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
218
+ self.conv_act = nn.SiLU()
219
+ self.conv_out = InflatedConv3d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
220
+
221
+ def set_attention_slice(self, slice_size):
222
+ r"""
223
+ Enable sliced attention computation.
224
+
225
+ When this option is enabled, the attention module will split the input tensor in slices, to compute attention
226
+ in several steps. This is useful to save some memory in exchange for a small speed decrease.
227
+
228
+ Args:
229
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
230
+ When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
231
+ `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is
232
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
233
+ must be a multiple of `slice_size`.
234
+ """
235
+ sliceable_head_dims = []
236
+
237
+ def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module):
238
+ if hasattr(module, "set_attention_slice"):
239
+ sliceable_head_dims.append(module.sliceable_head_dim)
240
+
241
+ for child in module.children():
242
+ fn_recursive_retrieve_slicable_dims(child)
243
+
244
+ # retrieve number of attention layers
245
+ for module in self.children():
246
+ fn_recursive_retrieve_slicable_dims(module)
247
+
248
+ num_slicable_layers = len(sliceable_head_dims)
249
+
250
+ if slice_size == "auto":
251
+ # half the attention head size is usually a good trade-off between
252
+ # speed and memory
253
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
254
+ elif slice_size == "max":
255
+ # make smallest slice possible
256
+ slice_size = num_slicable_layers * [1]
257
+
258
+ slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
259
+
260
+ if len(slice_size) != len(sliceable_head_dims):
261
+ raise ValueError(
262
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
263
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
264
+ )
265
+
266
+ for i in range(len(slice_size)):
267
+ size = slice_size[i]
268
+ dim = sliceable_head_dims[i]
269
+ if size is not None and size > dim:
270
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
271
+
272
+ # Recursively walk through all the children.
273
+ # Any children which exposes the set_attention_slice method
274
+ # gets the message
275
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
276
+ if hasattr(module, "set_attention_slice"):
277
+ module.set_attention_slice(slice_size.pop())
278
+
279
+ for child in module.children():
280
+ fn_recursive_set_attention_slice(child, slice_size)
281
+
282
+ reversed_slice_size = list(reversed(slice_size))
283
+ for module in self.children():
284
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
285
+
286
+ def _set_gradient_checkpointing(self, module, value=False):
287
+ if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)):
288
+ module.gradient_checkpointing = value
289
+
290
+ def forward(
291
+ self,
292
+ sample: torch.FloatTensor,
293
+ timestep: Union[torch.Tensor, float, int],
294
+ encoder_hidden_states: torch.Tensor,
295
+ camera_matrixs: Optional[torch.Tensor] = None,
296
+ class_labels: Optional[torch.Tensor] = None,
297
+ attention_mask: Optional[torch.Tensor] = None,
298
+ return_dict: bool = True,
299
+ ) -> Union[UNet3DConditionOutput, Tuple]:
300
+ r"""
301
+ Args:
302
+ sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
303
+ timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
304
+ encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
305
+ return_dict (`bool`, *optional*, defaults to `True`):
306
+ Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
307
+
308
+ Returns:
309
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
310
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
311
+ returning a tuple, the first element is the sample tensor.
312
+ """
313
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
314
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
315
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
316
+ # on the fly if necessary.
317
+ default_overall_up_factor = 2**self.num_upsamplers
318
+
319
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
320
+ forward_upsample_size = False
321
+ upsample_size = None
322
+
323
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
324
+ logger.info("Forward upsample size to force interpolation output size.")
325
+ forward_upsample_size = True
326
+
327
+ # prepare attention_mask
328
+ if attention_mask is not None:
329
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
330
+ attention_mask = attention_mask.unsqueeze(1)
331
+
332
+ # center input if necessary
333
+ if self.config.center_input_sample:
334
+ sample = 2 * sample - 1.0
335
+ # time
336
+ timesteps = timestep
337
+ if not torch.is_tensor(timesteps):
338
+ # This would be a good case for the `match` statement (Python 3.10+)
339
+ is_mps = sample.device.type == "mps"
340
+ if isinstance(timestep, float):
341
+ dtype = torch.float32 if is_mps else torch.float64
342
+ else:
343
+ dtype = torch.int32 if is_mps else torch.int64
344
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
345
+ elif len(timesteps.shape) == 0:
346
+ timesteps = timesteps[None].to(sample.device)
347
+
348
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
349
+ timesteps = timesteps.expand(sample.shape[0])
350
+
351
+ t_emb = self.time_proj(timesteps)
352
+
353
+ # timesteps does not contain any weights and will always return f32 tensors
354
+ # but time_embedding might actually be running in fp16. so we need to cast here.
355
+ # there might be better ways to encapsulate this.
356
+ t_emb = t_emb.to(dtype=self.dtype)
357
+ emb = self.time_embedding(t_emb) #torch.Size([32, 1280])
358
+ emb = torch.unsqueeze(emb, 1)
359
+ if camera_matrixs is not None:
360
+ cam_emb = self.camera_embedding(camera_matrixs)
361
+ emb = emb.repeat(1,cam_emb.shape[1],1)
362
+ emb = emb + cam_emb
363
+
364
+ if self.class_embedding is not None:
365
+ if class_labels is not None:
366
+ if self.config.class_embed_type == "timestep":
367
+ class_labels = self.time_proj(class_labels)
368
+ class_emb = self.class_embedding(class_labels)
369
+ emb = emb + class_emb
370
+
371
+ # pre-process
372
+ sample = self.conv_in(sample)
373
+
374
+ # down
375
+ down_block_res_samples = (sample,)
376
+ for downsample_block in self.down_blocks:
377
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
378
+ sample, res_samples = downsample_block(
379
+ hidden_states=sample,
380
+ temb=emb,
381
+ encoder_hidden_states=encoder_hidden_states,
382
+ attention_mask=attention_mask,
383
+ )
384
+ else:
385
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
386
+
387
+ down_block_res_samples += res_samples
388
+
389
+ # mid
390
+ sample = self.mid_block(
391
+ sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
392
+ )
393
+
394
+ # up
395
+ for i, upsample_block in enumerate(self.up_blocks):
396
+ is_final_block = i == len(self.up_blocks) - 1
397
+
398
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
399
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
400
+
401
+ # if we have not reached the final block and need to forward the
402
+ # upsample size, we do it here
403
+ if not is_final_block and forward_upsample_size:
404
+ upsample_size = down_block_res_samples[-1].shape[2:]
405
+
406
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
407
+ sample = upsample_block(
408
+ hidden_states=sample,
409
+ temb=emb,
410
+ res_hidden_states_tuple=res_samples,
411
+ encoder_hidden_states=encoder_hidden_states,
412
+ upsample_size=upsample_size,
413
+ attention_mask=attention_mask,
414
+ )
415
+ else:
416
+ sample = upsample_block(
417
+ hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
418
+ )
419
+ # post-process
420
+ sample = self.conv_norm_out(sample)
421
+ sample = self.conv_act(sample)
422
+ sample = self.conv_out(sample)
423
+
424
+ if not return_dict:
425
+ return (sample,)
426
+
427
+ return UNet3DConditionOutput(sample=sample)
428
+
429
+ @classmethod
430
+ def from_pretrained_2d(cls, pretrained_model_path, subfolder=None):
431
+ if subfolder is not None:
432
+ pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
433
+
434
+ config_file = os.path.join(pretrained_model_path, 'config.json')
435
+ if not os.path.isfile(config_file):
436
+ raise RuntimeError(f"{config_file} does not exist")
437
+ with open(config_file, "r") as f:
438
+ config = json.load(f)
439
+ config["_class_name"] = cls.__name__
440
+ config["down_block_types"] = [
441
+ "CrossAttnDownBlock3D",
442
+ "CrossAttnDownBlock3D",
443
+ "CrossAttnDownBlock3D",
444
+ "DownBlock3D"
445
+ ]
446
+ config["up_block_types"] = [
447
+ "UpBlock3D",
448
+ "CrossAttnUpBlock3D",
449
+ "CrossAttnUpBlock3D",
450
+ "CrossAttnUpBlock3D"
451
+ ]
452
+
453
+ from diffusers.utils import WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME
454
+
455
+ import safetensors
456
+ model = cls.from_config(config)
457
+ model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
458
+ if not os.path.isfile(model_file):
459
+ model_file = os.path.join(pretrained_model_path, SAFETENSORS_WEIGHTS_NAME)
460
+ if not os.path.isfile(model_file):
461
+ raise RuntimeError(f"{model_file} does not exist")
462
+ else:
463
+ state_dict = safetensors.torch.load_file(model_file, device="cpu")
464
+ else:
465
+ state_dict = torch.load(model_file, map_location="cpu")
466
+
467
+ for k, v in model.state_dict().items():
468
+ if '_temp.' in k or 'camera_embedding' in k or 'class_embedding' in k:
469
+ state_dict.update({k: v})
470
+ for k in list(state_dict.keys()):
471
+ if 'camera_embedding_' in k:
472
+ v = state_dict.pop(k)
473
+ model.load_state_dict(state_dict)
474
+
475
+ return model
canonicalize/models/unet_blocks.py ADDED
@@ -0,0 +1,596 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_blocks.py
2
+
3
+ import torch
4
+ from torch import nn
5
+
6
+ # from .attention import Transformer3DModel
7
+ from .resnet import Downsample3D, ResnetBlock3D, Upsample3D
8
+
9
+
10
+ def get_down_block(
11
+ down_block_type,
12
+ num_layers,
13
+ in_channels,
14
+ out_channels,
15
+ temb_channels,
16
+ add_downsample,
17
+ resnet_eps,
18
+ resnet_act_fn,
19
+ attn_num_head_channels,
20
+ resnet_groups=None,
21
+ cross_attention_dim=None,
22
+ downsample_padding=None,
23
+ dual_cross_attention=False,
24
+ use_linear_projection=False,
25
+ only_cross_attention=False,
26
+ upcast_attention=False,
27
+ resnet_time_scale_shift="default",
28
+ use_attn_temp=False,
29
+ ):
30
+ down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
31
+ if down_block_type == "DownBlock3D":
32
+ return DownBlock3D(
33
+ num_layers=num_layers,
34
+ in_channels=in_channels,
35
+ out_channels=out_channels,
36
+ temb_channels=temb_channels,
37
+ add_downsample=add_downsample,
38
+ resnet_eps=resnet_eps,
39
+ resnet_act_fn=resnet_act_fn,
40
+ resnet_groups=resnet_groups,
41
+ downsample_padding=downsample_padding,
42
+ resnet_time_scale_shift=resnet_time_scale_shift,
43
+ )
44
+ elif down_block_type == "CrossAttnDownBlock3D":
45
+ if cross_attention_dim is None:
46
+ raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock3D")
47
+ return CrossAttnDownBlock3D(
48
+ num_layers=num_layers,
49
+ in_channels=in_channels,
50
+ out_channels=out_channels,
51
+ temb_channels=temb_channels,
52
+ add_downsample=add_downsample,
53
+ resnet_eps=resnet_eps,
54
+ resnet_act_fn=resnet_act_fn,
55
+ resnet_groups=resnet_groups,
56
+ downsample_padding=downsample_padding,
57
+ cross_attention_dim=cross_attention_dim,
58
+ attn_num_head_channels=attn_num_head_channels,
59
+ dual_cross_attention=dual_cross_attention,
60
+ use_linear_projection=use_linear_projection,
61
+ only_cross_attention=only_cross_attention,
62
+ upcast_attention=upcast_attention,
63
+ resnet_time_scale_shift=resnet_time_scale_shift,
64
+ use_attn_temp=use_attn_temp,
65
+ )
66
+ raise ValueError(f"{down_block_type} does not exist.")
67
+
68
+
69
+ def get_up_block(
70
+ up_block_type,
71
+ num_layers,
72
+ in_channels,
73
+ out_channels,
74
+ prev_output_channel,
75
+ temb_channels,
76
+ add_upsample,
77
+ resnet_eps,
78
+ resnet_act_fn,
79
+ attn_num_head_channels,
80
+ resnet_groups=None,
81
+ cross_attention_dim=None,
82
+ dual_cross_attention=False,
83
+ use_linear_projection=False,
84
+ only_cross_attention=False,
85
+ upcast_attention=False,
86
+ resnet_time_scale_shift="default",
87
+ use_attn_temp=False,
88
+ ):
89
+ up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
90
+ if up_block_type == "UpBlock3D":
91
+ return UpBlock3D(
92
+ num_layers=num_layers,
93
+ in_channels=in_channels,
94
+ out_channels=out_channels,
95
+ prev_output_channel=prev_output_channel,
96
+ temb_channels=temb_channels,
97
+ add_upsample=add_upsample,
98
+ resnet_eps=resnet_eps,
99
+ resnet_act_fn=resnet_act_fn,
100
+ resnet_groups=resnet_groups,
101
+ resnet_time_scale_shift=resnet_time_scale_shift,
102
+ )
103
+ elif up_block_type == "CrossAttnUpBlock3D":
104
+ if cross_attention_dim is None:
105
+ raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock3D")
106
+ return CrossAttnUpBlock3D(
107
+ num_layers=num_layers,
108
+ in_channels=in_channels,
109
+ out_channels=out_channels,
110
+ prev_output_channel=prev_output_channel,
111
+ temb_channels=temb_channels,
112
+ add_upsample=add_upsample,
113
+ resnet_eps=resnet_eps,
114
+ resnet_act_fn=resnet_act_fn,
115
+ resnet_groups=resnet_groups,
116
+ cross_attention_dim=cross_attention_dim,
117
+ attn_num_head_channels=attn_num_head_channels,
118
+ dual_cross_attention=dual_cross_attention,
119
+ use_linear_projection=use_linear_projection,
120
+ only_cross_attention=only_cross_attention,
121
+ upcast_attention=upcast_attention,
122
+ resnet_time_scale_shift=resnet_time_scale_shift,
123
+ use_attn_temp=use_attn_temp,
124
+ )
125
+ raise ValueError(f"{up_block_type} does not exist.")
126
+
127
+
128
+ class UNetMidBlock3DCrossAttn(nn.Module):
129
+ def __init__(
130
+ self,
131
+ in_channels: int,
132
+ temb_channels: int,
133
+ dropout: float = 0.0,
134
+ num_layers: int = 1,
135
+ resnet_eps: float = 1e-6,
136
+ resnet_time_scale_shift: str = "default",
137
+ resnet_act_fn: str = "swish",
138
+ resnet_groups: int = 32,
139
+ resnet_pre_norm: bool = True,
140
+ attn_num_head_channels=1,
141
+ output_scale_factor=1.0,
142
+ cross_attention_dim=1280,
143
+ dual_cross_attention=False,
144
+ use_linear_projection=False,
145
+ upcast_attention=False,
146
+ ):
147
+ super().__init__()
148
+
149
+ self.has_cross_attention = True
150
+ self.attn_num_head_channels = attn_num_head_channels
151
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
152
+
153
+ # there is always at least one resnet
154
+ resnets = [
155
+ ResnetBlock3D(
156
+ in_channels=in_channels,
157
+ out_channels=in_channels,
158
+ temb_channels=temb_channels,
159
+ eps=resnet_eps,
160
+ groups=resnet_groups,
161
+ dropout=dropout,
162
+ time_embedding_norm=resnet_time_scale_shift,
163
+ non_linearity=resnet_act_fn,
164
+ output_scale_factor=output_scale_factor,
165
+ pre_norm=resnet_pre_norm,
166
+ )
167
+ ]
168
+ attentions = []
169
+
170
+ for _ in range(num_layers):
171
+ if dual_cross_attention:
172
+ raise NotImplementedError
173
+ attentions.append(
174
+ Transformer3DModel(
175
+ attn_num_head_channels,
176
+ in_channels // attn_num_head_channels,
177
+ in_channels=in_channels,
178
+ num_layers=1,
179
+ cross_attention_dim=cross_attention_dim,
180
+ norm_num_groups=resnet_groups,
181
+ use_linear_projection=use_linear_projection,
182
+ upcast_attention=upcast_attention,
183
+ )
184
+ )
185
+ resnets.append(
186
+ ResnetBlock3D(
187
+ in_channels=in_channels,
188
+ out_channels=in_channels,
189
+ temb_channels=temb_channels,
190
+ eps=resnet_eps,
191
+ groups=resnet_groups,
192
+ dropout=dropout,
193
+ time_embedding_norm=resnet_time_scale_shift,
194
+ non_linearity=resnet_act_fn,
195
+ output_scale_factor=output_scale_factor,
196
+ pre_norm=resnet_pre_norm,
197
+ )
198
+ )
199
+
200
+ self.attentions = nn.ModuleList(attentions)
201
+ self.resnets = nn.ModuleList(resnets)
202
+
203
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None):
204
+ hidden_states = self.resnets[0](hidden_states, temb)
205
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
206
+ hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
207
+ hidden_states = resnet(hidden_states, temb)
208
+
209
+ return hidden_states
210
+
211
+
212
+ class CrossAttnDownBlock3D(nn.Module):
213
+ def __init__(
214
+ self,
215
+ in_channels: int,
216
+ out_channels: int,
217
+ temb_channels: int,
218
+ dropout: float = 0.0,
219
+ num_layers: int = 1,
220
+ resnet_eps: float = 1e-6,
221
+ resnet_time_scale_shift: str = "default",
222
+ resnet_act_fn: str = "swish",
223
+ resnet_groups: int = 32,
224
+ resnet_pre_norm: bool = True,
225
+ attn_num_head_channels=1,
226
+ cross_attention_dim=1280,
227
+ output_scale_factor=1.0,
228
+ downsample_padding=1,
229
+ add_downsample=True,
230
+ dual_cross_attention=False,
231
+ use_linear_projection=False,
232
+ only_cross_attention=False,
233
+ upcast_attention=False,
234
+ use_attn_temp=False,
235
+ ):
236
+ super().__init__()
237
+ resnets = []
238
+ attentions = []
239
+
240
+ self.has_cross_attention = True
241
+ self.attn_num_head_channels = attn_num_head_channels
242
+
243
+ for i in range(num_layers):
244
+ in_channels = in_channels if i == 0 else out_channels
245
+ resnets.append(
246
+ ResnetBlock3D(
247
+ in_channels=in_channels,
248
+ out_channels=out_channels,
249
+ temb_channels=temb_channels,
250
+ eps=resnet_eps,
251
+ groups=resnet_groups,
252
+ dropout=dropout,
253
+ time_embedding_norm=resnet_time_scale_shift,
254
+ non_linearity=resnet_act_fn,
255
+ output_scale_factor=output_scale_factor,
256
+ pre_norm=resnet_pre_norm,
257
+ )
258
+ )
259
+ if dual_cross_attention:
260
+ raise NotImplementedError
261
+ attentions.append(
262
+ Transformer3DModel(
263
+ attn_num_head_channels,
264
+ out_channels // attn_num_head_channels,
265
+ in_channels=out_channels,
266
+ num_layers=1,
267
+ cross_attention_dim=cross_attention_dim,
268
+ norm_num_groups=resnet_groups,
269
+ use_linear_projection=use_linear_projection,
270
+ only_cross_attention=only_cross_attention,
271
+ upcast_attention=upcast_attention,
272
+ use_attn_temp=use_attn_temp,
273
+ )
274
+ )
275
+ self.attentions = nn.ModuleList(attentions)
276
+ self.resnets = nn.ModuleList(resnets)
277
+
278
+ if add_downsample:
279
+ self.downsamplers = nn.ModuleList(
280
+ [
281
+ Downsample3D(
282
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
283
+ )
284
+ ]
285
+ )
286
+ else:
287
+ self.downsamplers = None
288
+
289
+ self.gradient_checkpointing = False
290
+
291
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None):
292
+ output_states = ()
293
+
294
+ for resnet, attn in zip(self.resnets, self.attentions):
295
+ if self.training and self.gradient_checkpointing:
296
+
297
+ def create_custom_forward(module, return_dict=None):
298
+ def custom_forward(*inputs):
299
+ if return_dict is not None:
300
+ return module(*inputs, return_dict=return_dict)
301
+ else:
302
+ return module(*inputs)
303
+
304
+ return custom_forward
305
+
306
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
307
+ hidden_states = torch.utils.checkpoint.checkpoint(
308
+ create_custom_forward(attn, return_dict=False),
309
+ hidden_states,
310
+ encoder_hidden_states,
311
+ )[0]
312
+ else:
313
+ hidden_states = resnet(hidden_states, temb)
314
+ hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
315
+
316
+ output_states += (hidden_states,)
317
+
318
+ if self.downsamplers is not None:
319
+ for downsampler in self.downsamplers:
320
+ hidden_states = downsampler(hidden_states)
321
+
322
+ output_states += (hidden_states,)
323
+
324
+ return hidden_states, output_states
325
+
326
+
327
+ class DownBlock3D(nn.Module):
328
+ def __init__(
329
+ self,
330
+ in_channels: int,
331
+ out_channels: int,
332
+ temb_channels: int,
333
+ dropout: float = 0.0,
334
+ num_layers: int = 1,
335
+ resnet_eps: float = 1e-6,
336
+ resnet_time_scale_shift: str = "default",
337
+ resnet_act_fn: str = "swish",
338
+ resnet_groups: int = 32,
339
+ resnet_pre_norm: bool = True,
340
+ output_scale_factor=1.0,
341
+ add_downsample=True,
342
+ downsample_padding=1,
343
+ ):
344
+ super().__init__()
345
+ resnets = []
346
+
347
+ for i in range(num_layers):
348
+ in_channels = in_channels if i == 0 else out_channels
349
+ resnets.append(
350
+ ResnetBlock3D(
351
+ in_channels=in_channels,
352
+ out_channels=out_channels,
353
+ temb_channels=temb_channels,
354
+ eps=resnet_eps,
355
+ groups=resnet_groups,
356
+ dropout=dropout,
357
+ time_embedding_norm=resnet_time_scale_shift,
358
+ non_linearity=resnet_act_fn,
359
+ output_scale_factor=output_scale_factor,
360
+ pre_norm=resnet_pre_norm,
361
+ )
362
+ )
363
+
364
+ self.resnets = nn.ModuleList(resnets)
365
+
366
+ if add_downsample:
367
+ self.downsamplers = nn.ModuleList(
368
+ [
369
+ Downsample3D(
370
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
371
+ )
372
+ ]
373
+ )
374
+ else:
375
+ self.downsamplers = None
376
+
377
+ self.gradient_checkpointing = False
378
+
379
+ def forward(self, hidden_states, temb=None):
380
+ output_states = ()
381
+
382
+ for resnet in self.resnets:
383
+ if self.training and self.gradient_checkpointing:
384
+
385
+ def create_custom_forward(module):
386
+ def custom_forward(*inputs):
387
+ return module(*inputs)
388
+
389
+ return custom_forward
390
+
391
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
392
+ else:
393
+ hidden_states = resnet(hidden_states, temb)
394
+
395
+ output_states += (hidden_states,)
396
+
397
+ if self.downsamplers is not None:
398
+ for downsampler in self.downsamplers:
399
+ hidden_states = downsampler(hidden_states)
400
+
401
+ output_states += (hidden_states,)
402
+
403
+ return hidden_states, output_states
404
+
405
+
406
+ class CrossAttnUpBlock3D(nn.Module):
407
+ def __init__(
408
+ self,
409
+ in_channels: int,
410
+ out_channels: int,
411
+ prev_output_channel: int,
412
+ temb_channels: int,
413
+ dropout: float = 0.0,
414
+ num_layers: int = 1,
415
+ resnet_eps: float = 1e-6,
416
+ resnet_time_scale_shift: str = "default",
417
+ resnet_act_fn: str = "swish",
418
+ resnet_groups: int = 32,
419
+ resnet_pre_norm: bool = True,
420
+ attn_num_head_channels=1,
421
+ cross_attention_dim=1280,
422
+ output_scale_factor=1.0,
423
+ add_upsample=True,
424
+ dual_cross_attention=False,
425
+ use_linear_projection=False,
426
+ only_cross_attention=False,
427
+ upcast_attention=False,
428
+ use_attn_temp=False,
429
+ ):
430
+ super().__init__()
431
+ resnets = []
432
+ attentions = []
433
+
434
+ self.has_cross_attention = True
435
+ self.attn_num_head_channels = attn_num_head_channels
436
+
437
+ for i in range(num_layers):
438
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
439
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
440
+
441
+ resnets.append(
442
+ ResnetBlock3D(
443
+ in_channels=resnet_in_channels + res_skip_channels,
444
+ out_channels=out_channels,
445
+ temb_channels=temb_channels,
446
+ eps=resnet_eps,
447
+ groups=resnet_groups,
448
+ dropout=dropout,
449
+ time_embedding_norm=resnet_time_scale_shift,
450
+ non_linearity=resnet_act_fn,
451
+ output_scale_factor=output_scale_factor,
452
+ pre_norm=resnet_pre_norm,
453
+ )
454
+ )
455
+ if dual_cross_attention:
456
+ raise NotImplementedError
457
+ attentions.append(
458
+ Transformer3DModel(
459
+ attn_num_head_channels,
460
+ out_channels // attn_num_head_channels,
461
+ in_channels=out_channels,
462
+ num_layers=1,
463
+ cross_attention_dim=cross_attention_dim,
464
+ norm_num_groups=resnet_groups,
465
+ use_linear_projection=use_linear_projection,
466
+ only_cross_attention=only_cross_attention,
467
+ upcast_attention=upcast_attention,
468
+ use_attn_temp=use_attn_temp,
469
+ )
470
+ )
471
+
472
+ self.attentions = nn.ModuleList(attentions)
473
+ self.resnets = nn.ModuleList(resnets)
474
+
475
+ if add_upsample:
476
+ self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)])
477
+ else:
478
+ self.upsamplers = None
479
+
480
+ self.gradient_checkpointing = False
481
+
482
+ def forward(
483
+ self,
484
+ hidden_states,
485
+ res_hidden_states_tuple,
486
+ temb=None,
487
+ encoder_hidden_states=None,
488
+ upsample_size=None,
489
+ attention_mask=None,
490
+ ):
491
+ for resnet, attn in zip(self.resnets, self.attentions):
492
+ # pop res hidden states
493
+ res_hidden_states = res_hidden_states_tuple[-1]
494
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
495
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
496
+
497
+ if self.training and self.gradient_checkpointing:
498
+
499
+ def create_custom_forward(module, return_dict=None):
500
+ def custom_forward(*inputs):
501
+ if return_dict is not None:
502
+ return module(*inputs, return_dict=return_dict)
503
+ else:
504
+ return module(*inputs)
505
+
506
+ return custom_forward
507
+
508
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
509
+ hidden_states = torch.utils.checkpoint.checkpoint(
510
+ create_custom_forward(attn, return_dict=False),
511
+ hidden_states,
512
+ encoder_hidden_states,
513
+ )[0]
514
+ else:
515
+ hidden_states = resnet(hidden_states, temb)
516
+ hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
517
+
518
+ if self.upsamplers is not None:
519
+ for upsampler in self.upsamplers:
520
+ hidden_states = upsampler(hidden_states, upsample_size)
521
+
522
+ return hidden_states
523
+
524
+
525
+ class UpBlock3D(nn.Module):
526
+ def __init__(
527
+ self,
528
+ in_channels: int,
529
+ prev_output_channel: int,
530
+ out_channels: int,
531
+ temb_channels: int,
532
+ dropout: float = 0.0,
533
+ num_layers: int = 1,
534
+ resnet_eps: float = 1e-6,
535
+ resnet_time_scale_shift: str = "default",
536
+ resnet_act_fn: str = "swish",
537
+ resnet_groups: int = 32,
538
+ resnet_pre_norm: bool = True,
539
+ output_scale_factor=1.0,
540
+ add_upsample=True,
541
+ ):
542
+ super().__init__()
543
+ resnets = []
544
+
545
+ for i in range(num_layers):
546
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
547
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
548
+
549
+ resnets.append(
550
+ ResnetBlock3D(
551
+ in_channels=resnet_in_channels + res_skip_channels,
552
+ out_channels=out_channels,
553
+ temb_channels=temb_channels,
554
+ eps=resnet_eps,
555
+ groups=resnet_groups,
556
+ dropout=dropout,
557
+ time_embedding_norm=resnet_time_scale_shift,
558
+ non_linearity=resnet_act_fn,
559
+ output_scale_factor=output_scale_factor,
560
+ pre_norm=resnet_pre_norm,
561
+ )
562
+ )
563
+
564
+ self.resnets = nn.ModuleList(resnets)
565
+
566
+ if add_upsample:
567
+ self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)])
568
+ else:
569
+ self.upsamplers = None
570
+
571
+ self.gradient_checkpointing = False
572
+
573
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
574
+ for resnet in self.resnets:
575
+ # pop res hidden states
576
+ res_hidden_states = res_hidden_states_tuple[-1]
577
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
578
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
579
+
580
+ if self.training and self.gradient_checkpointing:
581
+
582
+ def create_custom_forward(module):
583
+ def custom_forward(*inputs):
584
+ return module(*inputs)
585
+
586
+ return custom_forward
587
+
588
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
589
+ else:
590
+ hidden_states = resnet(hidden_states, temb)
591
+
592
+ if self.upsamplers is not None:
593
+ for upsampler in self.upsamplers:
594
+ hidden_states = upsampler(hidden_states, upsample_size)
595
+
596
+ return hidden_states
canonicalize/models/unet_mv2d_blocks.py ADDED
@@ -0,0 +1,924 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Any, Dict, Optional, Tuple
15
+
16
+ import numpy as np
17
+ import torch
18
+ import torch.nn.functional as F
19
+ from torch import nn
20
+
21
+ from diffusers.utils import is_torch_version, logging
22
+ # from diffusers.models.attention import AdaGroupNorm
23
+ from diffusers.models.attention_processor import Attention, AttnAddedKVProcessor, AttnAddedKVProcessor2_0
24
+ from diffusers.models.dual_transformer_2d import DualTransformer2DModel
25
+ from diffusers.models.resnet import Downsample2D, FirDownsample2D, FirUpsample2D, KDownsample2D, KUpsample2D, ResnetBlock2D, Upsample2D
26
+ from canonicalize.models.transformer_mv2d import TransformerMV2DModel
27
+
28
+ from diffusers.models.unets.unet_2d_blocks import DownBlock2D, ResnetDownsampleBlock2D, AttnDownBlock2D, CrossAttnDownBlock2D, SimpleCrossAttnDownBlock2D, SkipDownBlock2D, AttnSkipDownBlock2D, DownEncoderBlock2D, AttnDownEncoderBlock2D, KDownBlock2D, KCrossAttnDownBlock2D
29
+ from diffusers.models.unets.unet_2d_blocks import UpBlock2D, ResnetUpsampleBlock2D, CrossAttnUpBlock2D, SimpleCrossAttnUpBlock2D, AttnUpBlock2D, SkipUpBlock2D, AttnSkipUpBlock2D, UpDecoderBlock2D, AttnUpDecoderBlock2D, KUpBlock2D, KCrossAttnUpBlock2D
30
+
31
+
32
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
33
+
34
+
35
+ def get_down_block(
36
+ down_block_type,
37
+ num_layers,
38
+ in_channels,
39
+ out_channels,
40
+ temb_channels,
41
+ add_downsample,
42
+ resnet_eps,
43
+ resnet_act_fn,
44
+ transformer_layers_per_block=1,
45
+ num_attention_heads=None,
46
+ resnet_groups=None,
47
+ cross_attention_dim=None,
48
+ downsample_padding=None,
49
+ dual_cross_attention=False,
50
+ use_linear_projection=False,
51
+ only_cross_attention=False,
52
+ upcast_attention=False,
53
+ resnet_time_scale_shift="default",
54
+ resnet_skip_time_act=False,
55
+ resnet_out_scale_factor=1.0,
56
+ cross_attention_norm=None,
57
+ attention_head_dim=None,
58
+ downsample_type=None,
59
+ num_views=1,
60
+ joint_attention: bool = False,
61
+ joint_attention_twice: bool = False,
62
+ multiview_attention: bool = True,
63
+ cross_domain_attention: bool=False
64
+ ):
65
+ # If attn head dim is not defined, we default it to the number of heads
66
+ if attention_head_dim is None:
67
+ logger.warn(
68
+ f"It is recommended to provide `attention_head_dim` when calling `get_down_block`. Defaulting `attention_head_dim` to {num_attention_heads}."
69
+ )
70
+ attention_head_dim = num_attention_heads
71
+
72
+ down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
73
+ if down_block_type == "DownBlock2D":
74
+ return DownBlock2D(
75
+ num_layers=num_layers,
76
+ in_channels=in_channels,
77
+ out_channels=out_channels,
78
+ temb_channels=temb_channels,
79
+ add_downsample=add_downsample,
80
+ resnet_eps=resnet_eps,
81
+ resnet_act_fn=resnet_act_fn,
82
+ resnet_groups=resnet_groups,
83
+ downsample_padding=downsample_padding,
84
+ resnet_time_scale_shift=resnet_time_scale_shift,
85
+ )
86
+ elif down_block_type == "ResnetDownsampleBlock2D":
87
+ return ResnetDownsampleBlock2D(
88
+ num_layers=num_layers,
89
+ in_channels=in_channels,
90
+ out_channels=out_channels,
91
+ temb_channels=temb_channels,
92
+ add_downsample=add_downsample,
93
+ resnet_eps=resnet_eps,
94
+ resnet_act_fn=resnet_act_fn,
95
+ resnet_groups=resnet_groups,
96
+ resnet_time_scale_shift=resnet_time_scale_shift,
97
+ skip_time_act=resnet_skip_time_act,
98
+ output_scale_factor=resnet_out_scale_factor,
99
+ )
100
+ elif down_block_type == "AttnDownBlock2D":
101
+ if add_downsample is False:
102
+ downsample_type = None
103
+ else:
104
+ downsample_type = downsample_type or "conv" # default to 'conv'
105
+ return AttnDownBlock2D(
106
+ num_layers=num_layers,
107
+ in_channels=in_channels,
108
+ out_channels=out_channels,
109
+ temb_channels=temb_channels,
110
+ resnet_eps=resnet_eps,
111
+ resnet_act_fn=resnet_act_fn,
112
+ resnet_groups=resnet_groups,
113
+ downsample_padding=downsample_padding,
114
+ attention_head_dim=attention_head_dim,
115
+ resnet_time_scale_shift=resnet_time_scale_shift,
116
+ downsample_type=downsample_type,
117
+ )
118
+ elif down_block_type == "CrossAttnDownBlock2D":
119
+ if cross_attention_dim is None:
120
+ raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock2D")
121
+ return CrossAttnDownBlock2D(
122
+ num_layers=num_layers,
123
+ transformer_layers_per_block=transformer_layers_per_block,
124
+ in_channels=in_channels,
125
+ out_channels=out_channels,
126
+ temb_channels=temb_channels,
127
+ add_downsample=add_downsample,
128
+ resnet_eps=resnet_eps,
129
+ resnet_act_fn=resnet_act_fn,
130
+ resnet_groups=resnet_groups,
131
+ downsample_padding=downsample_padding,
132
+ cross_attention_dim=cross_attention_dim,
133
+ num_attention_heads=num_attention_heads,
134
+ dual_cross_attention=dual_cross_attention,
135
+ use_linear_projection=use_linear_projection,
136
+ only_cross_attention=only_cross_attention,
137
+ upcast_attention=upcast_attention,
138
+ resnet_time_scale_shift=resnet_time_scale_shift,
139
+ )
140
+ # custom MV2D attention block
141
+ elif down_block_type == "CrossAttnDownBlockMV2D":
142
+ if cross_attention_dim is None:
143
+ raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlockMV2D")
144
+ return CrossAttnDownBlockMV2D(
145
+ num_layers=num_layers,
146
+ transformer_layers_per_block=transformer_layers_per_block,
147
+ in_channels=in_channels,
148
+ out_channels=out_channels,
149
+ temb_channels=temb_channels,
150
+ add_downsample=add_downsample,
151
+ resnet_eps=resnet_eps,
152
+ resnet_act_fn=resnet_act_fn,
153
+ resnet_groups=resnet_groups,
154
+ downsample_padding=downsample_padding,
155
+ cross_attention_dim=cross_attention_dim,
156
+ num_attention_heads=num_attention_heads,
157
+ dual_cross_attention=dual_cross_attention,
158
+ use_linear_projection=use_linear_projection,
159
+ only_cross_attention=only_cross_attention,
160
+ upcast_attention=upcast_attention,
161
+ resnet_time_scale_shift=resnet_time_scale_shift,
162
+ num_views=num_views,
163
+ joint_attention=joint_attention,
164
+ joint_attention_twice=joint_attention_twice,
165
+ multiview_attention=multiview_attention,
166
+ cross_domain_attention=cross_domain_attention
167
+ )
168
+ elif down_block_type == "SimpleCrossAttnDownBlock2D":
169
+ if cross_attention_dim is None:
170
+ raise ValueError("cross_attention_dim must be specified for SimpleCrossAttnDownBlock2D")
171
+ return SimpleCrossAttnDownBlock2D(
172
+ num_layers=num_layers,
173
+ in_channels=in_channels,
174
+ out_channels=out_channels,
175
+ temb_channels=temb_channels,
176
+ add_downsample=add_downsample,
177
+ resnet_eps=resnet_eps,
178
+ resnet_act_fn=resnet_act_fn,
179
+ resnet_groups=resnet_groups,
180
+ cross_attention_dim=cross_attention_dim,
181
+ attention_head_dim=attention_head_dim,
182
+ resnet_time_scale_shift=resnet_time_scale_shift,
183
+ skip_time_act=resnet_skip_time_act,
184
+ output_scale_factor=resnet_out_scale_factor,
185
+ only_cross_attention=only_cross_attention,
186
+ cross_attention_norm=cross_attention_norm,
187
+ )
188
+ elif down_block_type == "SkipDownBlock2D":
189
+ return SkipDownBlock2D(
190
+ num_layers=num_layers,
191
+ in_channels=in_channels,
192
+ out_channels=out_channels,
193
+ temb_channels=temb_channels,
194
+ add_downsample=add_downsample,
195
+ resnet_eps=resnet_eps,
196
+ resnet_act_fn=resnet_act_fn,
197
+ downsample_padding=downsample_padding,
198
+ resnet_time_scale_shift=resnet_time_scale_shift,
199
+ )
200
+ elif down_block_type == "AttnSkipDownBlock2D":
201
+ return AttnSkipDownBlock2D(
202
+ num_layers=num_layers,
203
+ in_channels=in_channels,
204
+ out_channels=out_channels,
205
+ temb_channels=temb_channels,
206
+ add_downsample=add_downsample,
207
+ resnet_eps=resnet_eps,
208
+ resnet_act_fn=resnet_act_fn,
209
+ attention_head_dim=attention_head_dim,
210
+ resnet_time_scale_shift=resnet_time_scale_shift,
211
+ )
212
+ elif down_block_type == "DownEncoderBlock2D":
213
+ return DownEncoderBlock2D(
214
+ num_layers=num_layers,
215
+ in_channels=in_channels,
216
+ out_channels=out_channels,
217
+ add_downsample=add_downsample,
218
+ resnet_eps=resnet_eps,
219
+ resnet_act_fn=resnet_act_fn,
220
+ resnet_groups=resnet_groups,
221
+ downsample_padding=downsample_padding,
222
+ resnet_time_scale_shift=resnet_time_scale_shift,
223
+ )
224
+ elif down_block_type == "AttnDownEncoderBlock2D":
225
+ return AttnDownEncoderBlock2D(
226
+ num_layers=num_layers,
227
+ in_channels=in_channels,
228
+ out_channels=out_channels,
229
+ add_downsample=add_downsample,
230
+ resnet_eps=resnet_eps,
231
+ resnet_act_fn=resnet_act_fn,
232
+ resnet_groups=resnet_groups,
233
+ downsample_padding=downsample_padding,
234
+ attention_head_dim=attention_head_dim,
235
+ resnet_time_scale_shift=resnet_time_scale_shift,
236
+ )
237
+ elif down_block_type == "KDownBlock2D":
238
+ return KDownBlock2D(
239
+ num_layers=num_layers,
240
+ in_channels=in_channels,
241
+ out_channels=out_channels,
242
+ temb_channels=temb_channels,
243
+ add_downsample=add_downsample,
244
+ resnet_eps=resnet_eps,
245
+ resnet_act_fn=resnet_act_fn,
246
+ )
247
+ elif down_block_type == "KCrossAttnDownBlock2D":
248
+ return KCrossAttnDownBlock2D(
249
+ num_layers=num_layers,
250
+ in_channels=in_channels,
251
+ out_channels=out_channels,
252
+ temb_channels=temb_channels,
253
+ add_downsample=add_downsample,
254
+ resnet_eps=resnet_eps,
255
+ resnet_act_fn=resnet_act_fn,
256
+ cross_attention_dim=cross_attention_dim,
257
+ attention_head_dim=attention_head_dim,
258
+ add_self_attention=True if not add_downsample else False,
259
+ )
260
+ raise ValueError(f"{down_block_type} does not exist.")
261
+
262
+
263
+ def get_up_block(
264
+ up_block_type,
265
+ num_layers,
266
+ in_channels,
267
+ out_channels,
268
+ prev_output_channel,
269
+ temb_channels,
270
+ add_upsample,
271
+ resnet_eps,
272
+ resnet_act_fn,
273
+ transformer_layers_per_block=1,
274
+ num_attention_heads=None,
275
+ resnet_groups=None,
276
+ cross_attention_dim=None,
277
+ dual_cross_attention=False,
278
+ use_linear_projection=False,
279
+ only_cross_attention=False,
280
+ upcast_attention=False,
281
+ resnet_time_scale_shift="default",
282
+ resnet_skip_time_act=False,
283
+ resnet_out_scale_factor=1.0,
284
+ cross_attention_norm=None,
285
+ attention_head_dim=None,
286
+ upsample_type=None,
287
+ num_views=1,
288
+ joint_attention: bool = False,
289
+ joint_attention_twice: bool = False,
290
+ multiview_attention: bool = True,
291
+ cross_domain_attention: bool=False
292
+ ):
293
+ # If attn head dim is not defined, we default it to the number of heads
294
+ if attention_head_dim is None:
295
+ logger.warn(
296
+ f"It is recommended to provide `attention_head_dim` when calling `get_up_block`. Defaulting `attention_head_dim` to {num_attention_heads}."
297
+ )
298
+ attention_head_dim = num_attention_heads
299
+
300
+ up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
301
+ if up_block_type == "UpBlock2D":
302
+ return UpBlock2D(
303
+ num_layers=num_layers,
304
+ in_channels=in_channels,
305
+ out_channels=out_channels,
306
+ prev_output_channel=prev_output_channel,
307
+ temb_channels=temb_channels,
308
+ add_upsample=add_upsample,
309
+ resnet_eps=resnet_eps,
310
+ resnet_act_fn=resnet_act_fn,
311
+ resnet_groups=resnet_groups,
312
+ resnet_time_scale_shift=resnet_time_scale_shift,
313
+ )
314
+ elif up_block_type == "ResnetUpsampleBlock2D":
315
+ return ResnetUpsampleBlock2D(
316
+ num_layers=num_layers,
317
+ in_channels=in_channels,
318
+ out_channels=out_channels,
319
+ prev_output_channel=prev_output_channel,
320
+ temb_channels=temb_channels,
321
+ add_upsample=add_upsample,
322
+ resnet_eps=resnet_eps,
323
+ resnet_act_fn=resnet_act_fn,
324
+ resnet_groups=resnet_groups,
325
+ resnet_time_scale_shift=resnet_time_scale_shift,
326
+ skip_time_act=resnet_skip_time_act,
327
+ output_scale_factor=resnet_out_scale_factor,
328
+ )
329
+ elif up_block_type == "CrossAttnUpBlock2D":
330
+ if cross_attention_dim is None:
331
+ raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock2D")
332
+ return CrossAttnUpBlock2D(
333
+ num_layers=num_layers,
334
+ transformer_layers_per_block=transformer_layers_per_block,
335
+ in_channels=in_channels,
336
+ out_channels=out_channels,
337
+ prev_output_channel=prev_output_channel,
338
+ temb_channels=temb_channels,
339
+ add_upsample=add_upsample,
340
+ resnet_eps=resnet_eps,
341
+ resnet_act_fn=resnet_act_fn,
342
+ resnet_groups=resnet_groups,
343
+ cross_attention_dim=cross_attention_dim,
344
+ num_attention_heads=num_attention_heads,
345
+ dual_cross_attention=dual_cross_attention,
346
+ use_linear_projection=use_linear_projection,
347
+ only_cross_attention=only_cross_attention,
348
+ upcast_attention=upcast_attention,
349
+ resnet_time_scale_shift=resnet_time_scale_shift,
350
+ )
351
+ # custom MV2D attention block
352
+ elif up_block_type == "CrossAttnUpBlockMV2D":
353
+ if cross_attention_dim is None:
354
+ raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlockMV2D")
355
+ return CrossAttnUpBlockMV2D(
356
+ num_layers=num_layers,
357
+ transformer_layers_per_block=transformer_layers_per_block,
358
+ in_channels=in_channels,
359
+ out_channels=out_channels,
360
+ prev_output_channel=prev_output_channel,
361
+ temb_channels=temb_channels,
362
+ add_upsample=add_upsample,
363
+ resnet_eps=resnet_eps,
364
+ resnet_act_fn=resnet_act_fn,
365
+ resnet_groups=resnet_groups,
366
+ cross_attention_dim=cross_attention_dim,
367
+ num_attention_heads=num_attention_heads,
368
+ dual_cross_attention=dual_cross_attention,
369
+ use_linear_projection=use_linear_projection,
370
+ only_cross_attention=only_cross_attention,
371
+ upcast_attention=upcast_attention,
372
+ resnet_time_scale_shift=resnet_time_scale_shift,
373
+ num_views=num_views,
374
+ joint_attention=joint_attention,
375
+ joint_attention_twice=joint_attention_twice,
376
+ multiview_attention=multiview_attention,
377
+ cross_domain_attention=cross_domain_attention
378
+ )
379
+ elif up_block_type == "SimpleCrossAttnUpBlock2D":
380
+ if cross_attention_dim is None:
381
+ raise ValueError("cross_attention_dim must be specified for SimpleCrossAttnUpBlock2D")
382
+ return SimpleCrossAttnUpBlock2D(
383
+ num_layers=num_layers,
384
+ in_channels=in_channels,
385
+ out_channels=out_channels,
386
+ prev_output_channel=prev_output_channel,
387
+ temb_channels=temb_channels,
388
+ add_upsample=add_upsample,
389
+ resnet_eps=resnet_eps,
390
+ resnet_act_fn=resnet_act_fn,
391
+ resnet_groups=resnet_groups,
392
+ cross_attention_dim=cross_attention_dim,
393
+ attention_head_dim=attention_head_dim,
394
+ resnet_time_scale_shift=resnet_time_scale_shift,
395
+ skip_time_act=resnet_skip_time_act,
396
+ output_scale_factor=resnet_out_scale_factor,
397
+ only_cross_attention=only_cross_attention,
398
+ cross_attention_norm=cross_attention_norm,
399
+ )
400
+ elif up_block_type == "AttnUpBlock2D":
401
+ if add_upsample is False:
402
+ upsample_type = None
403
+ else:
404
+ upsample_type = upsample_type or "conv" # default to 'conv'
405
+
406
+ return AttnUpBlock2D(
407
+ num_layers=num_layers,
408
+ in_channels=in_channels,
409
+ out_channels=out_channels,
410
+ prev_output_channel=prev_output_channel,
411
+ temb_channels=temb_channels,
412
+ resnet_eps=resnet_eps,
413
+ resnet_act_fn=resnet_act_fn,
414
+ resnet_groups=resnet_groups,
415
+ attention_head_dim=attention_head_dim,
416
+ resnet_time_scale_shift=resnet_time_scale_shift,
417
+ upsample_type=upsample_type,
418
+ )
419
+ elif up_block_type == "SkipUpBlock2D":
420
+ return SkipUpBlock2D(
421
+ num_layers=num_layers,
422
+ in_channels=in_channels,
423
+ out_channels=out_channels,
424
+ prev_output_channel=prev_output_channel,
425
+ temb_channels=temb_channels,
426
+ add_upsample=add_upsample,
427
+ resnet_eps=resnet_eps,
428
+ resnet_act_fn=resnet_act_fn,
429
+ resnet_time_scale_shift=resnet_time_scale_shift,
430
+ )
431
+ elif up_block_type == "AttnSkipUpBlock2D":
432
+ return AttnSkipUpBlock2D(
433
+ num_layers=num_layers,
434
+ in_channels=in_channels,
435
+ out_channels=out_channels,
436
+ prev_output_channel=prev_output_channel,
437
+ temb_channels=temb_channels,
438
+ add_upsample=add_upsample,
439
+ resnet_eps=resnet_eps,
440
+ resnet_act_fn=resnet_act_fn,
441
+ attention_head_dim=attention_head_dim,
442
+ resnet_time_scale_shift=resnet_time_scale_shift,
443
+ )
444
+ elif up_block_type == "UpDecoderBlock2D":
445
+ return UpDecoderBlock2D(
446
+ num_layers=num_layers,
447
+ in_channels=in_channels,
448
+ out_channels=out_channels,
449
+ add_upsample=add_upsample,
450
+ resnet_eps=resnet_eps,
451
+ resnet_act_fn=resnet_act_fn,
452
+ resnet_groups=resnet_groups,
453
+ resnet_time_scale_shift=resnet_time_scale_shift,
454
+ temb_channels=temb_channels,
455
+ )
456
+ elif up_block_type == "AttnUpDecoderBlock2D":
457
+ return AttnUpDecoderBlock2D(
458
+ num_layers=num_layers,
459
+ in_channels=in_channels,
460
+ out_channels=out_channels,
461
+ add_upsample=add_upsample,
462
+ resnet_eps=resnet_eps,
463
+ resnet_act_fn=resnet_act_fn,
464
+ resnet_groups=resnet_groups,
465
+ attention_head_dim=attention_head_dim,
466
+ resnet_time_scale_shift=resnet_time_scale_shift,
467
+ temb_channels=temb_channels,
468
+ )
469
+ elif up_block_type == "KUpBlock2D":
470
+ return KUpBlock2D(
471
+ num_layers=num_layers,
472
+ in_channels=in_channels,
473
+ out_channels=out_channels,
474
+ temb_channels=temb_channels,
475
+ add_upsample=add_upsample,
476
+ resnet_eps=resnet_eps,
477
+ resnet_act_fn=resnet_act_fn,
478
+ )
479
+ elif up_block_type == "KCrossAttnUpBlock2D":
480
+ return KCrossAttnUpBlock2D(
481
+ num_layers=num_layers,
482
+ in_channels=in_channels,
483
+ out_channels=out_channels,
484
+ temb_channels=temb_channels,
485
+ add_upsample=add_upsample,
486
+ resnet_eps=resnet_eps,
487
+ resnet_act_fn=resnet_act_fn,
488
+ cross_attention_dim=cross_attention_dim,
489
+ attention_head_dim=attention_head_dim,
490
+ )
491
+
492
+ raise ValueError(f"{up_block_type} does not exist.")
493
+
494
+
495
+ class UNetMidBlockMV2DCrossAttn(nn.Module):
496
+ def __init__(
497
+ self,
498
+ in_channels: int,
499
+ temb_channels: int,
500
+ dropout: float = 0.0,
501
+ num_layers: int = 1,
502
+ transformer_layers_per_block: int = 1,
503
+ resnet_eps: float = 1e-6,
504
+ resnet_time_scale_shift: str = "default",
505
+ resnet_act_fn: str = "swish",
506
+ resnet_groups: int = 32,
507
+ resnet_pre_norm: bool = True,
508
+ num_attention_heads=1,
509
+ output_scale_factor=1.0,
510
+ cross_attention_dim=1280,
511
+ dual_cross_attention=False,
512
+ use_linear_projection=False,
513
+ upcast_attention=False,
514
+ num_views: int = 1,
515
+ joint_attention: bool = False,
516
+ joint_attention_twice: bool = False,
517
+ multiview_attention: bool = True,
518
+ cross_domain_attention: bool=False
519
+ ):
520
+ super().__init__()
521
+
522
+ self.has_cross_attention = True
523
+ self.num_attention_heads = num_attention_heads
524
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
525
+
526
+ # there is always at least one resnet
527
+ resnets = [
528
+ ResnetBlock2D(
529
+ in_channels=in_channels,
530
+ out_channels=in_channels,
531
+ temb_channels=temb_channels,
532
+ eps=resnet_eps,
533
+ groups=resnet_groups,
534
+ dropout=dropout,
535
+ time_embedding_norm=resnet_time_scale_shift,
536
+ non_linearity=resnet_act_fn,
537
+ output_scale_factor=output_scale_factor,
538
+ pre_norm=resnet_pre_norm,
539
+ )
540
+ ]
541
+ attentions = []
542
+
543
+ for _ in range(num_layers):
544
+ if not dual_cross_attention:
545
+ attentions.append(
546
+ TransformerMV2DModel(
547
+ num_attention_heads,
548
+ in_channels // num_attention_heads,
549
+ in_channels=in_channels,
550
+ num_layers=transformer_layers_per_block,
551
+ cross_attention_dim=cross_attention_dim,
552
+ norm_num_groups=resnet_groups,
553
+ use_linear_projection=use_linear_projection,
554
+ upcast_attention=upcast_attention,
555
+ num_views=num_views,
556
+ joint_attention=joint_attention,
557
+ joint_attention_twice=joint_attention_twice,
558
+ multiview_attention=multiview_attention,
559
+ cross_domain_attention=cross_domain_attention
560
+ )
561
+ )
562
+ else:
563
+ raise NotImplementedError
564
+ resnets.append(
565
+ ResnetBlock2D(
566
+ in_channels=in_channels,
567
+ out_channels=in_channels,
568
+ temb_channels=temb_channels,
569
+ eps=resnet_eps,
570
+ groups=resnet_groups,
571
+ dropout=dropout,
572
+ time_embedding_norm=resnet_time_scale_shift,
573
+ non_linearity=resnet_act_fn,
574
+ output_scale_factor=output_scale_factor,
575
+ pre_norm=resnet_pre_norm,
576
+ )
577
+ )
578
+
579
+ self.attentions = nn.ModuleList(attentions)
580
+ self.resnets = nn.ModuleList(resnets)
581
+
582
+ def forward(
583
+ self,
584
+ hidden_states: torch.FloatTensor,
585
+ temb: Optional[torch.FloatTensor] = None,
586
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
587
+ attention_mask: Optional[torch.FloatTensor] = None,
588
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
589
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
590
+ ) -> torch.FloatTensor:
591
+ hidden_states = self.resnets[0](hidden_states, temb)
592
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
593
+ hidden_states = attn(
594
+ hidden_states,
595
+ encoder_hidden_states=encoder_hidden_states,
596
+ cross_attention_kwargs=cross_attention_kwargs,
597
+ attention_mask=attention_mask,
598
+ encoder_attention_mask=encoder_attention_mask,
599
+ return_dict=False,
600
+ )[0]
601
+ hidden_states = resnet(hidden_states, temb)
602
+
603
+ return hidden_states
604
+
605
+
606
+ class CrossAttnUpBlockMV2D(nn.Module):
607
+ def __init__(
608
+ self,
609
+ in_channels: int,
610
+ out_channels: int,
611
+ prev_output_channel: int,
612
+ temb_channels: int,
613
+ dropout: float = 0.0,
614
+ num_layers: int = 1,
615
+ transformer_layers_per_block: int = 1,
616
+ resnet_eps: float = 1e-6,
617
+ resnet_time_scale_shift: str = "default",
618
+ resnet_act_fn: str = "swish",
619
+ resnet_groups: int = 32,
620
+ resnet_pre_norm: bool = True,
621
+ num_attention_heads=1,
622
+ cross_attention_dim=1280,
623
+ output_scale_factor=1.0,
624
+ add_upsample=True,
625
+ dual_cross_attention=False,
626
+ use_linear_projection=False,
627
+ only_cross_attention=False,
628
+ upcast_attention=False,
629
+ num_views: int = 1,
630
+ joint_attention: bool = False,
631
+ joint_attention_twice: bool = False,
632
+ multiview_attention: bool = True,
633
+ cross_domain_attention: bool=False
634
+ ):
635
+ super().__init__()
636
+ resnets = []
637
+ attentions = []
638
+
639
+ self.has_cross_attention = True
640
+ self.num_attention_heads = num_attention_heads
641
+
642
+ for i in range(num_layers):
643
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
644
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
645
+
646
+ resnets.append(
647
+ ResnetBlock2D(
648
+ in_channels=resnet_in_channels + res_skip_channels,
649
+ out_channels=out_channels,
650
+ temb_channels=temb_channels,
651
+ eps=resnet_eps,
652
+ groups=resnet_groups,
653
+ dropout=dropout,
654
+ time_embedding_norm=resnet_time_scale_shift,
655
+ non_linearity=resnet_act_fn,
656
+ output_scale_factor=output_scale_factor,
657
+ pre_norm=resnet_pre_norm,
658
+ )
659
+ )
660
+ if not dual_cross_attention:
661
+ attentions.append(
662
+ TransformerMV2DModel(
663
+ num_attention_heads,
664
+ out_channels // num_attention_heads,
665
+ in_channels=out_channels,
666
+ num_layers=transformer_layers_per_block,
667
+ cross_attention_dim=cross_attention_dim,
668
+ norm_num_groups=resnet_groups,
669
+ use_linear_projection=use_linear_projection,
670
+ only_cross_attention=only_cross_attention,
671
+ upcast_attention=upcast_attention,
672
+ num_views=num_views,
673
+ joint_attention=joint_attention,
674
+ joint_attention_twice=joint_attention_twice,
675
+ multiview_attention=multiview_attention,
676
+ cross_domain_attention=cross_domain_attention
677
+ )
678
+ )
679
+ else:
680
+ raise NotImplementedError
681
+ self.attentions = nn.ModuleList(attentions)
682
+ self.resnets = nn.ModuleList(resnets)
683
+
684
+ if add_upsample:
685
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
686
+ else:
687
+ self.upsamplers = None
688
+ if num_views == 4:
689
+ self.gradient_checkpointing = False
690
+ else:
691
+ self.gradient_checkpointing = False
692
+
693
+ def forward(
694
+ self,
695
+ hidden_states: torch.FloatTensor,
696
+ res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
697
+ temb: Optional[torch.FloatTensor] = None,
698
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
699
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
700
+ upsample_size: Optional[int] = None,
701
+ attention_mask: Optional[torch.FloatTensor] = None,
702
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
703
+ ):
704
+ for resnet, attn in zip(self.resnets, self.attentions):
705
+ # pop res hidden states
706
+ res_hidden_states = res_hidden_states_tuple[-1]
707
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
708
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
709
+
710
+ if self.training and self.gradient_checkpointing:
711
+
712
+ def create_custom_forward(module, return_dict=None):
713
+ def custom_forward(*inputs):
714
+ if return_dict is not None:
715
+ return module(*inputs, return_dict=return_dict)
716
+ else:
717
+ return module(*inputs)
718
+
719
+ return custom_forward
720
+
721
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
722
+ hidden_states = torch.utils.checkpoint.checkpoint(
723
+ create_custom_forward(resnet),
724
+ hidden_states,
725
+ temb,
726
+ **ckpt_kwargs,
727
+ )
728
+ hidden_states = torch.utils.checkpoint.checkpoint(
729
+ create_custom_forward(attn, return_dict=False),
730
+ hidden_states,
731
+ encoder_hidden_states,
732
+ None, # timestep
733
+ None, # class_labels
734
+ cross_attention_kwargs,
735
+ attention_mask,
736
+ encoder_attention_mask,
737
+ **ckpt_kwargs,
738
+ )[0]
739
+ # hidden_states = attn(
740
+ # hidden_states,
741
+ # encoder_hidden_states=encoder_hidden_states,
742
+ # cross_attention_kwargs=cross_attention_kwargs,
743
+ # attention_mask=attention_mask,
744
+ # encoder_attention_mask=encoder_attention_mask,
745
+ # return_dict=False,
746
+ # )[0]
747
+ else:
748
+ hidden_states = resnet(hidden_states, temb)
749
+ hidden_states = attn(
750
+ hidden_states,
751
+ encoder_hidden_states=encoder_hidden_states,
752
+ cross_attention_kwargs=cross_attention_kwargs,
753
+ attention_mask=attention_mask,
754
+ encoder_attention_mask=encoder_attention_mask,
755
+ return_dict=False,
756
+ )[0]
757
+
758
+ if self.upsamplers is not None:
759
+ for upsampler in self.upsamplers:
760
+ hidden_states = upsampler(hidden_states, upsample_size)
761
+
762
+ return hidden_states
763
+
764
+
765
+ class CrossAttnDownBlockMV2D(nn.Module):
766
+ def __init__(
767
+ self,
768
+ in_channels: int,
769
+ out_channels: int,
770
+ temb_channels: int,
771
+ dropout: float = 0.0,
772
+ num_layers: int = 1,
773
+ transformer_layers_per_block: int = 1,
774
+ resnet_eps: float = 1e-6,
775
+ resnet_time_scale_shift: str = "default",
776
+ resnet_act_fn: str = "swish",
777
+ resnet_groups: int = 32,
778
+ resnet_pre_norm: bool = True,
779
+ num_attention_heads=1,
780
+ cross_attention_dim=1280,
781
+ output_scale_factor=1.0,
782
+ downsample_padding=1,
783
+ add_downsample=True,
784
+ dual_cross_attention=False,
785
+ use_linear_projection=False,
786
+ only_cross_attention=False,
787
+ upcast_attention=False,
788
+ num_views: int = 1,
789
+ joint_attention: bool = False,
790
+ joint_attention_twice: bool = False,
791
+ multiview_attention: bool = True,
792
+ cross_domain_attention: bool=False
793
+ ):
794
+ super().__init__()
795
+ resnets = []
796
+ attentions = []
797
+
798
+ self.has_cross_attention = True
799
+ self.num_attention_heads = num_attention_heads
800
+
801
+ for i in range(num_layers):
802
+ in_channels = in_channels if i == 0 else out_channels
803
+ resnets.append(
804
+ ResnetBlock2D(
805
+ in_channels=in_channels,
806
+ out_channels=out_channels,
807
+ temb_channels=temb_channels,
808
+ eps=resnet_eps,
809
+ groups=resnet_groups,
810
+ dropout=dropout,
811
+ time_embedding_norm=resnet_time_scale_shift,
812
+ non_linearity=resnet_act_fn,
813
+ output_scale_factor=output_scale_factor,
814
+ pre_norm=resnet_pre_norm,
815
+ )
816
+ )
817
+ if not dual_cross_attention:
818
+ attentions.append(
819
+ TransformerMV2DModel(
820
+ num_attention_heads,
821
+ out_channels // num_attention_heads,
822
+ in_channels=out_channels,
823
+ num_layers=transformer_layers_per_block,
824
+ cross_attention_dim=cross_attention_dim,
825
+ norm_num_groups=resnet_groups,
826
+ use_linear_projection=use_linear_projection,
827
+ only_cross_attention=only_cross_attention,
828
+ upcast_attention=upcast_attention,
829
+ num_views=num_views,
830
+ joint_attention=joint_attention,
831
+ joint_attention_twice=joint_attention_twice,
832
+ multiview_attention=multiview_attention,
833
+ cross_domain_attention=cross_domain_attention
834
+ )
835
+ )
836
+ else:
837
+ raise NotImplementedError
838
+ self.attentions = nn.ModuleList(attentions)
839
+ self.resnets = nn.ModuleList(resnets)
840
+
841
+ if add_downsample:
842
+ self.downsamplers = nn.ModuleList(
843
+ [
844
+ Downsample2D(
845
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
846
+ )
847
+ ]
848
+ )
849
+ else:
850
+ self.downsamplers = None
851
+ if num_views == 4:
852
+ self.gradient_checkpointing = False
853
+ else:
854
+ self.gradient_checkpointing = False
855
+
856
+ def forward(
857
+ self,
858
+ hidden_states: torch.FloatTensor,
859
+ temb: Optional[torch.FloatTensor] = None,
860
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
861
+ attention_mask: Optional[torch.FloatTensor] = None,
862
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
863
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
864
+ additional_residuals=None,
865
+ ):
866
+ output_states = ()
867
+
868
+ blocks = list(zip(self.resnets, self.attentions))
869
+
870
+ for i, (resnet, attn) in enumerate(blocks):
871
+ if self.training and self.gradient_checkpointing:
872
+
873
+ def create_custom_forward(module, return_dict=None):
874
+ def custom_forward(*inputs):
875
+ if return_dict is not None:
876
+ return module(*inputs, return_dict=return_dict)
877
+ else:
878
+ return module(*inputs)
879
+
880
+ return custom_forward
881
+
882
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
883
+ hidden_states = torch.utils.checkpoint.checkpoint(
884
+ create_custom_forward(resnet),
885
+ hidden_states,
886
+ temb,
887
+ **ckpt_kwargs,
888
+ )
889
+ hidden_states = torch.utils.checkpoint.checkpoint(
890
+ create_custom_forward(attn, return_dict=False),
891
+ hidden_states,
892
+ encoder_hidden_states,
893
+ None, # timestep
894
+ None, # class_labels
895
+ cross_attention_kwargs,
896
+ attention_mask,
897
+ encoder_attention_mask,
898
+ **ckpt_kwargs,
899
+ )[0]
900
+ else:
901
+ hidden_states = resnet(hidden_states, temb)
902
+ hidden_states = attn(
903
+ hidden_states,
904
+ encoder_hidden_states=encoder_hidden_states,
905
+ cross_attention_kwargs=cross_attention_kwargs,
906
+ attention_mask=attention_mask,
907
+ encoder_attention_mask=encoder_attention_mask,
908
+ return_dict=False,
909
+ )[0]
910
+
911
+ # apply additional residuals to the output of the last pair of resnet and attention blocks
912
+ if i == len(blocks) - 1 and additional_residuals is not None:
913
+ hidden_states = hidden_states + additional_residuals
914
+
915
+ output_states = output_states + (hidden_states,)
916
+
917
+ if self.downsamplers is not None:
918
+ for downsampler in self.downsamplers:
919
+ hidden_states = downsampler(hidden_states)
920
+
921
+ output_states = output_states + (hidden_states,)
922
+
923
+ return hidden_states, output_states
924
+
canonicalize/models/unet_mv2d_condition.py ADDED
@@ -0,0 +1,1502 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from dataclasses import dataclass
15
+ from typing import Any, Dict, List, Optional, Tuple, Union
16
+ import os
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.utils.checkpoint
21
+ from einops import rearrange
22
+
23
+
24
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
25
+ from diffusers.loaders import UNet2DConditionLoadersMixin
26
+ from diffusers.utils import BaseOutput, logging
27
+ from diffusers.models.activations import get_activation
28
+ from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor
29
+ from diffusers.models.embeddings import (
30
+ GaussianFourierProjection,
31
+ ImageHintTimeEmbedding,
32
+ ImageProjection,
33
+ ImageTimeEmbedding,
34
+ TextImageProjection,
35
+ TextImageTimeEmbedding,
36
+ TextTimeEmbedding,
37
+ TimestepEmbedding,
38
+ Timesteps,
39
+ )
40
+ from diffusers.models.modeling_utils import ModelMixin, load_state_dict, _load_state_dict_into_model
41
+ from diffusers.models.unet_2d_blocks import (
42
+ CrossAttnDownBlock2D,
43
+ CrossAttnUpBlock2D,
44
+ DownBlock2D,
45
+ UNetMidBlock2DCrossAttn,
46
+ UNetMidBlock2DSimpleCrossAttn,
47
+ UpBlock2D,
48
+ )
49
+ from diffusers.utils import (
50
+ CONFIG_NAME,
51
+ FLAX_WEIGHTS_NAME,
52
+ SAFETENSORS_WEIGHTS_NAME,
53
+ WEIGHTS_NAME,
54
+ _add_variant,
55
+ _get_model_file,
56
+ deprecate,
57
+ is_accelerate_available,
58
+ is_torch_version,
59
+ logging,
60
+ )
61
+ from diffusers import __version__
62
+ from canonicalize.models.unet_mv2d_blocks import (
63
+ CrossAttnDownBlockMV2D,
64
+ CrossAttnUpBlockMV2D,
65
+ UNetMidBlockMV2DCrossAttn,
66
+ get_down_block,
67
+ get_up_block,
68
+ )
69
+ from diffusers.models.attention_processor import Attention, AttnProcessor
70
+ from diffusers.utils.import_utils import is_xformers_available
71
+ from canonicalize.models.transformer_mv2d import XFormersMVAttnProcessor, MVAttnProcessor
72
+ from canonicalize.models.refunet import ReferenceOnlyAttnProc
73
+
74
+ from huggingface_hub.constants import HF_HUB_CACHE
75
+ from diffusers.utils.hub_utils import HF_HUB_OFFLINE
76
+
77
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
78
+
79
+
80
+ @dataclass
81
+ class UNetMV2DConditionOutput(BaseOutput):
82
+ """
83
+ The output of [`UNet2DConditionModel`].
84
+
85
+ Args:
86
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
87
+ The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
88
+ """
89
+
90
+ sample: torch.FloatTensor = None
91
+
92
+ class UNetMV2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
93
+ r"""
94
+ A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
95
+ shaped output.
96
+
97
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
98
+ for all models (such as downloading or saving).
99
+
100
+ Parameters:
101
+ sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
102
+ Height and width of input/output sample.
103
+ in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample.
104
+ out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
105
+ center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
106
+ flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
107
+ Whether to flip the sin to cos in the time embedding.
108
+ freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
109
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
110
+ The tuple of downsample blocks to use.
111
+ mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
112
+ Block type for middle of UNet, it can be either `UNetMidBlock2DCrossAttn` or
113
+ `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
114
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
115
+ The tuple of upsample blocks to use.
116
+ only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
117
+ Whether to include self-attention in the basic transformer blocks, see
118
+ [`~models.attention.BasicTransformerBlock`].
119
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
120
+ The tuple of output channels for each block.
121
+ layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
122
+ downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
123
+ mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
124
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
125
+ norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
126
+ If `None`, normalization and activation layers is skipped in post-processing.
127
+ norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
128
+ cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
129
+ The dimension of the cross attention features.
130
+ transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
131
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
132
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
133
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
134
+ encoder_hid_dim (`int`, *optional*, defaults to None):
135
+ If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
136
+ dimension to `cross_attention_dim`.
137
+ encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
138
+ If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
139
+ embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
140
+ attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
141
+ num_attention_heads (`int`, *optional*):
142
+ The number of attention heads. If not defined, defaults to `attention_head_dim`
143
+ resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
144
+ for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
145
+ class_embed_type (`str`, *optional*, defaults to `None`):
146
+ The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
147
+ `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
148
+ addition_embed_type (`str`, *optional*, defaults to `None`):
149
+ Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
150
+ "text". "text" will use the `TextTimeEmbedding` layer.
151
+ addition_time_embed_dim: (`int`, *optional*, defaults to `None`):
152
+ Dimension for the timestep embeddings.
153
+ num_class_embeds (`int`, *optional*, defaults to `None`):
154
+ Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
155
+ class conditioning with `class_embed_type` equal to `None`.
156
+ time_embedding_type (`str`, *optional*, defaults to `positional`):
157
+ The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
158
+ time_embedding_dim (`int`, *optional*, defaults to `None`):
159
+ An optional override for the dimension of the projected time embedding.
160
+ time_embedding_act_fn (`str`, *optional*, defaults to `None`):
161
+ Optional activation function to use only once on the time embeddings before they are passed to the rest of
162
+ the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`.
163
+ timestep_post_act (`str`, *optional*, defaults to `None`):
164
+ The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
165
+ time_cond_proj_dim (`int`, *optional*, defaults to `None`):
166
+ The dimension of `cond_proj` layer in the timestep embedding.
167
+ conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer.
168
+ conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer.
169
+ projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when
170
+ `class_embed_type="projection"`. Required when `class_embed_type="projection"`.
171
+ class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
172
+ embeddings with the class embeddings.
173
+ mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
174
+ Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If
175
+ `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the
176
+ `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False`
177
+ otherwise.
178
+ """
179
+
180
+ _supports_gradient_checkpointing = True
181
+
182
+ @register_to_config
183
+ def __init__(
184
+ self,
185
+ sample_size: Optional[int] = None,
186
+ in_channels: int = 4,
187
+ out_channels: int = 4,
188
+ center_input_sample: bool = False,
189
+ flip_sin_to_cos: bool = True,
190
+ freq_shift: int = 0,
191
+ down_block_types: Tuple[str] = (
192
+ "CrossAttnDownBlockMV2D",
193
+ "CrossAttnDownBlockMV2D",
194
+ "CrossAttnDownBlockMV2D",
195
+ "DownBlock2D",
196
+ ),
197
+ mid_block_type: Optional[str] = "UNetMidBlockMV2DCrossAttn",
198
+ up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlockMV2D", "CrossAttnUpBlockMV2D", "CrossAttnUpBlockMV2D"),
199
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
200
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
201
+ layers_per_block: Union[int, Tuple[int]] = 2,
202
+ downsample_padding: int = 1,
203
+ mid_block_scale_factor: float = 1,
204
+ act_fn: str = "silu",
205
+ norm_num_groups: Optional[int] = 32,
206
+ norm_eps: float = 1e-5,
207
+ cross_attention_dim: Union[int, Tuple[int]] = 1280,
208
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
209
+ encoder_hid_dim: Optional[int] = None,
210
+ encoder_hid_dim_type: Optional[str] = None,
211
+ attention_head_dim: Union[int, Tuple[int]] = 8,
212
+ num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
213
+ dual_cross_attention: bool = False,
214
+ use_linear_projection: bool = False,
215
+ class_embed_type: Optional[str] = None,
216
+ addition_embed_type: Optional[str] = None,
217
+ addition_time_embed_dim: Optional[int] = None,
218
+ num_class_embeds: Optional[int] = None,
219
+ upcast_attention: bool = False,
220
+ resnet_time_scale_shift: str = "default",
221
+ resnet_skip_time_act: bool = False,
222
+ resnet_out_scale_factor: int = 1.0,
223
+ time_embedding_type: str = "positional",
224
+ time_embedding_dim: Optional[int] = None,
225
+ time_embedding_act_fn: Optional[str] = None,
226
+ timestep_post_act: Optional[str] = None,
227
+ time_cond_proj_dim: Optional[int] = None,
228
+ conv_in_kernel: int = 3,
229
+ conv_out_kernel: int = 3,
230
+ projection_class_embeddings_input_dim: Optional[int] = None,
231
+ class_embeddings_concat: bool = False,
232
+ mid_block_only_cross_attention: Optional[bool] = None,
233
+ cross_attention_norm: Optional[str] = None,
234
+ addition_embed_type_num_heads=64,
235
+ num_views: int = 1,
236
+ joint_attention: bool = False,
237
+ joint_attention_twice: bool = False,
238
+ multiview_attention: bool = True,
239
+ cross_domain_attention: bool = False,
240
+ camera_input_dim: int = 12,
241
+ camera_hidden_dim: int = 320,
242
+ camera_output_dim: int = 1280,
243
+
244
+ ):
245
+ super().__init__()
246
+
247
+ self.sample_size = sample_size
248
+
249
+ if num_attention_heads is not None:
250
+ raise ValueError(
251
+ "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
252
+ )
253
+
254
+ # If `num_attention_heads` is not defined (which is the case for most models)
255
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
256
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
257
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
258
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
259
+ # which is why we correct for the naming here.
260
+ num_attention_heads = num_attention_heads or attention_head_dim
261
+
262
+ # Check inputs
263
+ if len(down_block_types) != len(up_block_types):
264
+ raise ValueError(
265
+ f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
266
+ )
267
+
268
+ if len(block_out_channels) != len(down_block_types):
269
+ raise ValueError(
270
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
271
+ )
272
+
273
+ if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
274
+ raise ValueError(
275
+ f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
276
+ )
277
+
278
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
279
+ raise ValueError(
280
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
281
+ )
282
+
283
+ if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
284
+ raise ValueError(
285
+ f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
286
+ )
287
+
288
+ if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
289
+ raise ValueError(
290
+ f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
291
+ )
292
+
293
+ if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
294
+ raise ValueError(
295
+ f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
296
+ )
297
+
298
+ # input
299
+ conv_in_padding = (conv_in_kernel - 1) // 2
300
+ self.conv_in = nn.Conv2d(
301
+ in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
302
+ )
303
+
304
+ # time
305
+ if time_embedding_type == "fourier":
306
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
307
+ if time_embed_dim % 2 != 0:
308
+ raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
309
+ self.time_proj = GaussianFourierProjection(
310
+ time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
311
+ )
312
+ timestep_input_dim = time_embed_dim
313
+ elif time_embedding_type == "positional":
314
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
315
+
316
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
317
+ timestep_input_dim = block_out_channels[0]
318
+ else:
319
+ raise ValueError(
320
+ f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
321
+ )
322
+
323
+ self.time_embedding = TimestepEmbedding(
324
+ timestep_input_dim,
325
+ time_embed_dim,
326
+ act_fn=act_fn,
327
+ post_act_fn=timestep_post_act,
328
+ cond_proj_dim=time_cond_proj_dim,
329
+ )
330
+
331
+ if encoder_hid_dim_type is None and encoder_hid_dim is not None:
332
+ encoder_hid_dim_type = "text_proj"
333
+ self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
334
+ logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
335
+
336
+ if encoder_hid_dim is None and encoder_hid_dim_type is not None:
337
+ raise ValueError(
338
+ f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
339
+ )
340
+
341
+ if encoder_hid_dim_type == "text_proj":
342
+ self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
343
+ elif encoder_hid_dim_type == "text_image_proj":
344
+ # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
345
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
346
+ # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
347
+ self.encoder_hid_proj = TextImageProjection(
348
+ text_embed_dim=encoder_hid_dim,
349
+ image_embed_dim=cross_attention_dim,
350
+ cross_attention_dim=cross_attention_dim,
351
+ )
352
+ elif encoder_hid_dim_type == "image_proj":
353
+ # Kandinsky 2.2
354
+ self.encoder_hid_proj = ImageProjection(
355
+ image_embed_dim=encoder_hid_dim,
356
+ cross_attention_dim=cross_attention_dim,
357
+ )
358
+ elif encoder_hid_dim_type is not None:
359
+ raise ValueError(
360
+ f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
361
+ )
362
+ else:
363
+ self.encoder_hid_proj = None
364
+
365
+ # class embedding
366
+ if class_embed_type is None and num_class_embeds is not None:
367
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
368
+ elif class_embed_type == "timestep":
369
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
370
+ elif class_embed_type == "identity":
371
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
372
+ elif class_embed_type == "projection":
373
+ if projection_class_embeddings_input_dim is None:
374
+ raise ValueError(
375
+ "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
376
+ )
377
+ # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
378
+ # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
379
+ # 2. it projects from an arbitrary input dimension.
380
+ #
381
+ # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
382
+ # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
383
+ # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
384
+ self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
385
+ elif class_embed_type == "simple_projection":
386
+ if projection_class_embeddings_input_dim is None:
387
+ raise ValueError(
388
+ "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
389
+ )
390
+ self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)
391
+ else:
392
+ self.class_embedding = None
393
+
394
+ if addition_embed_type == "text":
395
+ if encoder_hid_dim is not None:
396
+ text_time_embedding_from_dim = encoder_hid_dim
397
+ else:
398
+ text_time_embedding_from_dim = cross_attention_dim
399
+
400
+ self.add_embedding = TextTimeEmbedding(
401
+ text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
402
+ )
403
+ elif addition_embed_type == "text_image":
404
+ # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
405
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
406
+ # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
407
+ self.add_embedding = TextImageTimeEmbedding(
408
+ text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
409
+ )
410
+ elif addition_embed_type == "text_time":
411
+ self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
412
+ self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
413
+ elif addition_embed_type == "image":
414
+ # Kandinsky 2.2
415
+ self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
416
+ elif addition_embed_type == "image_hint":
417
+ # Kandinsky 2.2 ControlNet
418
+ self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
419
+ elif addition_embed_type is not None:
420
+ raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
421
+
422
+ if time_embedding_act_fn is None:
423
+ self.time_embed_act = None
424
+ else:
425
+ self.time_embed_act = get_activation(time_embedding_act_fn)
426
+
427
+ self.camera_embedding = nn.Sequential(
428
+ nn.Linear(camera_input_dim, time_embed_dim),
429
+ nn.SiLU(),
430
+ nn.Linear(time_embed_dim, time_embed_dim),
431
+ )
432
+
433
+ self.down_blocks = nn.ModuleList([])
434
+ self.up_blocks = nn.ModuleList([])
435
+
436
+ if isinstance(only_cross_attention, bool):
437
+ if mid_block_only_cross_attention is None:
438
+ mid_block_only_cross_attention = only_cross_attention
439
+
440
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
441
+
442
+ if mid_block_only_cross_attention is None:
443
+ mid_block_only_cross_attention = False
444
+
445
+ if isinstance(num_attention_heads, int):
446
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
447
+
448
+ if isinstance(attention_head_dim, int):
449
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
450
+
451
+ if isinstance(cross_attention_dim, int):
452
+ cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
453
+
454
+ if isinstance(layers_per_block, int):
455
+ layers_per_block = [layers_per_block] * len(down_block_types)
456
+
457
+ if isinstance(transformer_layers_per_block, int):
458
+ transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
459
+
460
+ if class_embeddings_concat:
461
+ # The time embeddings are concatenated with the class embeddings. The dimension of the
462
+ # time embeddings passed to the down, middle, and up blocks is twice the dimension of the
463
+ # regular time embeddings
464
+ blocks_time_embed_dim = time_embed_dim * 2
465
+ else:
466
+ blocks_time_embed_dim = time_embed_dim
467
+
468
+ # down
469
+ output_channel = block_out_channels[0]
470
+ for i, down_block_type in enumerate(down_block_types):
471
+ input_channel = output_channel
472
+ output_channel = block_out_channels[i]
473
+ is_final_block = i == len(block_out_channels) - 1
474
+
475
+ down_block = get_down_block(
476
+ down_block_type,
477
+ num_layers=layers_per_block[i],
478
+ transformer_layers_per_block=transformer_layers_per_block[i],
479
+ in_channels=input_channel,
480
+ out_channels=output_channel,
481
+ temb_channels=blocks_time_embed_dim,
482
+ add_downsample=not is_final_block,
483
+ resnet_eps=norm_eps,
484
+ resnet_act_fn=act_fn,
485
+ resnet_groups=norm_num_groups,
486
+ cross_attention_dim=cross_attention_dim[i],
487
+ num_attention_heads=num_attention_heads[i],
488
+ downsample_padding=downsample_padding,
489
+ dual_cross_attention=dual_cross_attention,
490
+ use_linear_projection=use_linear_projection,
491
+ only_cross_attention=only_cross_attention[i],
492
+ upcast_attention=upcast_attention,
493
+ resnet_time_scale_shift=resnet_time_scale_shift,
494
+ resnet_skip_time_act=resnet_skip_time_act,
495
+ resnet_out_scale_factor=resnet_out_scale_factor,
496
+ cross_attention_norm=cross_attention_norm,
497
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
498
+ num_views=num_views,
499
+ joint_attention=joint_attention,
500
+ joint_attention_twice=joint_attention_twice,
501
+ multiview_attention=multiview_attention,
502
+ cross_domain_attention=cross_domain_attention
503
+ )
504
+ self.down_blocks.append(down_block)
505
+
506
+ # mid
507
+ if mid_block_type == "UNetMidBlock2DCrossAttn":
508
+ self.mid_block = UNetMidBlock2DCrossAttn(
509
+ transformer_layers_per_block=transformer_layers_per_block[-1],
510
+ in_channels=block_out_channels[-1],
511
+ temb_channels=blocks_time_embed_dim,
512
+ resnet_eps=norm_eps,
513
+ resnet_act_fn=act_fn,
514
+ output_scale_factor=mid_block_scale_factor,
515
+ resnet_time_scale_shift=resnet_time_scale_shift,
516
+ cross_attention_dim=cross_attention_dim[-1],
517
+ num_attention_heads=num_attention_heads[-1],
518
+ resnet_groups=norm_num_groups,
519
+ dual_cross_attention=dual_cross_attention,
520
+ use_linear_projection=use_linear_projection,
521
+ upcast_attention=upcast_attention,
522
+ )
523
+ # custom MV2D attention block
524
+ elif mid_block_type == "UNetMidBlockMV2DCrossAttn":
525
+ self.mid_block = UNetMidBlockMV2DCrossAttn(
526
+ transformer_layers_per_block=transformer_layers_per_block[-1],
527
+ in_channels=block_out_channels[-1],
528
+ temb_channels=blocks_time_embed_dim,
529
+ resnet_eps=norm_eps,
530
+ resnet_act_fn=act_fn,
531
+ output_scale_factor=mid_block_scale_factor,
532
+ resnet_time_scale_shift=resnet_time_scale_shift,
533
+ cross_attention_dim=cross_attention_dim[-1],
534
+ num_attention_heads=num_attention_heads[-1],
535
+ resnet_groups=norm_num_groups,
536
+ dual_cross_attention=dual_cross_attention,
537
+ use_linear_projection=use_linear_projection,
538
+ upcast_attention=upcast_attention,
539
+ num_views=num_views,
540
+ joint_attention=joint_attention,
541
+ joint_attention_twice=joint_attention_twice,
542
+ multiview_attention=multiview_attention,
543
+ cross_domain_attention=cross_domain_attention
544
+ )
545
+ elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn":
546
+ self.mid_block = UNetMidBlock2DSimpleCrossAttn(
547
+ in_channels=block_out_channels[-1],
548
+ temb_channels=blocks_time_embed_dim,
549
+ resnet_eps=norm_eps,
550
+ resnet_act_fn=act_fn,
551
+ output_scale_factor=mid_block_scale_factor,
552
+ cross_attention_dim=cross_attention_dim[-1],
553
+ attention_head_dim=attention_head_dim[-1],
554
+ resnet_groups=norm_num_groups,
555
+ resnet_time_scale_shift=resnet_time_scale_shift,
556
+ skip_time_act=resnet_skip_time_act,
557
+ only_cross_attention=mid_block_only_cross_attention,
558
+ cross_attention_norm=cross_attention_norm,
559
+ )
560
+ elif mid_block_type is None:
561
+ self.mid_block = None
562
+ else:
563
+ raise ValueError(f"unknown mid_block_type : {mid_block_type}")
564
+
565
+ # count how many layers upsample the images
566
+ self.num_upsamplers = 0
567
+
568
+ # up
569
+ reversed_block_out_channels = list(reversed(block_out_channels))
570
+ reversed_num_attention_heads = list(reversed(num_attention_heads))
571
+ reversed_layers_per_block = list(reversed(layers_per_block))
572
+ reversed_cross_attention_dim = list(reversed(cross_attention_dim))
573
+ reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
574
+ only_cross_attention = list(reversed(only_cross_attention))
575
+
576
+ output_channel = reversed_block_out_channels[0]
577
+ for i, up_block_type in enumerate(up_block_types):
578
+ is_final_block = i == len(block_out_channels) - 1
579
+
580
+ prev_output_channel = output_channel
581
+ output_channel = reversed_block_out_channels[i]
582
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
583
+
584
+ # add upsample block for all BUT final layer
585
+ if not is_final_block:
586
+ add_upsample = True
587
+ self.num_upsamplers += 1
588
+ else:
589
+ add_upsample = False
590
+
591
+ up_block = get_up_block(
592
+ up_block_type,
593
+ num_layers=reversed_layers_per_block[i] + 1,
594
+ transformer_layers_per_block=reversed_transformer_layers_per_block[i],
595
+ in_channels=input_channel,
596
+ out_channels=output_channel,
597
+ prev_output_channel=prev_output_channel,
598
+ temb_channels=blocks_time_embed_dim,
599
+ add_upsample=add_upsample,
600
+ resnet_eps=norm_eps,
601
+ resnet_act_fn=act_fn,
602
+ resnet_groups=norm_num_groups,
603
+ cross_attention_dim=reversed_cross_attention_dim[i],
604
+ num_attention_heads=reversed_num_attention_heads[i],
605
+ dual_cross_attention=dual_cross_attention,
606
+ use_linear_projection=use_linear_projection,
607
+ only_cross_attention=only_cross_attention[i],
608
+ upcast_attention=upcast_attention,
609
+ resnet_time_scale_shift=resnet_time_scale_shift,
610
+ resnet_skip_time_act=resnet_skip_time_act,
611
+ resnet_out_scale_factor=resnet_out_scale_factor,
612
+ cross_attention_norm=cross_attention_norm,
613
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
614
+ num_views=num_views,
615
+ joint_attention=joint_attention,
616
+ joint_attention_twice=joint_attention_twice,
617
+ multiview_attention=multiview_attention,
618
+ cross_domain_attention=cross_domain_attention
619
+ )
620
+ self.up_blocks.append(up_block)
621
+ prev_output_channel = output_channel
622
+
623
+ # out
624
+ if norm_num_groups is not None:
625
+ self.conv_norm_out = nn.GroupNorm(
626
+ num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
627
+ )
628
+
629
+ self.conv_act = get_activation(act_fn)
630
+
631
+ else:
632
+ self.conv_norm_out = None
633
+ self.conv_act = None
634
+
635
+ conv_out_padding = (conv_out_kernel - 1) // 2
636
+ self.conv_out = nn.Conv2d(
637
+ block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
638
+ )
639
+
640
+ @property
641
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
642
+ r"""
643
+ Returns:
644
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
645
+ indexed by its weight name.
646
+ """
647
+ # set recursively
648
+ processors = {}
649
+
650
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
651
+ if hasattr(module, "set_processor"):
652
+ processors[f"{name}.processor"] = module.processor
653
+
654
+ for sub_name, child in module.named_children():
655
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
656
+
657
+ return processors
658
+
659
+ for name, module in self.named_children():
660
+ fn_recursive_add_processors(name, module, processors)
661
+
662
+ return processors
663
+
664
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
665
+ r"""
666
+ Sets the attention processor to use to compute attention.
667
+
668
+ Parameters:
669
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
670
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
671
+ for **all** `Attention` layers.
672
+
673
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
674
+ processor. This is strongly recommended when setting trainable attention processors.
675
+
676
+ """
677
+ count = len(self.attn_processors.keys())
678
+
679
+ if isinstance(processor, dict) and len(processor) != count:
680
+ raise ValueError(
681
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
682
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
683
+ )
684
+
685
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
686
+ if hasattr(module, "set_processor"):
687
+ if not isinstance(processor, dict):
688
+ module.set_processor(processor)
689
+ else:
690
+ module.set_processor(processor.pop(f"{name}.processor"))
691
+
692
+ for sub_name, child in module.named_children():
693
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
694
+
695
+ for name, module in self.named_children():
696
+ fn_recursive_attn_processor(name, module, processor)
697
+
698
+ def set_default_attn_processor(self):
699
+ """
700
+ Disables custom attention processors and sets the default attention implementation.
701
+ """
702
+ self.set_attn_processor(AttnProcessor())
703
+
704
+ def set_attention_slice(self, slice_size):
705
+ r"""
706
+ Enable sliced attention computation.
707
+
708
+ When this option is enabled, the attention module splits the input tensor in slices to compute attention in
709
+ several steps. This is useful for saving some memory in exchange for a small decrease in speed.
710
+
711
+ Args:
712
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
713
+ When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
714
+ `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
715
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
716
+ must be a multiple of `slice_size`.
717
+ """
718
+ sliceable_head_dims = []
719
+
720
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
721
+ if hasattr(module, "set_attention_slice"):
722
+ sliceable_head_dims.append(module.sliceable_head_dim)
723
+
724
+ for child in module.children():
725
+ fn_recursive_retrieve_sliceable_dims(child)
726
+
727
+ # retrieve number of attention layers
728
+ for module in self.children():
729
+ fn_recursive_retrieve_sliceable_dims(module)
730
+
731
+ num_sliceable_layers = len(sliceable_head_dims)
732
+
733
+ if slice_size == "auto":
734
+ # half the attention head size is usually a good trade-off between
735
+ # speed and memory
736
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
737
+ elif slice_size == "max":
738
+ # make smallest slice possible
739
+ slice_size = num_sliceable_layers * [1]
740
+
741
+ slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
742
+
743
+ if len(slice_size) != len(sliceable_head_dims):
744
+ raise ValueError(
745
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
746
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
747
+ )
748
+
749
+ for i in range(len(slice_size)):
750
+ size = slice_size[i]
751
+ dim = sliceable_head_dims[i]
752
+ if size is not None and size > dim:
753
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
754
+
755
+ # Recursively walk through all the children.
756
+ # Any children which exposes the set_attention_slice method
757
+ # gets the message
758
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
759
+ if hasattr(module, "set_attention_slice"):
760
+ module.set_attention_slice(slice_size.pop())
761
+
762
+ for child in module.children():
763
+ fn_recursive_set_attention_slice(child, slice_size)
764
+
765
+ reversed_slice_size = list(reversed(slice_size))
766
+ for module in self.children():
767
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
768
+
769
+
770
+ def forward(
771
+ self,
772
+ sample: torch.FloatTensor,
773
+ timestep: Union[torch.Tensor, float, int],
774
+ encoder_hidden_states: torch.Tensor,
775
+ camera_matrixs: Optional[torch.Tensor] = None,
776
+ class_labels: Optional[torch.Tensor] = None,
777
+ timestep_cond: Optional[torch.Tensor] = None,
778
+ attention_mask: Optional[torch.Tensor] = None,
779
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
780
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
781
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
782
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
783
+ encoder_attention_mask: Optional[torch.Tensor] = None,
784
+ return_dict: bool = True,
785
+ ) -> Union[UNetMV2DConditionOutput, Tuple]:
786
+ r"""
787
+ The [`UNet2DConditionModel`] forward method.
788
+
789
+ Args:
790
+ sample (`torch.FloatTensor`):
791
+ The noisy input tensor with the following shape `(batch, channel, height, width)`.
792
+ timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
793
+ encoder_hidden_states (`torch.FloatTensor`):
794
+ The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
795
+ encoder_attention_mask (`torch.Tensor`):
796
+ A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
797
+ `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
798
+ which adds large negative values to the attention scores corresponding to "discard" tokens.
799
+ return_dict (`bool`, *optional*, defaults to `True`):
800
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
801
+ tuple.
802
+ cross_attention_kwargs (`dict`, *optional*):
803
+ A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
804
+ added_cond_kwargs: (`dict`, *optional*):
805
+ A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that
806
+ are passed along to the UNet blocks.
807
+
808
+ Returns:
809
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
810
+ If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
811
+ a `tuple` is returned where the first element is the sample tensor.
812
+ """
813
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
814
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
815
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
816
+ # on the fly if necessary.
817
+ default_overall_up_factor = 2**self.num_upsamplers
818
+
819
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
820
+ forward_upsample_size = False
821
+ upsample_size = None
822
+
823
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
824
+ logger.info("Forward upsample size to force interpolation output size.")
825
+ forward_upsample_size = True
826
+
827
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension
828
+ # expects mask of shape:
829
+ # [batch, key_tokens]
830
+ # adds singleton query_tokens dimension:
831
+ # [batch, 1, key_tokens]
832
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
833
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
834
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
835
+ if attention_mask is not None:
836
+ # assume that mask is expressed as:
837
+ # (1 = keep, 0 = discard)
838
+ # convert mask into a bias that can be added to attention scores:
839
+ # (keep = +0, discard = -10000.0)
840
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
841
+ attention_mask = attention_mask.unsqueeze(1)
842
+
843
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
844
+ if encoder_attention_mask is not None:
845
+ encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
846
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
847
+
848
+ # 0. center input if necessary
849
+ if self.config.center_input_sample:
850
+ sample = 2 * sample - 1.0
851
+
852
+ # 1. time
853
+ timesteps = timestep
854
+ if not torch.is_tensor(timesteps):
855
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
856
+ # This would be a good case for the `match` statement (Python 3.10+)
857
+ is_mps = sample.device.type == "mps"
858
+ if isinstance(timestep, float):
859
+ dtype = torch.float32 if is_mps else torch.float64
860
+ else:
861
+ dtype = torch.int32 if is_mps else torch.int64
862
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
863
+ elif len(timesteps.shape) == 0:
864
+ timesteps = timesteps[None].to(sample.device)
865
+
866
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
867
+ timesteps = timesteps.expand(sample.shape[0])
868
+
869
+ t_emb = self.time_proj(timesteps)
870
+
871
+ # `Timesteps` does not contain any weights and will always return f32 tensors
872
+ # but time_embedding might actually be running in fp16. so we need to cast here.
873
+ # there might be better ways to encapsulate this.
874
+ t_emb = t_emb.to(dtype=sample.dtype)
875
+ emb = self.time_embedding(t_emb, timestep_cond)
876
+
877
+ if camera_matrixs is not None:
878
+ emb = torch.unsqueeze(emb, 1)
879
+ cam_emb = self.camera_embedding(camera_matrixs)
880
+ emb = emb.repeat(1,cam_emb.shape[1],1) #torch.Size([32, 4, 1280])
881
+ emb = emb + cam_emb
882
+ emb = rearrange(emb, "b f c -> (b f) c", f=emb.shape[1])
883
+
884
+ aug_emb = None
885
+
886
+ if self.class_embedding is not None and class_labels is not None:
887
+ if class_labels is None:
888
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
889
+
890
+ if self.config.class_embed_type == "timestep":
891
+ class_labels = self.time_proj(class_labels)
892
+
893
+ # `Timesteps` does not contain any weights and will always return f32 tensors
894
+ # there might be better ways to encapsulate this.
895
+ class_labels = class_labels.to(dtype=sample.dtype)
896
+
897
+ class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
898
+
899
+ if self.config.class_embeddings_concat:
900
+ emb = torch.cat([emb, class_emb], dim=-1)
901
+ else:
902
+ emb = emb + class_emb
903
+
904
+ if self.config.addition_embed_type == "text":
905
+ aug_emb = self.add_embedding(encoder_hidden_states)
906
+ elif self.config.addition_embed_type == "text_image":
907
+ # Kandinsky 2.1 - style
908
+ if "image_embeds" not in added_cond_kwargs:
909
+ raise ValueError(
910
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
911
+ )
912
+
913
+ image_embs = added_cond_kwargs.get("image_embeds")
914
+ text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
915
+ aug_emb = self.add_embedding(text_embs, image_embs)
916
+ elif self.config.addition_embed_type == "text_time":
917
+ # SDXL - style
918
+ if "text_embeds" not in added_cond_kwargs:
919
+ raise ValueError(
920
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
921
+ )
922
+ text_embeds = added_cond_kwargs.get("text_embeds")
923
+ if "time_ids" not in added_cond_kwargs:
924
+ raise ValueError(
925
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
926
+ )
927
+ time_ids = added_cond_kwargs.get("time_ids")
928
+ time_embeds = self.add_time_proj(time_ids.flatten())
929
+ time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
930
+
931
+ add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
932
+ add_embeds = add_embeds.to(emb.dtype)
933
+ aug_emb = self.add_embedding(add_embeds)
934
+ elif self.config.addition_embed_type == "image":
935
+ # Kandinsky 2.2 - style
936
+ if "image_embeds" not in added_cond_kwargs:
937
+ raise ValueError(
938
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
939
+ )
940
+ image_embs = added_cond_kwargs.get("image_embeds")
941
+ aug_emb = self.add_embedding(image_embs)
942
+ elif self.config.addition_embed_type == "image_hint":
943
+ # Kandinsky 2.2 - style
944
+ if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
945
+ raise ValueError(
946
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
947
+ )
948
+ image_embs = added_cond_kwargs.get("image_embeds")
949
+ hint = added_cond_kwargs.get("hint")
950
+ aug_emb, hint = self.add_embedding(image_embs, hint)
951
+ sample = torch.cat([sample, hint], dim=1)
952
+
953
+ emb = emb + aug_emb if aug_emb is not None else emb
954
+
955
+ if self.time_embed_act is not None:
956
+ emb = self.time_embed_act(emb)
957
+
958
+ if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
959
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
960
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
961
+ # Kadinsky 2.1 - style
962
+ if "image_embeds" not in added_cond_kwargs:
963
+ raise ValueError(
964
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
965
+ )
966
+
967
+ image_embeds = added_cond_kwargs.get("image_embeds")
968
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
969
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":
970
+ # Kandinsky 2.2 - style
971
+ if "image_embeds" not in added_cond_kwargs:
972
+ raise ValueError(
973
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
974
+ )
975
+ image_embeds = added_cond_kwargs.get("image_embeds")
976
+ encoder_hidden_states = self.encoder_hid_proj(image_embeds)
977
+ # 2. pre-process
978
+ sample = rearrange(sample, "b c f h w -> (b f) c h w", f=sample.shape[2])
979
+ sample = self.conv_in(sample)
980
+ # 3. down
981
+
982
+ is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
983
+ is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None
984
+
985
+ down_block_res_samples = (sample,)
986
+ for downsample_block in self.down_blocks:
987
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
988
+ # For t2i-adapter CrossAttnDownBlock2D
989
+ additional_residuals = {}
990
+ if is_adapter and len(down_block_additional_residuals) > 0:
991
+ additional_residuals["additional_residuals"] = down_block_additional_residuals.pop(0)
992
+
993
+ sample, res_samples = downsample_block(
994
+ hidden_states=sample,
995
+ temb=emb,
996
+ encoder_hidden_states=encoder_hidden_states,
997
+ attention_mask=attention_mask,
998
+ cross_attention_kwargs=cross_attention_kwargs,
999
+ encoder_attention_mask=encoder_attention_mask,
1000
+ **additional_residuals,
1001
+ )
1002
+ else:
1003
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
1004
+
1005
+ if is_adapter and len(down_block_additional_residuals) > 0:
1006
+ sample += down_block_additional_residuals.pop(0)
1007
+
1008
+ down_block_res_samples += res_samples
1009
+
1010
+ if is_controlnet:
1011
+ new_down_block_res_samples = ()
1012
+
1013
+ for down_block_res_sample, down_block_additional_residual in zip(
1014
+ down_block_res_samples, down_block_additional_residuals
1015
+ ):
1016
+ down_block_res_sample = down_block_res_sample + down_block_additional_residual
1017
+ new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
1018
+
1019
+ down_block_res_samples = new_down_block_res_samples
1020
+
1021
+ # 4. mid
1022
+ if self.mid_block is not None:
1023
+ sample = self.mid_block(
1024
+ sample,
1025
+ emb,
1026
+ encoder_hidden_states=encoder_hidden_states,
1027
+ attention_mask=attention_mask,
1028
+ cross_attention_kwargs=cross_attention_kwargs,
1029
+ encoder_attention_mask=encoder_attention_mask,
1030
+ )
1031
+
1032
+ if is_controlnet:
1033
+ sample = sample + mid_block_additional_residual
1034
+
1035
+ # 5. up
1036
+ for i, upsample_block in enumerate(self.up_blocks):
1037
+ is_final_block = i == len(self.up_blocks) - 1
1038
+
1039
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
1040
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
1041
+
1042
+ # if we have not reached the final block and need to forward the
1043
+ # upsample size, we do it here
1044
+ if not is_final_block and forward_upsample_size:
1045
+ upsample_size = down_block_res_samples[-1].shape[2:]
1046
+
1047
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
1048
+ sample = upsample_block(
1049
+ hidden_states=sample,
1050
+ temb=emb,
1051
+ res_hidden_states_tuple=res_samples,
1052
+ encoder_hidden_states=encoder_hidden_states,
1053
+ cross_attention_kwargs=cross_attention_kwargs,
1054
+ upsample_size=upsample_size,
1055
+ attention_mask=attention_mask,
1056
+ encoder_attention_mask=encoder_attention_mask,
1057
+ )
1058
+ else:
1059
+ sample = upsample_block(
1060
+ hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
1061
+ )
1062
+
1063
+ # 6. post-process
1064
+ if self.conv_norm_out:
1065
+ sample = self.conv_norm_out(sample)
1066
+ sample = self.conv_act(sample)
1067
+ sample = self.conv_out(sample)
1068
+
1069
+ if not return_dict:
1070
+ return (sample,)
1071
+
1072
+ return UNetMV2DConditionOutput(sample=sample)
1073
+
1074
+ @classmethod
1075
+ def from_pretrained_2d(
1076
+ cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
1077
+ camera_embedding_type: str, num_views: int, sample_size: int,
1078
+ zero_init_conv_in: bool = True, zero_init_camera_projection: bool = False,
1079
+ projection_class_embeddings_input_dim: int=6, joint_attention: bool = False,
1080
+ joint_attention_twice: bool = False, multiview_attention: bool = True,
1081
+ cross_domain_attention: bool = False,
1082
+ in_channels: int = 8, out_channels: int = 4, local_crossattn=False,
1083
+ **kwargs
1084
+ ):
1085
+ r"""
1086
+ Instantiate a pretrained PyTorch model from a pretrained model configuration.
1087
+
1088
+ The model is set in evaluation mode - `model.eval()` - by default, and dropout modules are deactivated. To
1089
+ train the model, set it back in training mode with `model.train()`.
1090
+
1091
+ Parameters:
1092
+ pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
1093
+ Can be either:
1094
+
1095
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
1096
+ the Hub.
1097
+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
1098
+ with [`~ModelMixin.save_pretrained`].
1099
+
1100
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
1101
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
1102
+ is not used.
1103
+ torch_dtype (`str` or `torch.dtype`, *optional*):
1104
+ Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
1105
+ dtype is automatically derived from the model's weights.
1106
+ force_download (`bool`, *optional*, defaults to `False`):
1107
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
1108
+ cached versions if they exist.
1109
+ resume_download (`bool`, *optional*, defaults to `False`):
1110
+ Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
1111
+ incompletely downloaded files are deleted.
1112
+ proxies (`Dict[str, str]`, *optional*):
1113
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
1114
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
1115
+ output_loading_info (`bool`, *optional*, defaults to `False`):
1116
+ Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
1117
+ local_files_only(`bool`, *optional*, defaults to `False`):
1118
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
1119
+ won't be downloaded from the Hub.
1120
+ use_auth_token (`str` or *bool*, *optional*):
1121
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
1122
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
1123
+ revision (`str`, *optional*, defaults to `"main"`):
1124
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
1125
+ allowed by Git.
1126
+ from_flax (`bool`, *optional*, defaults to `False`):
1127
+ Load the model weights from a Flax checkpoint save file.
1128
+ subfolder (`str`, *optional*, defaults to `""`):
1129
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
1130
+ mirror (`str`, *optional*):
1131
+ Mirror source to resolve accessibility issues if you're downloading a model in China. We do not
1132
+ guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
1133
+ information.
1134
+ device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
1135
+ A map that specifies where each submodule should go. It doesn't need to be defined for each
1136
+ parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the
1137
+ same device.
1138
+
1139
+ Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For
1140
+ more information about each option see [designing a device
1141
+ map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
1142
+ max_memory (`Dict`, *optional*):
1143
+ A dictionary device identifier for the maximum memory. Will default to the maximum memory available for
1144
+ each GPU and the available CPU RAM if unset.
1145
+ offload_folder (`str` or `os.PathLike`, *optional*):
1146
+ The path to offload weights if `device_map` contains the value `"disk"`.
1147
+ offload_state_dict (`bool`, *optional*):
1148
+ If `True`, temporarily offloads the CPU state dict to the hard drive to avoid running out of CPU RAM if
1149
+ the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to `True`
1150
+ when there is some disk offload.
1151
+ low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
1152
+ Speed up model loading only loading the pretrained weights and not initializing the weights. This also
1153
+ tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
1154
+ Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
1155
+ argument to `True` will raise an error.
1156
+ variant (`str`, *optional*):
1157
+ Load weights from a specified `variant` filename such as `"fp16"` or `"ema"`. This is ignored when
1158
+ loading `from_flax`.
1159
+ use_safetensors (`bool`, *optional*, defaults to `None`):
1160
+ If set to `None`, the `safetensors` weights are downloaded if they're available **and** if the
1161
+ `safetensors` library is installed. If set to `True`, the model is forcibly loaded from `safetensors`
1162
+ weights. If set to `False`, `safetensors` weights are not loaded.
1163
+
1164
+ <Tip>
1165
+
1166
+ To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with
1167
+ `huggingface-cli login`. You can also activate the special
1168
+ ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use this method in a
1169
+ firewalled environment.
1170
+
1171
+ </Tip>
1172
+
1173
+ Example:
1174
+
1175
+ ```py
1176
+ from diffusers import UNet2DConditionModel
1177
+
1178
+ unet = UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet")
1179
+ ```
1180
+
1181
+ If you get the error message below, you need to finetune the weights for your downstream task:
1182
+
1183
+ ```bash
1184
+ Some weights of UNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match:
1185
+ - conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated
1186
+ You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
1187
+ ```
1188
+ """
1189
+ cache_dir = kwargs.pop("cache_dir", HF_HUB_CACHE)
1190
+ ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
1191
+ force_download = kwargs.pop("force_download", False)
1192
+ from_flax = kwargs.pop("from_flax", False)
1193
+ resume_download = kwargs.pop("resume_download", False)
1194
+ proxies = kwargs.pop("proxies", None)
1195
+ output_loading_info = kwargs.pop("output_loading_info", False)
1196
+ local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
1197
+ use_auth_token = kwargs.pop("use_auth_token", None)
1198
+ revision = kwargs.pop("revision", None)
1199
+ torch_dtype = kwargs.pop("torch_dtype", None)
1200
+ subfolder = kwargs.pop("subfolder", None)
1201
+ device_map = kwargs.pop("device_map", None)
1202
+ max_memory = kwargs.pop("max_memory", None)
1203
+ offload_folder = kwargs.pop("offload_folder", None)
1204
+ offload_state_dict = kwargs.pop("offload_state_dict", False)
1205
+ variant = kwargs.pop("variant", None)
1206
+ use_safetensors = kwargs.pop("use_safetensors", None)
1207
+
1208
+ # if use_safetensors and not is_safetensors_available():
1209
+ # raise ValueError(
1210
+ # "`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetensors"
1211
+ # )
1212
+
1213
+ allow_pickle = False
1214
+ if use_safetensors is None:
1215
+ # use_safetensors = is_safetensors_available()
1216
+ use_safetensors = False
1217
+ allow_pickle = True
1218
+
1219
+ if device_map is not None and not is_accelerate_available():
1220
+ raise NotImplementedError(
1221
+ "Loading and dispatching requires `accelerate`. Please make sure to install accelerate or set"
1222
+ " `device_map=None`. You can install accelerate with `pip install accelerate`."
1223
+ )
1224
+
1225
+ # Check if we can handle device_map and dispatching the weights
1226
+ if device_map is not None and not is_torch_version(">=", "1.9.0"):
1227
+ raise NotImplementedError(
1228
+ "Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
1229
+ " `device_map=None`."
1230
+ )
1231
+
1232
+ # Load config if we don't provide a configuration
1233
+ config_path = pretrained_model_name_or_path
1234
+
1235
+ user_agent = {
1236
+ "diffusers": __version__,
1237
+ "file_type": "model",
1238
+ "framework": "pytorch",
1239
+ }
1240
+
1241
+ # load config
1242
+ config, unused_kwargs, commit_hash = cls.load_config(
1243
+ config_path,
1244
+ cache_dir=cache_dir,
1245
+ return_unused_kwargs=True,
1246
+ return_commit_hash=True,
1247
+ force_download=force_download,
1248
+ resume_download=resume_download,
1249
+ proxies=proxies,
1250
+ local_files_only=local_files_only,
1251
+ use_auth_token=use_auth_token,
1252
+ revision=revision,
1253
+ subfolder=subfolder,
1254
+ device_map=device_map,
1255
+ max_memory=max_memory,
1256
+ offload_folder=offload_folder,
1257
+ offload_state_dict=offload_state_dict,
1258
+ user_agent=user_agent,
1259
+ **kwargs,
1260
+ )
1261
+
1262
+ # modify config
1263
+ config["_class_name"] = cls.__name__
1264
+ config['in_channels'] = in_channels
1265
+ config['out_channels'] = out_channels
1266
+ config['sample_size'] = sample_size # training resolution
1267
+ config['num_views'] = num_views
1268
+ config['joint_attention'] = joint_attention
1269
+ config['joint_attention_twice'] = joint_attention_twice
1270
+ config['multiview_attention'] = multiview_attention
1271
+ config['cross_domain_attention'] = cross_domain_attention
1272
+ config["down_block_types"] = [
1273
+ "CrossAttnDownBlockMV2D",
1274
+ "CrossAttnDownBlockMV2D",
1275
+ "CrossAttnDownBlockMV2D",
1276
+ "DownBlock2D"
1277
+ ]
1278
+ config['mid_block_type'] = "UNetMidBlockMV2DCrossAttn"
1279
+ config["up_block_types"] = [
1280
+ "UpBlock2D",
1281
+ "CrossAttnUpBlockMV2D",
1282
+ "CrossAttnUpBlockMV2D",
1283
+ "CrossAttnUpBlockMV2D"
1284
+ ]
1285
+ config['class_embed_type'] = 'projection'
1286
+ if camera_embedding_type == 'e_de_da_sincos':
1287
+ config['projection_class_embeddings_input_dim'] = projection_class_embeddings_input_dim # default 6
1288
+ else:
1289
+ raise NotImplementedError
1290
+
1291
+ # load model
1292
+ model_file = None
1293
+ if from_flax:
1294
+ raise NotImplementedError
1295
+ else:
1296
+ if use_safetensors:
1297
+ try:
1298
+ model_file = _get_model_file(
1299
+ pretrained_model_name_or_path,
1300
+ weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant),
1301
+ cache_dir=cache_dir,
1302
+ force_download=force_download,
1303
+ resume_download=resume_download,
1304
+ proxies=proxies,
1305
+ local_files_only=local_files_only,
1306
+ use_auth_token=use_auth_token,
1307
+ revision=revision,
1308
+ subfolder=subfolder,
1309
+ user_agent=user_agent,
1310
+ commit_hash=commit_hash,
1311
+ )
1312
+ except IOError as e:
1313
+ if not allow_pickle:
1314
+ raise e
1315
+ pass
1316
+ if model_file is None:
1317
+ model_file = _get_model_file(
1318
+ pretrained_model_name_or_path,
1319
+ weights_name=_add_variant(WEIGHTS_NAME, variant),
1320
+ cache_dir=cache_dir,
1321
+ force_download=force_download,
1322
+ resume_download=resume_download,
1323
+ proxies=proxies,
1324
+ local_files_only=local_files_only,
1325
+ use_auth_token=use_auth_token,
1326
+ revision=revision,
1327
+ subfolder=subfolder,
1328
+ user_agent=user_agent,
1329
+ commit_hash=commit_hash,
1330
+ )
1331
+
1332
+ model = cls.from_config(config, **unused_kwargs)
1333
+ if local_crossattn:
1334
+ unet_lora_attn_procs = dict()
1335
+ for name, _ in model.attn_processors.items():
1336
+ if not name.endswith("attn1.processor"):
1337
+ default_attn_proc = AttnProcessor()
1338
+ elif is_xformers_available():
1339
+ default_attn_proc = XFormersMVAttnProcessor()
1340
+ else:
1341
+ default_attn_proc = MVAttnProcessor()
1342
+ unet_lora_attn_procs[name] = ReferenceOnlyAttnProc(
1343
+ default_attn_proc, enabled=name.endswith("attn1.processor"), name=name
1344
+ )
1345
+ model.set_attn_processor(unet_lora_attn_procs)
1346
+ state_dict = load_state_dict(model_file, variant=variant)
1347
+ model._convert_deprecated_attention_blocks(state_dict)
1348
+
1349
+ conv_in_weight = state_dict['conv_in.weight']
1350
+ conv_out_weight = state_dict['conv_out.weight']
1351
+ model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model_2d(
1352
+ model,
1353
+ state_dict,
1354
+ model_file,
1355
+ pretrained_model_name_or_path,
1356
+ ignore_mismatched_sizes=True,
1357
+ )
1358
+ if any([key == 'conv_in.weight' for key, _, _ in mismatched_keys]):
1359
+ # initialize from the original SD structure
1360
+ model.conv_in.weight.data[:,:4] = conv_in_weight
1361
+
1362
+ # whether to place all zero to new layers?
1363
+ if zero_init_conv_in:
1364
+ model.conv_in.weight.data[:,4:] = 0.
1365
+
1366
+ if any([key == 'conv_out.weight' for key, _, _ in mismatched_keys]):
1367
+ # initialize from the original SD structure
1368
+ model.conv_out.weight.data[:,:4] = conv_out_weight
1369
+ if out_channels == 8: # copy for the last 4 channels
1370
+ model.conv_out.weight.data[:, 4:] = conv_out_weight
1371
+
1372
+ if zero_init_camera_projection:
1373
+ for p in model.class_embedding.parameters():
1374
+ torch.nn.init.zeros_(p)
1375
+
1376
+ loading_info = {
1377
+ "missing_keys": missing_keys,
1378
+ "unexpected_keys": unexpected_keys,
1379
+ "mismatched_keys": mismatched_keys,
1380
+ "error_msgs": error_msgs,
1381
+ }
1382
+
1383
+ if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
1384
+ raise ValueError(
1385
+ f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}."
1386
+ )
1387
+ elif torch_dtype is not None:
1388
+ model = model.to(torch_dtype)
1389
+
1390
+ model.register_to_config(_name_or_path=pretrained_model_name_or_path)
1391
+
1392
+ # Set model in evaluation mode to deactivate DropOut modules by default
1393
+ model.eval()
1394
+ if output_loading_info:
1395
+ return model, loading_info
1396
+
1397
+ return model
1398
+
1399
+ @classmethod
1400
+ def _load_pretrained_model_2d(
1401
+ cls,
1402
+ model,
1403
+ state_dict,
1404
+ resolved_archive_file,
1405
+ pretrained_model_name_or_path,
1406
+ ignore_mismatched_sizes=False,
1407
+ ):
1408
+ # Retrieve missing & unexpected_keys
1409
+ model_state_dict = model.state_dict()
1410
+ loaded_keys = list(state_dict.keys())
1411
+
1412
+ expected_keys = list(model_state_dict.keys())
1413
+
1414
+ original_loaded_keys = loaded_keys
1415
+
1416
+ missing_keys = list(set(expected_keys) - set(loaded_keys))
1417
+ unexpected_keys = list(set(loaded_keys) - set(expected_keys))
1418
+
1419
+ # Make sure we are able to load base models as well as derived models (with heads)
1420
+ model_to_load = model
1421
+
1422
+ def _find_mismatched_keys(
1423
+ state_dict,
1424
+ model_state_dict,
1425
+ loaded_keys,
1426
+ ignore_mismatched_sizes,
1427
+ ):
1428
+ mismatched_keys = []
1429
+ if ignore_mismatched_sizes:
1430
+ for checkpoint_key in loaded_keys:
1431
+ model_key = checkpoint_key
1432
+
1433
+ if (
1434
+ model_key in model_state_dict
1435
+ and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
1436
+ ):
1437
+ mismatched_keys.append(
1438
+ (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
1439
+ )
1440
+ del state_dict[checkpoint_key]
1441
+ return mismatched_keys
1442
+
1443
+ if state_dict is not None:
1444
+ # Whole checkpoint
1445
+ mismatched_keys = _find_mismatched_keys(
1446
+ state_dict,
1447
+ model_state_dict,
1448
+ original_loaded_keys,
1449
+ ignore_mismatched_sizes,
1450
+ )
1451
+ error_msgs = _load_state_dict_into_model(model_to_load, state_dict)
1452
+
1453
+ if len(error_msgs) > 0:
1454
+ error_msg = "\n\t".join(error_msgs)
1455
+ if "size mismatch" in error_msg:
1456
+ error_msg += (
1457
+ "\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method."
1458
+ )
1459
+ raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
1460
+
1461
+ if len(unexpected_keys) > 0:
1462
+ logger.warning(
1463
+ f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
1464
+ f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
1465
+ f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task"
1466
+ " or with another architecture (e.g. initializing a BertForSequenceClassification model from a"
1467
+ " BertForPreTraining model).\n- This IS NOT expected if you are initializing"
1468
+ f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly"
1469
+ " identical (initializing a BertForSequenceClassification model from a"
1470
+ " BertForSequenceClassification model)."
1471
+ )
1472
+ else:
1473
+ logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
1474
+ if len(missing_keys) > 0:
1475
+ logger.warning(
1476
+ f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
1477
+ f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
1478
+ " TRAIN this model on a down-stream task to be able to use it for predictions and inference."
1479
+ )
1480
+ elif len(mismatched_keys) == 0:
1481
+ logger.info(
1482
+ f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at"
1483
+ f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the"
1484
+ f" checkpoint was trained on, you can already use {model.__class__.__name__} for predictions"
1485
+ " without further training."
1486
+ )
1487
+ if len(mismatched_keys) > 0:
1488
+ mismatched_warning = "\n".join(
1489
+ [
1490
+ f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
1491
+ for key, shape1, shape2 in mismatched_keys
1492
+ ]
1493
+ )
1494
+ logger.warning(
1495
+ f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
1496
+ f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not"
1497
+ f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be"
1498
+ " able to use it for predictions and inference."
1499
+ )
1500
+
1501
+ return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs
1502
+
canonicalize/models/unet_mv2d_ref.py ADDED
@@ -0,0 +1,1543 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from dataclasses import dataclass
15
+ from typing import Any, Dict, List, Optional, Tuple, Union
16
+ import os
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.utils.checkpoint
21
+ from einops import rearrange
22
+
23
+
24
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
25
+ from diffusers.loaders import UNet2DConditionLoadersMixin
26
+ from diffusers.utils import BaseOutput, logging
27
+ from diffusers.models.activations import get_activation
28
+ from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor
29
+ from diffusers.models.embeddings import (
30
+ GaussianFourierProjection,
31
+ ImageHintTimeEmbedding,
32
+ ImageProjection,
33
+ ImageTimeEmbedding,
34
+ TextImageProjection,
35
+ TextImageTimeEmbedding,
36
+ TextTimeEmbedding,
37
+ TimestepEmbedding,
38
+ Timesteps,
39
+ )
40
+ from diffusers.models.lora import LoRALinearLayer
41
+
42
+ from diffusers.models.modeling_utils import ModelMixin, load_state_dict, _load_state_dict_into_model
43
+ from diffusers.models.unet_2d_blocks import (
44
+ CrossAttnDownBlock2D,
45
+ CrossAttnUpBlock2D,
46
+ DownBlock2D,
47
+ UNetMidBlock2DCrossAttn,
48
+ UNetMidBlock2DSimpleCrossAttn,
49
+ UpBlock2D,
50
+ )
51
+ from diffusers.utils import (
52
+ CONFIG_NAME,
53
+ FLAX_WEIGHTS_NAME,
54
+ SAFETENSORS_WEIGHTS_NAME,
55
+ WEIGHTS_NAME,
56
+ _add_variant,
57
+ _get_model_file,
58
+ deprecate,
59
+ is_accelerate_available,
60
+ is_torch_version,
61
+ logging,
62
+ )
63
+ from diffusers import __version__
64
+ from canonicalize.models.unet_mv2d_blocks import (
65
+ CrossAttnDownBlockMV2D,
66
+ CrossAttnUpBlockMV2D,
67
+ UNetMidBlockMV2DCrossAttn,
68
+ get_down_block,
69
+ get_up_block,
70
+ )
71
+ from diffusers.models.attention_processor import Attention, AttnProcessor
72
+ from diffusers.utils.import_utils import is_xformers_available
73
+ from canonicalize.models.transformer_mv2d import XFormersMVAttnProcessor, MVAttnProcessor
74
+ from canonicalize.models.refunet import ReferenceOnlyAttnProc
75
+
76
+ from huggingface_hub.constants import HF_HUB_CACHE
77
+ from diffusers.utils.hub_utils import HF_HUB_OFFLINE
78
+
79
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
80
+
81
+
82
+ @dataclass
83
+ class UNetMV2DRefOutput(BaseOutput):
84
+ """
85
+ The output of [`UNet2DConditionModel`].
86
+
87
+ Args:
88
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
89
+ The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
90
+ """
91
+
92
+ sample: torch.FloatTensor = None
93
+
94
+ class Identity(torch.nn.Module):
95
+ r"""A placeholder identity operator that is argument-insensitive.
96
+
97
+ Args:
98
+ args: any argument (unused)
99
+ kwargs: any keyword argument (unused)
100
+
101
+ Shape:
102
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
103
+ - Output: :math:`(*)`, same shape as the input.
104
+
105
+ Examples::
106
+
107
+ >>> m = nn.Identity(54, unused_argument1=0.1, unused_argument2=False)
108
+ >>> input = torch.randn(128, 20)
109
+ >>> output = m(input)
110
+ >>> print(output.size())
111
+ torch.Size([128, 20])
112
+
113
+ """
114
+ def __init__(self, scale=None, *args, **kwargs) -> None:
115
+ super(Identity, self).__init__()
116
+
117
+ def forward(self, input, *args, **kwargs):
118
+ return input
119
+
120
+
121
+
122
+ class _LoRACompatibleLinear(nn.Module):
123
+ """
124
+ A Linear layer that can be used with LoRA.
125
+ """
126
+
127
+ def __init__(self, *args, lora_layer: Optional[LoRALinearLayer] = None, **kwargs):
128
+ super().__init__(*args, **kwargs)
129
+ self.lora_layer = lora_layer
130
+
131
+ def set_lora_layer(self, lora_layer: Optional[LoRALinearLayer]):
132
+ self.lora_layer = lora_layer
133
+
134
+ def _fuse_lora(self):
135
+ pass
136
+
137
+ def _unfuse_lora(self):
138
+ pass
139
+
140
+ def forward(self, hidden_states, scale=None, lora_scale: int = 1):
141
+ return hidden_states
142
+
143
+ class UNetMV2DRefModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
144
+ r"""
145
+ A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
146
+ shaped output.
147
+
148
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
149
+ for all models (such as downloading or saving).
150
+
151
+ Parameters:
152
+ sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
153
+ Height and width of input/output sample.
154
+ in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample.
155
+ out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
156
+ center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
157
+ flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
158
+ Whether to flip the sin to cos in the time embedding.
159
+ freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
160
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
161
+ The tuple of downsample blocks to use.
162
+ mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
163
+ Block type for middle of UNet, it can be either `UNetMidBlock2DCrossAttn` or
164
+ `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
165
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
166
+ The tuple of upsample blocks to use.
167
+ only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
168
+ Whether to include self-attention in the basic transformer blocks, see
169
+ [`~models.attention.BasicTransformerBlock`].
170
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
171
+ The tuple of output channels for each block.
172
+ layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
173
+ downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
174
+ mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
175
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
176
+ norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
177
+ If `None`, normalization and activation layers is skipped in post-processing.
178
+ norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
179
+ cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
180
+ The dimension of the cross attention features.
181
+ transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
182
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
183
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
184
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
185
+ encoder_hid_dim (`int`, *optional*, defaults to None):
186
+ If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
187
+ dimension to `cross_attention_dim`.
188
+ encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
189
+ If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
190
+ embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
191
+ attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
192
+ num_attention_heads (`int`, *optional*):
193
+ The number of attention heads. If not defined, defaults to `attention_head_dim`
194
+ resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
195
+ for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
196
+ class_embed_type (`str`, *optional*, defaults to `None`):
197
+ The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
198
+ `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
199
+ addition_embed_type (`str`, *optional*, defaults to `None`):
200
+ Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
201
+ "text". "text" will use the `TextTimeEmbedding` layer.
202
+ addition_time_embed_dim: (`int`, *optional*, defaults to `None`):
203
+ Dimension for the timestep embeddings.
204
+ num_class_embeds (`int`, *optional*, defaults to `None`):
205
+ Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
206
+ class conditioning with `class_embed_type` equal to `None`.
207
+ time_embedding_type (`str`, *optional*, defaults to `positional`):
208
+ The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
209
+ time_embedding_dim (`int`, *optional*, defaults to `None`):
210
+ An optional override for the dimension of the projected time embedding.
211
+ time_embedding_act_fn (`str`, *optional*, defaults to `None`):
212
+ Optional activation function to use only once on the time embeddings before they are passed to the rest of
213
+ the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`.
214
+ timestep_post_act (`str`, *optional*, defaults to `None`):
215
+ The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
216
+ time_cond_proj_dim (`int`, *optional*, defaults to `None`):
217
+ The dimension of `cond_proj` layer in the timestep embedding.
218
+ conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer.
219
+ conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer.
220
+ projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when
221
+ `class_embed_type="projection"`. Required when `class_embed_type="projection"`.
222
+ class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
223
+ embeddings with the class embeddings.
224
+ mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
225
+ Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If
226
+ `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the
227
+ `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False`
228
+ otherwise.
229
+ """
230
+
231
+ _supports_gradient_checkpointing = True
232
+
233
+ @register_to_config
234
+ def __init__(
235
+ self,
236
+ sample_size: Optional[int] = None,
237
+ in_channels: int = 4,
238
+ out_channels: int = 4,
239
+ center_input_sample: bool = False,
240
+ flip_sin_to_cos: bool = True,
241
+ freq_shift: int = 0,
242
+ down_block_types: Tuple[str] = (
243
+ "CrossAttnDownBlockMV2D",
244
+ "CrossAttnDownBlockMV2D",
245
+ "CrossAttnDownBlockMV2D",
246
+ "DownBlock2D",
247
+ ),
248
+ mid_block_type: Optional[str] = "UNetMidBlockMV2DCrossAttn",
249
+ up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlockMV2D", "CrossAttnUpBlockMV2D", "CrossAttnUpBlockMV2D"),
250
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
251
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
252
+ layers_per_block: Union[int, Tuple[int]] = 2,
253
+ downsample_padding: int = 1,
254
+ mid_block_scale_factor: float = 1,
255
+ act_fn: str = "silu",
256
+ norm_num_groups: Optional[int] = 32,
257
+ norm_eps: float = 1e-5,
258
+ cross_attention_dim: Union[int, Tuple[int]] = 1280,
259
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
260
+ encoder_hid_dim: Optional[int] = None,
261
+ encoder_hid_dim_type: Optional[str] = None,
262
+ attention_head_dim: Union[int, Tuple[int]] = 8,
263
+ num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
264
+ dual_cross_attention: bool = False,
265
+ use_linear_projection: bool = False,
266
+ class_embed_type: Optional[str] = None,
267
+ addition_embed_type: Optional[str] = None,
268
+ addition_time_embed_dim: Optional[int] = None,
269
+ num_class_embeds: Optional[int] = None,
270
+ upcast_attention: bool = False,
271
+ resnet_time_scale_shift: str = "default",
272
+ resnet_skip_time_act: bool = False,
273
+ resnet_out_scale_factor: int = 1.0,
274
+ time_embedding_type: str = "positional",
275
+ time_embedding_dim: Optional[int] = None,
276
+ time_embedding_act_fn: Optional[str] = None,
277
+ timestep_post_act: Optional[str] = None,
278
+ time_cond_proj_dim: Optional[int] = None,
279
+ conv_in_kernel: int = 3,
280
+ conv_out_kernel: int = 3,
281
+ projection_class_embeddings_input_dim: Optional[int] = None,
282
+ class_embeddings_concat: bool = False,
283
+ mid_block_only_cross_attention: Optional[bool] = None,
284
+ cross_attention_norm: Optional[str] = None,
285
+ addition_embed_type_num_heads=64,
286
+ num_views: int = 1,
287
+ joint_attention: bool = False,
288
+ joint_attention_twice: bool = False,
289
+ multiview_attention: bool = True,
290
+ cross_domain_attention: bool = False,
291
+ camera_input_dim: int = 12,
292
+ camera_hidden_dim: int = 320,
293
+ camera_output_dim: int = 1280,
294
+
295
+ ):
296
+ super().__init__()
297
+
298
+ self.sample_size = sample_size
299
+
300
+ if num_attention_heads is not None:
301
+ raise ValueError(
302
+ "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
303
+ )
304
+
305
+ # If `num_attention_heads` is not defined (which is the case for most models)
306
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
307
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
308
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
309
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
310
+ # which is why we correct for the naming here.
311
+ num_attention_heads = num_attention_heads or attention_head_dim
312
+
313
+ # Check inputs
314
+ if len(down_block_types) != len(up_block_types):
315
+ raise ValueError(
316
+ f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
317
+ )
318
+
319
+ if len(block_out_channels) != len(down_block_types):
320
+ raise ValueError(
321
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
322
+ )
323
+
324
+ if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
325
+ raise ValueError(
326
+ f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
327
+ )
328
+
329
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
330
+ raise ValueError(
331
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
332
+ )
333
+
334
+ if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
335
+ raise ValueError(
336
+ f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
337
+ )
338
+
339
+ if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
340
+ raise ValueError(
341
+ f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
342
+ )
343
+
344
+ if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
345
+ raise ValueError(
346
+ f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
347
+ )
348
+
349
+ # input
350
+ conv_in_padding = (conv_in_kernel - 1) // 2
351
+ self.conv_in = nn.Conv2d(
352
+ in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
353
+ )
354
+
355
+ # time
356
+ if time_embedding_type == "fourier":
357
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
358
+ if time_embed_dim % 2 != 0:
359
+ raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
360
+ self.time_proj = GaussianFourierProjection(
361
+ time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
362
+ )
363
+ timestep_input_dim = time_embed_dim
364
+ elif time_embedding_type == "positional":
365
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
366
+
367
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
368
+ timestep_input_dim = block_out_channels[0]
369
+ else:
370
+ raise ValueError(
371
+ f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
372
+ )
373
+
374
+ self.time_embedding = TimestepEmbedding(
375
+ timestep_input_dim,
376
+ time_embed_dim,
377
+ act_fn=act_fn,
378
+ post_act_fn=timestep_post_act,
379
+ cond_proj_dim=time_cond_proj_dim,
380
+ )
381
+
382
+ if encoder_hid_dim_type is None and encoder_hid_dim is not None:
383
+ encoder_hid_dim_type = "text_proj"
384
+ self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
385
+ logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
386
+
387
+ if encoder_hid_dim is None and encoder_hid_dim_type is not None:
388
+ raise ValueError(
389
+ f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
390
+ )
391
+
392
+ if encoder_hid_dim_type == "text_proj":
393
+ self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
394
+ elif encoder_hid_dim_type == "text_image_proj":
395
+ # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
396
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
397
+ # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
398
+ self.encoder_hid_proj = TextImageProjection(
399
+ text_embed_dim=encoder_hid_dim,
400
+ image_embed_dim=cross_attention_dim,
401
+ cross_attention_dim=cross_attention_dim,
402
+ )
403
+ elif encoder_hid_dim_type == "image_proj":
404
+ # Kandinsky 2.2
405
+ self.encoder_hid_proj = ImageProjection(
406
+ image_embed_dim=encoder_hid_dim,
407
+ cross_attention_dim=cross_attention_dim,
408
+ )
409
+ elif encoder_hid_dim_type is not None:
410
+ raise ValueError(
411
+ f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
412
+ )
413
+ else:
414
+ self.encoder_hid_proj = None
415
+
416
+ # class embedding
417
+ if class_embed_type is None and num_class_embeds is not None:
418
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
419
+ elif class_embed_type == "timestep":
420
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
421
+ elif class_embed_type == "identity":
422
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
423
+ elif class_embed_type == "projection":
424
+ if projection_class_embeddings_input_dim is None:
425
+ raise ValueError(
426
+ "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
427
+ )
428
+ # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
429
+ # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
430
+ # 2. it projects from an arbitrary input dimension.
431
+ #
432
+ # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
433
+ # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
434
+ # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
435
+ self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
436
+ elif class_embed_type == "simple_projection":
437
+ if projection_class_embeddings_input_dim is None:
438
+ raise ValueError(
439
+ "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
440
+ )
441
+ self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)
442
+ else:
443
+ self.class_embedding = None
444
+
445
+ if addition_embed_type == "text":
446
+ if encoder_hid_dim is not None:
447
+ text_time_embedding_from_dim = encoder_hid_dim
448
+ else:
449
+ text_time_embedding_from_dim = cross_attention_dim
450
+
451
+ self.add_embedding = TextTimeEmbedding(
452
+ text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
453
+ )
454
+ elif addition_embed_type == "text_image":
455
+ # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
456
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
457
+ # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
458
+ self.add_embedding = TextImageTimeEmbedding(
459
+ text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
460
+ )
461
+ elif addition_embed_type == "text_time":
462
+ self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
463
+ self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
464
+ elif addition_embed_type == "image":
465
+ # Kandinsky 2.2
466
+ self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
467
+ elif addition_embed_type == "image_hint":
468
+ # Kandinsky 2.2 ControlNet
469
+ self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
470
+ elif addition_embed_type is not None:
471
+ raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
472
+
473
+ if time_embedding_act_fn is None:
474
+ self.time_embed_act = None
475
+ else:
476
+ self.time_embed_act = get_activation(time_embedding_act_fn)
477
+
478
+ self.camera_embedding = nn.Sequential(
479
+ nn.Linear(camera_input_dim, time_embed_dim),
480
+ nn.SiLU(),
481
+ nn.Linear(time_embed_dim, time_embed_dim),
482
+ )
483
+
484
+ self.down_blocks = nn.ModuleList([])
485
+ self.up_blocks = nn.ModuleList([])
486
+
487
+ if isinstance(only_cross_attention, bool):
488
+ if mid_block_only_cross_attention is None:
489
+ mid_block_only_cross_attention = only_cross_attention
490
+
491
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
492
+
493
+ if mid_block_only_cross_attention is None:
494
+ mid_block_only_cross_attention = False
495
+
496
+ if isinstance(num_attention_heads, int):
497
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
498
+
499
+ if isinstance(attention_head_dim, int):
500
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
501
+
502
+ if isinstance(cross_attention_dim, int):
503
+ cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
504
+
505
+ if isinstance(layers_per_block, int):
506
+ layers_per_block = [layers_per_block] * len(down_block_types)
507
+
508
+ if isinstance(transformer_layers_per_block, int):
509
+ transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
510
+
511
+ if class_embeddings_concat:
512
+ # The time embeddings are concatenated with the class embeddings. The dimension of the
513
+ # time embeddings passed to the down, middle, and up blocks is twice the dimension of the
514
+ # regular time embeddings
515
+ blocks_time_embed_dim = time_embed_dim * 2
516
+ else:
517
+ blocks_time_embed_dim = time_embed_dim
518
+
519
+ # down
520
+ output_channel = block_out_channels[0]
521
+ for i, down_block_type in enumerate(down_block_types):
522
+ input_channel = output_channel
523
+ output_channel = block_out_channels[i]
524
+ is_final_block = i == len(block_out_channels) - 1
525
+
526
+ down_block = get_down_block(
527
+ down_block_type,
528
+ num_layers=layers_per_block[i],
529
+ transformer_layers_per_block=transformer_layers_per_block[i],
530
+ in_channels=input_channel,
531
+ out_channels=output_channel,
532
+ temb_channels=blocks_time_embed_dim,
533
+ add_downsample=not is_final_block,
534
+ resnet_eps=norm_eps,
535
+ resnet_act_fn=act_fn,
536
+ resnet_groups=norm_num_groups,
537
+ cross_attention_dim=cross_attention_dim[i],
538
+ num_attention_heads=num_attention_heads[i],
539
+ downsample_padding=downsample_padding,
540
+ dual_cross_attention=dual_cross_attention,
541
+ use_linear_projection=use_linear_projection,
542
+ only_cross_attention=only_cross_attention[i],
543
+ upcast_attention=upcast_attention,
544
+ resnet_time_scale_shift=resnet_time_scale_shift,
545
+ resnet_skip_time_act=resnet_skip_time_act,
546
+ resnet_out_scale_factor=resnet_out_scale_factor,
547
+ cross_attention_norm=cross_attention_norm,
548
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
549
+ num_views=num_views,
550
+ joint_attention=joint_attention,
551
+ joint_attention_twice=joint_attention_twice,
552
+ multiview_attention=multiview_attention,
553
+ cross_domain_attention=cross_domain_attention
554
+ )
555
+ self.down_blocks.append(down_block)
556
+
557
+ # mid
558
+ if mid_block_type == "UNetMidBlock2DCrossAttn":
559
+ self.mid_block = UNetMidBlock2DCrossAttn(
560
+ transformer_layers_per_block=transformer_layers_per_block[-1],
561
+ in_channels=block_out_channels[-1],
562
+ temb_channels=blocks_time_embed_dim,
563
+ resnet_eps=norm_eps,
564
+ resnet_act_fn=act_fn,
565
+ output_scale_factor=mid_block_scale_factor,
566
+ resnet_time_scale_shift=resnet_time_scale_shift,
567
+ cross_attention_dim=cross_attention_dim[-1],
568
+ num_attention_heads=num_attention_heads[-1],
569
+ resnet_groups=norm_num_groups,
570
+ dual_cross_attention=dual_cross_attention,
571
+ use_linear_projection=use_linear_projection,
572
+ upcast_attention=upcast_attention,
573
+ )
574
+ # custom MV2D attention block
575
+ elif mid_block_type == "UNetMidBlockMV2DCrossAttn":
576
+ self.mid_block = UNetMidBlockMV2DCrossAttn(
577
+ transformer_layers_per_block=transformer_layers_per_block[-1],
578
+ in_channels=block_out_channels[-1],
579
+ temb_channels=blocks_time_embed_dim,
580
+ resnet_eps=norm_eps,
581
+ resnet_act_fn=act_fn,
582
+ output_scale_factor=mid_block_scale_factor,
583
+ resnet_time_scale_shift=resnet_time_scale_shift,
584
+ cross_attention_dim=cross_attention_dim[-1],
585
+ num_attention_heads=num_attention_heads[-1],
586
+ resnet_groups=norm_num_groups,
587
+ dual_cross_attention=dual_cross_attention,
588
+ use_linear_projection=use_linear_projection,
589
+ upcast_attention=upcast_attention,
590
+ num_views=num_views,
591
+ joint_attention=joint_attention,
592
+ joint_attention_twice=joint_attention_twice,
593
+ multiview_attention=multiview_attention,
594
+ cross_domain_attention=cross_domain_attention
595
+ )
596
+ elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn":
597
+ self.mid_block = UNetMidBlock2DSimpleCrossAttn(
598
+ in_channels=block_out_channels[-1],
599
+ temb_channels=blocks_time_embed_dim,
600
+ resnet_eps=norm_eps,
601
+ resnet_act_fn=act_fn,
602
+ output_scale_factor=mid_block_scale_factor,
603
+ cross_attention_dim=cross_attention_dim[-1],
604
+ attention_head_dim=attention_head_dim[-1],
605
+ resnet_groups=norm_num_groups,
606
+ resnet_time_scale_shift=resnet_time_scale_shift,
607
+ skip_time_act=resnet_skip_time_act,
608
+ only_cross_attention=mid_block_only_cross_attention,
609
+ cross_attention_norm=cross_attention_norm,
610
+ )
611
+ elif mid_block_type is None:
612
+ self.mid_block = None
613
+ else:
614
+ raise ValueError(f"unknown mid_block_type : {mid_block_type}")
615
+
616
+ # count how many layers upsample the images
617
+ self.num_upsamplers = 0
618
+
619
+ # up
620
+ reversed_block_out_channels = list(reversed(block_out_channels))
621
+ reversed_num_attention_heads = list(reversed(num_attention_heads))
622
+ reversed_layers_per_block = list(reversed(layers_per_block))
623
+ reversed_cross_attention_dim = list(reversed(cross_attention_dim))
624
+ reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
625
+ only_cross_attention = list(reversed(only_cross_attention))
626
+
627
+ output_channel = reversed_block_out_channels[0]
628
+ for i, up_block_type in enumerate(up_block_types):
629
+ is_final_block = i == len(block_out_channels) - 1
630
+
631
+ prev_output_channel = output_channel
632
+ output_channel = reversed_block_out_channels[i]
633
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
634
+
635
+ # add upsample block for all BUT final layer
636
+ if not is_final_block:
637
+ add_upsample = True
638
+ self.num_upsamplers += 1
639
+ else:
640
+ add_upsample = False
641
+
642
+ up_block = get_up_block(
643
+ up_block_type,
644
+ num_layers=reversed_layers_per_block[i] + 1,
645
+ transformer_layers_per_block=reversed_transformer_layers_per_block[i],
646
+ in_channels=input_channel,
647
+ out_channels=output_channel,
648
+ prev_output_channel=prev_output_channel,
649
+ temb_channels=blocks_time_embed_dim,
650
+ add_upsample=add_upsample,
651
+ resnet_eps=norm_eps,
652
+ resnet_act_fn=act_fn,
653
+ resnet_groups=norm_num_groups,
654
+ cross_attention_dim=reversed_cross_attention_dim[i],
655
+ num_attention_heads=reversed_num_attention_heads[i],
656
+ dual_cross_attention=dual_cross_attention,
657
+ use_linear_projection=use_linear_projection,
658
+ only_cross_attention=only_cross_attention[i],
659
+ upcast_attention=upcast_attention,
660
+ resnet_time_scale_shift=resnet_time_scale_shift,
661
+ resnet_skip_time_act=resnet_skip_time_act,
662
+ resnet_out_scale_factor=resnet_out_scale_factor,
663
+ cross_attention_norm=cross_attention_norm,
664
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
665
+ num_views=num_views,
666
+ joint_attention=joint_attention,
667
+ joint_attention_twice=joint_attention_twice,
668
+ multiview_attention=multiview_attention,
669
+ cross_domain_attention=cross_domain_attention
670
+ )
671
+ self.up_blocks.append(up_block)
672
+ prev_output_channel = output_channel
673
+
674
+ self.up_blocks[3].attentions[2].transformer_blocks[0].attn1.to_q = _LoRACompatibleLinear()
675
+ self.up_blocks[3].attentions[2].transformer_blocks[0].attn1.to_k = _LoRACompatibleLinear()
676
+ self.up_blocks[3].attentions[2].transformer_blocks[0].attn1.to_v = _LoRACompatibleLinear()
677
+ self.up_blocks[3].attentions[2].transformer_blocks[0].attn1.to_out = nn.ModuleList([Identity(), Identity()])
678
+ self.up_blocks[3].attentions[2].transformer_blocks[0].norm2 = Identity()
679
+ self.up_blocks[3].attentions[2].transformer_blocks[0].attn2 = None
680
+ self.up_blocks[3].attentions[2].transformer_blocks[0].norm3 = Identity()
681
+ self.up_blocks[3].attentions[2].transformer_blocks[0].ff = Identity()
682
+ self.up_blocks[3].attentions[2].proj_out = Identity()
683
+
684
+ @property
685
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
686
+ r"""
687
+ Returns:
688
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
689
+ indexed by its weight name.
690
+ """
691
+ # set recursively
692
+ processors = {}
693
+
694
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
695
+ if hasattr(module, "set_processor"):
696
+ processors[f"{name}.processor"] = module.processor
697
+
698
+ for sub_name, child in module.named_children():
699
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
700
+
701
+ return processors
702
+
703
+ for name, module in self.named_children():
704
+ fn_recursive_add_processors(name, module, processors)
705
+
706
+ return processors
707
+
708
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
709
+ r"""
710
+ Sets the attention processor to use to compute attention.
711
+
712
+ Parameters:
713
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
714
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
715
+ for **all** `Attention` layers.
716
+
717
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
718
+ processor. This is strongly recommended when setting trainable attention processors.
719
+
720
+ """
721
+ count = len(self.attn_processors.keys())
722
+
723
+ if isinstance(processor, dict) and len(processor) != count:
724
+ raise ValueError(
725
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
726
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
727
+ )
728
+
729
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
730
+ if hasattr(module, "set_processor"):
731
+ if not isinstance(processor, dict):
732
+ module.set_processor(processor)
733
+ else:
734
+ module.set_processor(processor.pop(f"{name}.processor"))
735
+
736
+ for sub_name, child in module.named_children():
737
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
738
+
739
+ for name, module in self.named_children():
740
+ fn_recursive_attn_processor(name, module, processor)
741
+
742
+ def set_default_attn_processor(self):
743
+ """
744
+ Disables custom attention processors and sets the default attention implementation.
745
+ """
746
+ self.set_attn_processor(AttnProcessor())
747
+
748
+ def set_attention_slice(self, slice_size):
749
+ r"""
750
+ Enable sliced attention computation.
751
+
752
+ When this option is enabled, the attention module splits the input tensor in slices to compute attention in
753
+ several steps. This is useful for saving some memory in exchange for a small decrease in speed.
754
+
755
+ Args:
756
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
757
+ When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
758
+ `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
759
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
760
+ must be a multiple of `slice_size`.
761
+ """
762
+ sliceable_head_dims = []
763
+
764
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
765
+ if hasattr(module, "set_attention_slice"):
766
+ sliceable_head_dims.append(module.sliceable_head_dim)
767
+
768
+ for child in module.children():
769
+ fn_recursive_retrieve_sliceable_dims(child)
770
+
771
+ # retrieve number of attention layers
772
+ for module in self.children():
773
+ fn_recursive_retrieve_sliceable_dims(module)
774
+
775
+ num_sliceable_layers = len(sliceable_head_dims)
776
+
777
+ if slice_size == "auto":
778
+ # half the attention head size is usually a good trade-off between
779
+ # speed and memory
780
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
781
+ elif slice_size == "max":
782
+ # make smallest slice possible
783
+ slice_size = num_sliceable_layers * [1]
784
+
785
+ slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
786
+
787
+ if len(slice_size) != len(sliceable_head_dims):
788
+ raise ValueError(
789
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
790
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
791
+ )
792
+
793
+ for i in range(len(slice_size)):
794
+ size = slice_size[i]
795
+ dim = sliceable_head_dims[i]
796
+ if size is not None and size > dim:
797
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
798
+
799
+ # Recursively walk through all the children.
800
+ # Any children which exposes the set_attention_slice method
801
+ # gets the message
802
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
803
+ if hasattr(module, "set_attention_slice"):
804
+ module.set_attention_slice(slice_size.pop())
805
+
806
+ for child in module.children():
807
+ fn_recursive_set_attention_slice(child, slice_size)
808
+
809
+ reversed_slice_size = list(reversed(slice_size))
810
+ for module in self.children():
811
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
812
+
813
+ def _set_gradient_checkpointing(self, module, value=False):
814
+ if isinstance(module, (CrossAttnDownBlock2D, CrossAttnDownBlockMV2D, DownBlock2D, CrossAttnUpBlock2D, CrossAttnUpBlockMV2D, UpBlock2D)):
815
+ module.gradient_checkpointing = value
816
+
817
+ def forward(
818
+ self,
819
+ sample: torch.FloatTensor,
820
+ timestep: Union[torch.Tensor, float, int],
821
+ encoder_hidden_states: torch.Tensor,
822
+ camera_matrixs: Optional[torch.Tensor] = None,
823
+ class_labels: Optional[torch.Tensor] = None,
824
+ timestep_cond: Optional[torch.Tensor] = None,
825
+ attention_mask: Optional[torch.Tensor] = None,
826
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
827
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
828
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
829
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
830
+ encoder_attention_mask: Optional[torch.Tensor] = None,
831
+ return_dict: bool = True,
832
+ ) -> Union[UNetMV2DRefOutput, Tuple]:
833
+ r"""
834
+ The [`UNet2DConditionModel`] forward method.
835
+
836
+ Args:
837
+ sample (`torch.FloatTensor`):
838
+ The noisy input tensor with the following shape `(batch, channel, height, width)`.
839
+ timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
840
+ encoder_hidden_states (`torch.FloatTensor`):
841
+ The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
842
+ encoder_attention_mask (`torch.Tensor`):
843
+ A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
844
+ `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
845
+ which adds large negative values to the attention scores corresponding to "discard" tokens.
846
+ return_dict (`bool`, *optional*, defaults to `True`):
847
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
848
+ tuple.
849
+ cross_attention_kwargs (`dict`, *optional*):
850
+ A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
851
+ added_cond_kwargs: (`dict`, *optional*):
852
+ A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that
853
+ are passed along to the UNet blocks.
854
+
855
+ Returns:
856
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
857
+ If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
858
+ a `tuple` is returned where the first element is the sample tensor.
859
+ """
860
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
861
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
862
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
863
+ # on the fly if necessary.
864
+ default_overall_up_factor = 2**self.num_upsamplers
865
+
866
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
867
+ forward_upsample_size = False
868
+ upsample_size = None
869
+
870
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
871
+ logger.info("Forward upsample size to force interpolation output size.")
872
+ forward_upsample_size = True
873
+
874
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension
875
+ # expects mask of shape:
876
+ # [batch, key_tokens]
877
+ # adds singleton query_tokens dimension:
878
+ # [batch, 1, key_tokens]
879
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
880
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
881
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
882
+ if attention_mask is not None:
883
+ # assume that mask is expressed as:
884
+ # (1 = keep, 0 = discard)
885
+ # convert mask into a bias that can be added to attention scores:
886
+ # (keep = +0, discard = -10000.0)
887
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
888
+ attention_mask = attention_mask.unsqueeze(1)
889
+
890
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
891
+ if encoder_attention_mask is not None:
892
+ encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
893
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
894
+
895
+ # 0. center input if necessary
896
+ if self.config.center_input_sample:
897
+ sample = 2 * sample - 1.0
898
+
899
+ # 1. time
900
+ timesteps = timestep
901
+ if not torch.is_tensor(timesteps):
902
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
903
+ # This would be a good case for the `match` statement (Python 3.10+)
904
+ is_mps = sample.device.type == "mps"
905
+ if isinstance(timestep, float):
906
+ dtype = torch.float32 if is_mps else torch.float64
907
+ else:
908
+ dtype = torch.int32 if is_mps else torch.int64
909
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
910
+ elif len(timesteps.shape) == 0:
911
+ timesteps = timesteps[None].to(sample.device)
912
+
913
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
914
+ timesteps = timesteps.expand(sample.shape[0])
915
+
916
+ t_emb = self.time_proj(timesteps)
917
+
918
+ # `Timesteps` does not contain any weights and will always return f32 tensors
919
+ # but time_embedding might actually be running in fp16. so we need to cast here.
920
+ # there might be better ways to encapsulate this.
921
+ t_emb = t_emb.to(dtype=sample.dtype)
922
+ emb = self.time_embedding(t_emb, timestep_cond)
923
+
924
+ if camera_matrixs is not None:
925
+ emb = torch.unsqueeze(emb, 1)
926
+ cam_emb = self.camera_embedding(camera_matrixs)
927
+ emb = emb.repeat(1,cam_emb.shape[1],1)
928
+ emb = emb + cam_emb
929
+ emb = rearrange(emb, "b f c -> (b f) c", f=emb.shape[1])
930
+
931
+ aug_emb = None
932
+
933
+ if self.class_embedding is not None and class_labels is not None:
934
+ if class_labels is None:
935
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
936
+
937
+ if self.config.class_embed_type == "timestep":
938
+ class_labels = self.time_proj(class_labels)
939
+
940
+ # `Timesteps` does not contain any weights and will always return f32 tensors
941
+ # there might be better ways to encapsulate this.
942
+ class_labels = class_labels.to(dtype=sample.dtype)
943
+
944
+ class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
945
+
946
+ if self.config.class_embeddings_concat:
947
+ emb = torch.cat([emb, class_emb], dim=-1)
948
+ else:
949
+ emb = emb + class_emb
950
+
951
+ if self.config.addition_embed_type == "text":
952
+ aug_emb = self.add_embedding(encoder_hidden_states)
953
+ elif self.config.addition_embed_type == "text_image":
954
+ # Kandinsky 2.1 - style
955
+ if "image_embeds" not in added_cond_kwargs:
956
+ raise ValueError(
957
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
958
+ )
959
+
960
+ image_embs = added_cond_kwargs.get("image_embeds")
961
+ text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
962
+ aug_emb = self.add_embedding(text_embs, image_embs)
963
+ elif self.config.addition_embed_type == "text_time":
964
+ # SDXL - style
965
+ if "text_embeds" not in added_cond_kwargs:
966
+ raise ValueError(
967
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
968
+ )
969
+ text_embeds = added_cond_kwargs.get("text_embeds")
970
+ if "time_ids" not in added_cond_kwargs:
971
+ raise ValueError(
972
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
973
+ )
974
+ time_ids = added_cond_kwargs.get("time_ids")
975
+ time_embeds = self.add_time_proj(time_ids.flatten())
976
+ time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
977
+
978
+ add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
979
+ add_embeds = add_embeds.to(emb.dtype)
980
+ aug_emb = self.add_embedding(add_embeds)
981
+ elif self.config.addition_embed_type == "image":
982
+ # Kandinsky 2.2 - style
983
+ if "image_embeds" not in added_cond_kwargs:
984
+ raise ValueError(
985
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
986
+ )
987
+ image_embs = added_cond_kwargs.get("image_embeds")
988
+ aug_emb = self.add_embedding(image_embs)
989
+ elif self.config.addition_embed_type == "image_hint":
990
+ # Kandinsky 2.2 - style
991
+ if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
992
+ raise ValueError(
993
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
994
+ )
995
+ image_embs = added_cond_kwargs.get("image_embeds")
996
+ hint = added_cond_kwargs.get("hint")
997
+ aug_emb, hint = self.add_embedding(image_embs, hint)
998
+ sample = torch.cat([sample, hint], dim=1)
999
+
1000
+ emb = emb + aug_emb if aug_emb is not None else emb
1001
+
1002
+ if self.time_embed_act is not None:
1003
+ emb = self.time_embed_act(emb)
1004
+
1005
+ if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
1006
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
1007
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
1008
+ # Kadinsky 2.1 - style
1009
+ if "image_embeds" not in added_cond_kwargs:
1010
+ raise ValueError(
1011
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
1012
+ )
1013
+
1014
+ image_embeds = added_cond_kwargs.get("image_embeds")
1015
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
1016
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":
1017
+ # Kandinsky 2.2 - style
1018
+ if "image_embeds" not in added_cond_kwargs:
1019
+ raise ValueError(
1020
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
1021
+ )
1022
+ image_embeds = added_cond_kwargs.get("image_embeds")
1023
+ encoder_hidden_states = self.encoder_hid_proj(image_embeds)
1024
+ # 2. pre-process
1025
+ sample = rearrange(sample, "b c f h w -> (b f) c h w", f=sample.shape[2])
1026
+ sample = self.conv_in(sample)
1027
+ # 3. down
1028
+
1029
+ is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
1030
+ is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None
1031
+
1032
+ down_block_res_samples = (sample,)
1033
+ for downsample_block in self.down_blocks:
1034
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
1035
+ # For t2i-adapter CrossAttnDownBlock2D
1036
+ additional_residuals = {}
1037
+ if is_adapter and len(down_block_additional_residuals) > 0:
1038
+ additional_residuals["additional_residuals"] = down_block_additional_residuals.pop(0)
1039
+
1040
+ sample, res_samples = downsample_block(
1041
+ hidden_states=sample,
1042
+ temb=emb,
1043
+ encoder_hidden_states=encoder_hidden_states,
1044
+ attention_mask=attention_mask,
1045
+ cross_attention_kwargs=cross_attention_kwargs,
1046
+ encoder_attention_mask=encoder_attention_mask,
1047
+ **additional_residuals,
1048
+ )
1049
+ else:
1050
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
1051
+
1052
+ if is_adapter and len(down_block_additional_residuals) > 0:
1053
+ sample += down_block_additional_residuals.pop(0)
1054
+
1055
+ down_block_res_samples += res_samples
1056
+
1057
+ if is_controlnet:
1058
+ new_down_block_res_samples = ()
1059
+
1060
+ for down_block_res_sample, down_block_additional_residual in zip(
1061
+ down_block_res_samples, down_block_additional_residuals
1062
+ ):
1063
+ down_block_res_sample = down_block_res_sample + down_block_additional_residual
1064
+ new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
1065
+
1066
+ down_block_res_samples = new_down_block_res_samples
1067
+
1068
+ # 4. mid
1069
+ if self.mid_block is not None:
1070
+ sample = self.mid_block(
1071
+ sample,
1072
+ emb,
1073
+ encoder_hidden_states=encoder_hidden_states,
1074
+ attention_mask=attention_mask,
1075
+ cross_attention_kwargs=cross_attention_kwargs,
1076
+ encoder_attention_mask=encoder_attention_mask,
1077
+ )
1078
+
1079
+ if is_controlnet:
1080
+ sample = sample + mid_block_additional_residual
1081
+
1082
+ # 5. up
1083
+ for i, upsample_block in enumerate(self.up_blocks):
1084
+ is_final_block = i == len(self.up_blocks) - 1
1085
+
1086
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
1087
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
1088
+
1089
+ # if we have not reached the final block and need to forward the
1090
+ # upsample size, we do it here
1091
+ if not is_final_block and forward_upsample_size:
1092
+ upsample_size = down_block_res_samples[-1].shape[2:]
1093
+
1094
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
1095
+ sample = upsample_block(
1096
+ hidden_states=sample,
1097
+ temb=emb,
1098
+ res_hidden_states_tuple=res_samples,
1099
+ encoder_hidden_states=encoder_hidden_states,
1100
+ cross_attention_kwargs=cross_attention_kwargs,
1101
+ upsample_size=upsample_size,
1102
+ attention_mask=attention_mask,
1103
+ encoder_attention_mask=encoder_attention_mask,
1104
+ )
1105
+ else:
1106
+ sample = upsample_block(
1107
+ hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
1108
+ )
1109
+
1110
+ if not return_dict:
1111
+ return (sample,)
1112
+
1113
+ return UNetMV2DRefOutput(sample=sample)
1114
+
1115
+ @classmethod
1116
+ def from_pretrained_2d(
1117
+ cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
1118
+ camera_embedding_type: str, num_views: int, sample_size: int,
1119
+ zero_init_conv_in: bool = True, zero_init_camera_projection: bool = False,
1120
+ projection_class_embeddings_input_dim: int=6, joint_attention: bool = False,
1121
+ joint_attention_twice: bool = False, multiview_attention: bool = True,
1122
+ cross_domain_attention: bool = False,
1123
+ in_channels: int = 8, out_channels: int = 4, local_crossattn=False,
1124
+ **kwargs
1125
+ ):
1126
+ r"""
1127
+ Instantiate a pretrained PyTorch model from a pretrained model configuration.
1128
+
1129
+ The model is set in evaluation mode - `model.eval()` - by default, and dropout modules are deactivated. To
1130
+ train the model, set it back in training mode with `model.train()`.
1131
+
1132
+ Parameters:
1133
+ pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
1134
+ Can be either:
1135
+
1136
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
1137
+ the Hub.
1138
+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
1139
+ with [`~ModelMixin.save_pretrained`].
1140
+
1141
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
1142
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
1143
+ is not used.
1144
+ torch_dtype (`str` or `torch.dtype`, *optional*):
1145
+ Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
1146
+ dtype is automatically derived from the model's weights.
1147
+ force_download (`bool`, *optional*, defaults to `False`):
1148
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
1149
+ cached versions if they exist.
1150
+ resume_download (`bool`, *optional*, defaults to `False`):
1151
+ Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
1152
+ incompletely downloaded files are deleted.
1153
+ proxies (`Dict[str, str]`, *optional*):
1154
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
1155
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
1156
+ output_loading_info (`bool`, *optional*, defaults to `False`):
1157
+ Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
1158
+ local_files_only(`bool`, *optional*, defaults to `False`):
1159
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
1160
+ won't be downloaded from the Hub.
1161
+ use_auth_token (`str` or *bool*, *optional*):
1162
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
1163
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
1164
+ revision (`str`, *optional*, defaults to `"main"`):
1165
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
1166
+ allowed by Git.
1167
+ from_flax (`bool`, *optional*, defaults to `False`):
1168
+ Load the model weights from a Flax checkpoint save file.
1169
+ subfolder (`str`, *optional*, defaults to `""`):
1170
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
1171
+ mirror (`str`, *optional*):
1172
+ Mirror source to resolve accessibility issues if you're downloading a model in China. We do not
1173
+ guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
1174
+ information.
1175
+ device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
1176
+ A map that specifies where each submodule should go. It doesn't need to be defined for each
1177
+ parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the
1178
+ same device.
1179
+
1180
+ Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For
1181
+ more information about each option see [designing a device
1182
+ map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
1183
+ max_memory (`Dict`, *optional*):
1184
+ A dictionary device identifier for the maximum memory. Will default to the maximum memory available for
1185
+ each GPU and the available CPU RAM if unset.
1186
+ offload_folder (`str` or `os.PathLike`, *optional*):
1187
+ The path to offload weights if `device_map` contains the value `"disk"`.
1188
+ offload_state_dict (`bool`, *optional*):
1189
+ If `True`, temporarily offloads the CPU state dict to the hard drive to avoid running out of CPU RAM if
1190
+ the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to `True`
1191
+ when there is some disk offload.
1192
+ low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
1193
+ Speed up model loading only loading the pretrained weights and not initializing the weights. This also
1194
+ tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
1195
+ Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
1196
+ argument to `True` will raise an error.
1197
+ variant (`str`, *optional*):
1198
+ Load weights from a specified `variant` filename such as `"fp16"` or `"ema"`. This is ignored when
1199
+ loading `from_flax`.
1200
+ use_safetensors (`bool`, *optional*, defaults to `None`):
1201
+ If set to `None`, the `safetensors` weights are downloaded if they're available **and** if the
1202
+ `safetensors` library is installed. If set to `True`, the model is forcibly loaded from `safetensors`
1203
+ weights. If set to `False`, `safetensors` weights are not loaded.
1204
+
1205
+ <Tip>
1206
+
1207
+ To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with
1208
+ `huggingface-cli login`. You can also activate the special
1209
+ ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use this method in a
1210
+ firewalled environment.
1211
+
1212
+ </Tip>
1213
+
1214
+ Example:
1215
+
1216
+ ```py
1217
+ from diffusers import UNet2DConditionModel
1218
+
1219
+ unet = UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet")
1220
+ ```
1221
+
1222
+ If you get the error message below, you need to finetune the weights for your downstream task:
1223
+
1224
+ ```bash
1225
+ Some weights of UNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match:
1226
+ - conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated
1227
+ You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
1228
+ ```
1229
+ """
1230
+ cache_dir = kwargs.pop("cache_dir", HF_HUB_CACHE)
1231
+ ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
1232
+ force_download = kwargs.pop("force_download", False)
1233
+ from_flax = kwargs.pop("from_flax", False)
1234
+ resume_download = kwargs.pop("resume_download", False)
1235
+ proxies = kwargs.pop("proxies", None)
1236
+ output_loading_info = kwargs.pop("output_loading_info", False)
1237
+ local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
1238
+ use_auth_token = kwargs.pop("use_auth_token", None)
1239
+ revision = kwargs.pop("revision", None)
1240
+ torch_dtype = kwargs.pop("torch_dtype", None)
1241
+ subfolder = kwargs.pop("subfolder", None)
1242
+ device_map = kwargs.pop("device_map", None)
1243
+ max_memory = kwargs.pop("max_memory", None)
1244
+ offload_folder = kwargs.pop("offload_folder", None)
1245
+ offload_state_dict = kwargs.pop("offload_state_dict", False)
1246
+ variant = kwargs.pop("variant", None)
1247
+ use_safetensors = kwargs.pop("use_safetensors", None)
1248
+
1249
+ # if use_safetensors and not is_safetensors_available():
1250
+ # raise ValueError(
1251
+ # "`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetensors"
1252
+ # )
1253
+
1254
+ allow_pickle = False
1255
+ if use_safetensors is None:
1256
+ # use_safetensors = is_safetensors_available()
1257
+ use_safetensors = False
1258
+ allow_pickle = True
1259
+
1260
+ if device_map is not None and not is_accelerate_available():
1261
+ raise NotImplementedError(
1262
+ "Loading and dispatching requires `accelerate`. Please make sure to install accelerate or set"
1263
+ " `device_map=None`. You can install accelerate with `pip install accelerate`."
1264
+ )
1265
+
1266
+ # Check if we can handle device_map and dispatching the weights
1267
+ if device_map is not None and not is_torch_version(">=", "1.9.0"):
1268
+ raise NotImplementedError(
1269
+ "Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
1270
+ " `device_map=None`."
1271
+ )
1272
+
1273
+ # Load config if we don't provide a configuration
1274
+ config_path = pretrained_model_name_or_path
1275
+
1276
+ user_agent = {
1277
+ "diffusers": __version__,
1278
+ "file_type": "model",
1279
+ "framework": "pytorch",
1280
+ }
1281
+
1282
+ # load config
1283
+ config, unused_kwargs, commit_hash = cls.load_config(
1284
+ config_path,
1285
+ cache_dir=cache_dir,
1286
+ return_unused_kwargs=True,
1287
+ return_commit_hash=True,
1288
+ force_download=force_download,
1289
+ resume_download=resume_download,
1290
+ proxies=proxies,
1291
+ local_files_only=local_files_only,
1292
+ use_auth_token=use_auth_token,
1293
+ revision=revision,
1294
+ subfolder=subfolder,
1295
+ device_map=device_map,
1296
+ max_memory=max_memory,
1297
+ offload_folder=offload_folder,
1298
+ offload_state_dict=offload_state_dict,
1299
+ user_agent=user_agent,
1300
+ **kwargs,
1301
+ )
1302
+
1303
+ # modify config
1304
+ config["_class_name"] = cls.__name__
1305
+ config['in_channels'] = in_channels
1306
+ config['out_channels'] = out_channels
1307
+ config['sample_size'] = sample_size # training resolution
1308
+ config['num_views'] = num_views
1309
+ config['joint_attention'] = joint_attention
1310
+ config['joint_attention_twice'] = joint_attention_twice
1311
+ config['multiview_attention'] = multiview_attention
1312
+ config['cross_domain_attention'] = cross_domain_attention
1313
+ config["down_block_types"] = [
1314
+ "CrossAttnDownBlockMV2D",
1315
+ "CrossAttnDownBlockMV2D",
1316
+ "CrossAttnDownBlockMV2D",
1317
+ "DownBlock2D"
1318
+ ]
1319
+ config['mid_block_type'] = "UNetMidBlockMV2DCrossAttn"
1320
+ config["up_block_types"] = [
1321
+ "UpBlock2D",
1322
+ "CrossAttnUpBlockMV2D",
1323
+ "CrossAttnUpBlockMV2D",
1324
+ "CrossAttnUpBlockMV2D"
1325
+ ]
1326
+ config['class_embed_type'] = 'projection'
1327
+ if camera_embedding_type == 'e_de_da_sincos':
1328
+ config['projection_class_embeddings_input_dim'] = projection_class_embeddings_input_dim # default 6
1329
+ else:
1330
+ raise NotImplementedError
1331
+
1332
+ # load model
1333
+ model_file = None
1334
+ if from_flax:
1335
+ raise NotImplementedError
1336
+ else:
1337
+ if use_safetensors:
1338
+ try:
1339
+ model_file = _get_model_file(
1340
+ pretrained_model_name_or_path,
1341
+ weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant),
1342
+ cache_dir=cache_dir,
1343
+ force_download=force_download,
1344
+ resume_download=resume_download,
1345
+ proxies=proxies,
1346
+ local_files_only=local_files_only,
1347
+ use_auth_token=use_auth_token,
1348
+ revision=revision,
1349
+ subfolder=subfolder,
1350
+ user_agent=user_agent,
1351
+ commit_hash=commit_hash,
1352
+ )
1353
+ except IOError as e:
1354
+ if not allow_pickle:
1355
+ raise e
1356
+ pass
1357
+ if model_file is None:
1358
+ model_file = _get_model_file(
1359
+ pretrained_model_name_or_path,
1360
+ weights_name=_add_variant(WEIGHTS_NAME, variant),
1361
+ cache_dir=cache_dir,
1362
+ force_download=force_download,
1363
+ resume_download=resume_download,
1364
+ proxies=proxies,
1365
+ local_files_only=local_files_only,
1366
+ use_auth_token=use_auth_token,
1367
+ revision=revision,
1368
+ subfolder=subfolder,
1369
+ user_agent=user_agent,
1370
+ commit_hash=commit_hash,
1371
+ )
1372
+
1373
+ model = cls.from_config(config, **unused_kwargs)
1374
+ if local_crossattn:
1375
+ unet_lora_attn_procs = dict()
1376
+ for name, _ in model.attn_processors.items():
1377
+ if not name.endswith("attn1.processor"):
1378
+ default_attn_proc = AttnProcessor()
1379
+ elif is_xformers_available():
1380
+ default_attn_proc = XFormersMVAttnProcessor()
1381
+ else:
1382
+ default_attn_proc = MVAttnProcessor()
1383
+ unet_lora_attn_procs[name] = ReferenceOnlyAttnProc(
1384
+ default_attn_proc, enabled=name.endswith("attn1.processor"), name=name
1385
+ )
1386
+ model.set_attn_processor(unet_lora_attn_procs)
1387
+ state_dict = load_state_dict(model_file, variant=variant)
1388
+ model._convert_deprecated_attention_blocks(state_dict)
1389
+
1390
+ conv_in_weight = state_dict['conv_in.weight']
1391
+ conv_out_weight = state_dict['conv_out.weight']
1392
+ model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model_2d(
1393
+ model,
1394
+ state_dict,
1395
+ model_file,
1396
+ pretrained_model_name_or_path,
1397
+ ignore_mismatched_sizes=True,
1398
+ )
1399
+ if any([key == 'conv_in.weight' for key, _, _ in mismatched_keys]):
1400
+ # initialize from the original SD structure
1401
+ model.conv_in.weight.data[:,:4] = conv_in_weight
1402
+
1403
+ # whether to place all zero to new layers?
1404
+ if zero_init_conv_in:
1405
+ model.conv_in.weight.data[:,4:] = 0.
1406
+
1407
+ if any([key == 'conv_out.weight' for key, _, _ in mismatched_keys]):
1408
+ # initialize from the original SD structure
1409
+ model.conv_out.weight.data[:,:4] = conv_out_weight
1410
+ if out_channels == 8: # copy for the last 4 channels
1411
+ model.conv_out.weight.data[:, 4:] = conv_out_weight
1412
+
1413
+ if zero_init_camera_projection:
1414
+ for p in model.class_embedding.parameters():
1415
+ torch.nn.init.zeros_(p)
1416
+
1417
+ loading_info = {
1418
+ "missing_keys": missing_keys,
1419
+ "unexpected_keys": unexpected_keys,
1420
+ "mismatched_keys": mismatched_keys,
1421
+ "error_msgs": error_msgs,
1422
+ }
1423
+
1424
+ if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
1425
+ raise ValueError(
1426
+ f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}."
1427
+ )
1428
+ elif torch_dtype is not None:
1429
+ model = model.to(torch_dtype)
1430
+
1431
+ model.register_to_config(_name_or_path=pretrained_model_name_or_path)
1432
+
1433
+ # Set model in evaluation mode to deactivate DropOut modules by default
1434
+ model.eval()
1435
+ if output_loading_info:
1436
+ return model, loading_info
1437
+
1438
+ return model
1439
+
1440
+ @classmethod
1441
+ def _load_pretrained_model_2d(
1442
+ cls,
1443
+ model,
1444
+ state_dict,
1445
+ resolved_archive_file,
1446
+ pretrained_model_name_or_path,
1447
+ ignore_mismatched_sizes=False,
1448
+ ):
1449
+ # Retrieve missing & unexpected_keys
1450
+ model_state_dict = model.state_dict()
1451
+ loaded_keys = list(state_dict.keys())
1452
+
1453
+ expected_keys = list(model_state_dict.keys())
1454
+
1455
+ original_loaded_keys = loaded_keys
1456
+
1457
+ missing_keys = list(set(expected_keys) - set(loaded_keys))
1458
+ unexpected_keys = list(set(loaded_keys) - set(expected_keys))
1459
+
1460
+ # Make sure we are able to load base models as well as derived models (with heads)
1461
+ model_to_load = model
1462
+
1463
+ def _find_mismatched_keys(
1464
+ state_dict,
1465
+ model_state_dict,
1466
+ loaded_keys,
1467
+ ignore_mismatched_sizes,
1468
+ ):
1469
+ mismatched_keys = []
1470
+ if ignore_mismatched_sizes:
1471
+ for checkpoint_key in loaded_keys:
1472
+ model_key = checkpoint_key
1473
+
1474
+ if (
1475
+ model_key in model_state_dict
1476
+ and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
1477
+ ):
1478
+ mismatched_keys.append(
1479
+ (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
1480
+ )
1481
+ del state_dict[checkpoint_key]
1482
+ return mismatched_keys
1483
+
1484
+ if state_dict is not None:
1485
+ # Whole checkpoint
1486
+ mismatched_keys = _find_mismatched_keys(
1487
+ state_dict,
1488
+ model_state_dict,
1489
+ original_loaded_keys,
1490
+ ignore_mismatched_sizes,
1491
+ )
1492
+ error_msgs = _load_state_dict_into_model(model_to_load, state_dict)
1493
+
1494
+ if len(error_msgs) > 0:
1495
+ error_msg = "\n\t".join(error_msgs)
1496
+ if "size mismatch" in error_msg:
1497
+ error_msg += (
1498
+ "\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method."
1499
+ )
1500
+ raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
1501
+
1502
+ if len(unexpected_keys) > 0:
1503
+ logger.warning(
1504
+ f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
1505
+ f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
1506
+ f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task"
1507
+ " or with another architecture (e.g. initializing a BertForSequenceClassification model from a"
1508
+ " BertForPreTraining model).\n- This IS NOT expected if you are initializing"
1509
+ f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly"
1510
+ " identical (initializing a BertForSequenceClassification model from a"
1511
+ " BertForSequenceClassification model)."
1512
+ )
1513
+ else:
1514
+ logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
1515
+ if len(missing_keys) > 0:
1516
+ logger.warning(
1517
+ f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
1518
+ f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
1519
+ " TRAIN this model on a down-stream task to be able to use it for predictions and inference."
1520
+ )
1521
+ elif len(mismatched_keys) == 0:
1522
+ logger.info(
1523
+ f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at"
1524
+ f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the"
1525
+ f" checkpoint was trained on, you can already use {model.__class__.__name__} for predictions"
1526
+ " without further training."
1527
+ )
1528
+ if len(mismatched_keys) > 0:
1529
+ mismatched_warning = "\n".join(
1530
+ [
1531
+ f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
1532
+ for key, shape1, shape2 in mismatched_keys
1533
+ ]
1534
+ )
1535
+ logger.warning(
1536
+ f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
1537
+ f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not"
1538
+ f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be"
1539
+ " able to use it for predictions and inference."
1540
+ )
1541
+
1542
+ return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs
1543
+
canonicalize/pipeline_canonicalize.py ADDED
@@ -0,0 +1,518 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
2
+
3
+ import tqdm
4
+
5
+ import inspect
6
+ from typing import Callable, List, Optional, Union
7
+ from dataclasses import dataclass
8
+
9
+ import numpy as np
10
+ import torch
11
+
12
+ from diffusers.utils import is_accelerate_available
13
+ from packaging import version
14
+ from transformers import CLIPTextModel, CLIPTokenizer
15
+ import torchvision.transforms.functional as TF
16
+
17
+ from diffusers.configuration_utils import FrozenDict
18
+ from diffusers.models import AutoencoderKL
19
+ from diffusers import DiffusionPipeline
20
+ from diffusers.schedulers import (
21
+ DDIMScheduler,
22
+ DPMSolverMultistepScheduler,
23
+ EulerAncestralDiscreteScheduler,
24
+ EulerDiscreteScheduler,
25
+ LMSDiscreteScheduler,
26
+ PNDMScheduler,
27
+ )
28
+ from diffusers.utils import deprecate, logging, BaseOutput
29
+
30
+ from einops import rearrange
31
+
32
+ from canonicalize.models.unet import UNet3DConditionModel
33
+ from torchvision.transforms import InterpolationMode
34
+
35
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
36
+
37
+ class CanonicalizationPipeline(DiffusionPipeline):
38
+ _optional_components = []
39
+
40
+ def __init__(
41
+ self,
42
+ vae: AutoencoderKL,
43
+ text_encoder: CLIPTextModel,
44
+ tokenizer: CLIPTokenizer,
45
+ unet: UNet3DConditionModel,
46
+
47
+ scheduler: Union[
48
+ DDIMScheduler,
49
+ PNDMScheduler,
50
+ LMSDiscreteScheduler,
51
+ EulerDiscreteScheduler,
52
+ EulerAncestralDiscreteScheduler,
53
+ DPMSolverMultistepScheduler,
54
+ ],
55
+ ref_unet = None,
56
+ feature_extractor=None,
57
+ image_encoder=None
58
+ ):
59
+ super().__init__()
60
+ self.ref_unet = ref_unet
61
+ self.feature_extractor = feature_extractor
62
+ self.image_encoder = image_encoder
63
+
64
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
65
+ deprecation_message = (
66
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
67
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
68
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
69
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
70
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
71
+ " file"
72
+ )
73
+ deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
74
+ new_config = dict(scheduler.config)
75
+ new_config["steps_offset"] = 1
76
+ scheduler._internal_dict = FrozenDict(new_config)
77
+
78
+ if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
79
+ deprecation_message = (
80
+ f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
81
+ " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
82
+ " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
83
+ " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
84
+ " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
85
+ )
86
+ deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
87
+ new_config = dict(scheduler.config)
88
+ new_config["clip_sample"] = False
89
+ scheduler._internal_dict = FrozenDict(new_config)
90
+
91
+ is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
92
+ version.parse(unet.config._diffusers_version).base_version
93
+ ) < version.parse("0.9.0.dev0")
94
+ is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
95
+ if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
96
+ deprecation_message = (
97
+ "The configuration file of the unet has set the default `sample_size` to smaller than"
98
+ " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
99
+ " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
100
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
101
+ " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
102
+ " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
103
+ " in the config might lead to incorrect results in future versions. If you have downloaded this"
104
+ " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
105
+ " the `unet/config.json` file"
106
+ )
107
+ deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
108
+ new_config = dict(unet.config)
109
+ new_config["sample_size"] = 64
110
+ unet._internal_dict = FrozenDict(new_config)
111
+
112
+ self.register_modules(
113
+ vae=vae,
114
+ text_encoder=text_encoder,
115
+ tokenizer=tokenizer,
116
+ unet=unet,
117
+ scheduler=scheduler,
118
+ )
119
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
120
+
121
+ def enable_vae_slicing(self):
122
+ self.vae.enable_slicing()
123
+
124
+ def disable_vae_slicing(self):
125
+ self.vae.disable_slicing()
126
+
127
+ def enable_sequential_cpu_offload(self, gpu_id=0):
128
+ if is_accelerate_available():
129
+ from accelerate import cpu_offload
130
+ else:
131
+ raise ImportError("Please install accelerate via `pip install accelerate`")
132
+
133
+ device = torch.device(f"cuda:{gpu_id}")
134
+
135
+ for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
136
+ if cpu_offloaded_model is not None:
137
+ cpu_offload(cpu_offloaded_model, device)
138
+
139
+
140
+ @property
141
+ def _execution_device(self):
142
+ if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
143
+ return self.device
144
+ for module in self.unet.modules():
145
+ if (
146
+ hasattr(module, "_hf_hook")
147
+ and hasattr(module._hf_hook, "execution_device")
148
+ and module._hf_hook.execution_device is not None
149
+ ):
150
+ return torch.device(module._hf_hook.execution_device)
151
+ return self.device
152
+
153
+ def _encode_image(self, image_pil, device, num_images_per_prompt, do_classifier_free_guidance, img_proj=None):
154
+ dtype = next(self.image_encoder.parameters()).dtype
155
+
156
+ # image encoding
157
+ clip_image_mean = torch.as_tensor(self.feature_extractor.image_mean)[:,None,None].to(device, dtype=torch.float32)
158
+ clip_image_std = torch.as_tensor(self.feature_extractor.image_std)[:,None,None].to(device, dtype=torch.float32)
159
+ imgs_in_proc = TF.resize(image_pil, (self.feature_extractor.crop_size['height'], self.feature_extractor.crop_size['width']), interpolation=InterpolationMode.BICUBIC)
160
+ # do the normalization in float32 to preserve precision
161
+ imgs_in_proc = ((imgs_in_proc.float() - clip_image_mean) / clip_image_std).to(dtype)
162
+ if img_proj is None:
163
+ # (B*Nv, 1, 768)
164
+ image_embeddings = self.image_encoder(imgs_in_proc).image_embeds.unsqueeze(1)
165
+ # duplicate image embeddings for each generation per prompt, using mps friendly method
166
+ # Note: repeat differently from official pipelines
167
+ # B1B2B3B4 -> B1B2B3B4B1B2B3B4
168
+ bs_embed, seq_len, _ = image_embeddings.shape
169
+ image_embeddings = image_embeddings.repeat(num_images_per_prompt, 1, 1)
170
+ if do_classifier_free_guidance:
171
+ negative_prompt_embeds = torch.zeros_like(image_embeddings)
172
+
173
+ # For classifier free guidance, we need to do two forward passes.
174
+ # Here we concatenate the unconditional and text embeddings into a single batch
175
+ # to avoid doing two forward passes
176
+ image_embeddings = torch.cat([negative_prompt_embeds, image_embeddings])
177
+ else:
178
+ if do_classifier_free_guidance:
179
+ negative_image_proc = torch.zeros_like(imgs_in_proc)
180
+
181
+ # For classifier free guidance, we need to do two forward passes.
182
+ # Here we concatenate the unconditional and text embeddings into a single batch
183
+ # to avoid doing two forward passes
184
+ imgs_in_proc = torch.cat([negative_image_proc, imgs_in_proc])
185
+
186
+ image_embeds = image_encoder(imgs_in_proc, output_hidden_states=True).hidden_states[-2]
187
+ image_embeddings = img_proj(image_embeds)
188
+
189
+ image_latents = self.vae.encode(image_pil* 2.0 - 1.0).latent_dist.mode() * self.vae.config.scaling_factor
190
+
191
+ # Note: repeat differently from official pipelines
192
+ # B1B2B3B4 -> B1B2B3B4B1B2B3B4
193
+ image_latents = image_latents.repeat(num_images_per_prompt, 1, 1, 1)
194
+ return image_embeddings, image_latents
195
+
196
+ def _encode_prompt(self, prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt):
197
+ batch_size = len(prompt) if isinstance(prompt, list) else 1
198
+
199
+ text_inputs = self.tokenizer(
200
+ prompt,
201
+ padding="max_length",
202
+ max_length=self.tokenizer.model_max_length,
203
+ truncation=True,
204
+ return_tensors="pt",
205
+ )
206
+ text_input_ids = text_inputs.input_ids
207
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
208
+
209
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
210
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
211
+ logger.warning(
212
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
213
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
214
+ )
215
+
216
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
217
+ attention_mask = text_inputs.attention_mask.to(device)
218
+ else:
219
+ attention_mask = None
220
+
221
+ text_embeddings = self.text_encoder(
222
+ text_input_ids.to(device),
223
+ attention_mask=attention_mask,
224
+ )
225
+ text_embeddings = text_embeddings[0]
226
+
227
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
228
+ bs_embed, seq_len, _ = text_embeddings.shape
229
+ text_embeddings = text_embeddings.repeat(1, num_videos_per_prompt, 1)
230
+ text_embeddings = text_embeddings.view(bs_embed * num_videos_per_prompt, seq_len, -1)
231
+
232
+ # get unconditional embeddings for classifier free guidance
233
+ if do_classifier_free_guidance:
234
+ uncond_tokens: List[str]
235
+ if negative_prompt is None:
236
+ uncond_tokens = [""] * batch_size
237
+ elif type(prompt) is not type(negative_prompt):
238
+ raise TypeError(
239
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
240
+ f" {type(prompt)}."
241
+ )
242
+ elif isinstance(negative_prompt, str):
243
+ uncond_tokens = [negative_prompt]
244
+ elif batch_size != len(negative_prompt):
245
+ raise ValueError(
246
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
247
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
248
+ " the batch size of `prompt`."
249
+ )
250
+ else:
251
+ uncond_tokens = negative_prompt
252
+
253
+ max_length = text_input_ids.shape[-1]
254
+ uncond_input = self.tokenizer(
255
+ uncond_tokens,
256
+ padding="max_length",
257
+ max_length=max_length,
258
+ truncation=True,
259
+ return_tensors="pt",
260
+ )
261
+
262
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
263
+ attention_mask = uncond_input.attention_mask.to(device)
264
+ else:
265
+ attention_mask = None
266
+
267
+ uncond_embeddings = self.text_encoder(
268
+ uncond_input.input_ids.to(device),
269
+ attention_mask=attention_mask,
270
+ )
271
+ uncond_embeddings = uncond_embeddings[0]
272
+
273
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
274
+ seq_len = uncond_embeddings.shape[1]
275
+ uncond_embeddings = uncond_embeddings.repeat(1, num_videos_per_prompt, 1)
276
+ uncond_embeddings = uncond_embeddings.view(batch_size * num_videos_per_prompt, seq_len, -1)
277
+
278
+ # For classifier free guidance, we need to do two forward passes.
279
+ # Here we concatenate the unconditional and text embeddings into a single batch
280
+ # to avoid doing two forward passes
281
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
282
+
283
+ return text_embeddings
284
+
285
+ def decode_latents(self, latents):
286
+ video_length = latents.shape[2]
287
+ latents = 1 / 0.18215 * latents
288
+ latents = rearrange(latents, "b c f h w -> (b f) c h w")
289
+ video = self.vae.decode(latents).sample
290
+ video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
291
+ video = (video / 2 + 0.5).clamp(0, 1)
292
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
293
+ video = video.cpu().float().numpy()
294
+ return video
295
+
296
+ def prepare_extra_step_kwargs(self, generator, eta):
297
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
298
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
299
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
300
+ # and should be between [0, 1]
301
+
302
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
303
+ extra_step_kwargs = {}
304
+ if accepts_eta:
305
+ extra_step_kwargs["eta"] = eta
306
+
307
+ # check if the scheduler accepts generator
308
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
309
+ if accepts_generator:
310
+ extra_step_kwargs["generator"] = generator
311
+ return extra_step_kwargs
312
+
313
+ def check_inputs(self, prompt, height, width, callback_steps):
314
+ if not isinstance(prompt, str) and not isinstance(prompt, list):
315
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
316
+
317
+ if height % 8 != 0 or width % 8 != 0:
318
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
319
+
320
+ if (callback_steps is None) or (
321
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
322
+ ):
323
+ raise ValueError(
324
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
325
+ f" {type(callback_steps)}."
326
+ )
327
+
328
+ def prepare_latents(self, batch_size, num_channels_latents, video_length, height, width, dtype, device, generator, latents=None):
329
+ shape = (batch_size, num_channels_latents, video_length, height // self.vae_scale_factor, width // self.vae_scale_factor)
330
+ if isinstance(generator, list) and len(generator) != batch_size:
331
+ raise ValueError(
332
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
333
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
334
+ )
335
+
336
+ if latents is None:
337
+ rand_device = "cpu" if device.type == "mps" else device
338
+
339
+ if isinstance(generator, list):
340
+ shape = (1,) + shape[1:]
341
+ latents = [
342
+ torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype)
343
+ for i in range(batch_size)
344
+ ]
345
+ latents = torch.cat(latents, dim=0).to(device)
346
+ else:
347
+ latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
348
+ else:
349
+ if latents.shape != shape:
350
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
351
+ latents = latents.to(device)
352
+
353
+ # scale the initial noise by the standard deviation required by the scheduler
354
+ latents = latents * self.scheduler.init_noise_sigma
355
+ return latents
356
+
357
+ @torch.no_grad()
358
+ def __call__(
359
+ self,
360
+ prompt: Union[str, List[str]],
361
+ image: Union[str, List[str]],
362
+ height: Optional[int] = None,
363
+ width: Optional[int] = None,
364
+ num_inference_steps: int = 50,
365
+ guidance_scale: float = 7.5,
366
+ negative_prompt: Optional[Union[str, List[str]]] = None,
367
+ num_videos_per_prompt: Optional[int] = 1,
368
+ eta: float = 0.0,
369
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
370
+ latents: Optional[torch.FloatTensor] = None,
371
+ output_type: Optional[str] = "tensor",
372
+ return_dict: bool = True,
373
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
374
+ callback_steps: Optional[int] = 1,
375
+ class_labels = None,
376
+ prompt_ids = None,
377
+ unet_condition_type = None,
378
+ img_proj=None,
379
+ use_noise=True,
380
+ use_shifted_noise=False,
381
+ rescale = 0.7,
382
+ **kwargs,
383
+ ):
384
+ # Default height and width to unet
385
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
386
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
387
+ video_length = 1
388
+
389
+ # Check inputs. Raise error if not correct
390
+ self.check_inputs(prompt, height, width, callback_steps)
391
+ if isinstance(image, list):
392
+ batch_size = len(image)
393
+ else:
394
+ batch_size = image.shape[0]
395
+ # Define call parameters
396
+ batch_size = 1 if isinstance(prompt, str) else len(prompt)
397
+ device = self._execution_device
398
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
399
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
400
+ # corresponds to doing no classifier free guidance.
401
+ do_classifier_free_guidance = guidance_scale > 1.0
402
+
403
+ # 3. Encode input image
404
+ image_embeddings, image_latents = self._encode_image(image, device, num_videos_per_prompt, do_classifier_free_guidance, img_proj=img_proj) #torch.Size([64, 1, 768]) torch.Size([64, 4, 32, 32])
405
+ image_latents = rearrange(image_latents, "(b f) c h w -> b c f h w", f=1) #torch.Size([64, 4, 1, 32, 32])
406
+
407
+ # Encode input prompt
408
+ text_embeddings = self._encode_prompt( #torch.Size([64, 77, 768])
409
+ prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt
410
+ )
411
+
412
+ # Prepare timesteps
413
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
414
+ timesteps = self.scheduler.timesteps
415
+
416
+ # Prepare latent variables
417
+ num_channels_latents = self.unet.in_channels
418
+ latents = self.prepare_latents(
419
+ batch_size * num_videos_per_prompt,
420
+ num_channels_latents,
421
+ video_length,
422
+ height,
423
+ width,
424
+ text_embeddings.dtype,
425
+ device,
426
+ generator,
427
+ latents,
428
+ )
429
+ latents_dtype = latents.dtype
430
+
431
+ # Prepare extra step kwargs.
432
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
433
+
434
+ # Denoising loop
435
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
436
+
437
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
438
+ for i, t in enumerate(tqdm.tqdm(timesteps)):
439
+ # expand the latents if we are doing classifier free guidance
440
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
441
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
442
+
443
+ noise_cond = torch.randn_like(image_latents)
444
+ if use_noise:
445
+ cond_latents = self.scheduler.add_noise(image_latents, noise_cond, t)
446
+ else:
447
+ cond_latents = image_latents
448
+ cond_latent_model_input = torch.cat([cond_latents] * 2) if do_classifier_free_guidance else cond_latents
449
+ cond_latent_model_input = self.scheduler.scale_model_input(cond_latent_model_input, t)
450
+
451
+ # predict the noise residual
452
+ # ref text condition
453
+ ref_dict = {}
454
+ if self.ref_unet is not None:
455
+ noise_pred_cond = self.ref_unet(
456
+ cond_latent_model_input,
457
+ t,
458
+ encoder_hidden_states=text_embeddings.to(torch.float32),
459
+ cross_attention_kwargs=dict(mode="w", ref_dict=ref_dict)
460
+ ).sample.to(dtype=latents_dtype)
461
+
462
+ # text condition for unet
463
+ text_embeddings_unet = text_embeddings.unsqueeze(1).repeat(1,latents.shape[2],1,1)
464
+ text_embeddings_unet = rearrange(text_embeddings_unet, 'B Nv d c -> (B Nv) d c')
465
+ # image condition for unet
466
+ image_embeddings_unet = image_embeddings.unsqueeze(1).repeat(1,latents.shape[2],1, 1)
467
+ image_embeddings_unet = rearrange(image_embeddings_unet, 'B Nv d c -> (B Nv) d c')
468
+
469
+ encoder_hidden_states_unet_cond = image_embeddings_unet
470
+
471
+ if self.ref_unet is not None:
472
+ noise_pred = self.unet(
473
+ latent_model_input.to(torch.float32),
474
+ t,
475
+ encoder_hidden_states=encoder_hidden_states_unet_cond.to(torch.float32),
476
+ cross_attention_kwargs=dict(mode="r", ref_dict=ref_dict, is_cfg_guidance=do_classifier_free_guidance)
477
+ ).sample.to(dtype=latents_dtype)
478
+ else:
479
+ noise_pred = self.unet(
480
+ latent_model_input.to(torch.float32),
481
+ t,
482
+ encoder_hidden_states=encoder_hidden_states_unet_cond.to(torch.float32),
483
+ cross_attention_kwargs=dict(mode="n", ref_dict=ref_dict, is_cfg_guidance=do_classifier_free_guidance)
484
+ ).sample.to(dtype=latents_dtype)
485
+ # perform guidance
486
+ if do_classifier_free_guidance:
487
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
488
+ if use_shifted_noise:
489
+ # Apply regular classifier-free guidance.
490
+ cfg = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
491
+ # Calculate standard deviations.
492
+ std_pos = noise_pred_text.std([1,2,3], keepdim=True)
493
+ std_cfg = cfg.std([1,2,3], keepdim=True)
494
+ # Apply guidance rescale with fused operations.
495
+ factor = std_pos / std_cfg
496
+ factor = rescale * factor + (1 - rescale)
497
+ noise_pred = cfg * factor
498
+ else:
499
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
500
+
501
+ # compute the previous noisy sample x_t -> x_t-1
502
+ noise_pred = rearrange(noise_pred, "(b f) c h w -> b c f h w", f=video_length)
503
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
504
+
505
+ # call the callback, if provided
506
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
507
+ progress_bar.update()
508
+ if callback is not None and i % callback_steps == 0:
509
+ callback(i, t, latents)
510
+
511
+ # Post-processing
512
+ video = self.decode_latents(latents)
513
+
514
+ # Convert to tensor
515
+ if output_type == "tensor":
516
+ video = torch.from_numpy(video)
517
+
518
+ return video
canonicalize/util.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import imageio
3
+ import numpy as np
4
+ from typing import Union
5
+ import cv2
6
+ import torch
7
+ import torchvision
8
+
9
+ from tqdm import tqdm
10
+ from einops import rearrange
11
+
12
+ def shifted_noise(betas, image_d=512, noise_d=256, shifted_noise=True):
13
+ alphas = 1 - betas
14
+ alphas_bar = torch.cumprod(alphas, dim=0)
15
+ d = (image_d / noise_d) ** 2
16
+ if shifted_noise:
17
+ alphas_bar = alphas_bar / (d - (d - 1) * alphas_bar)
18
+ alphas_bar_sqrt = torch.sqrt(alphas_bar)
19
+ alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
20
+ alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
21
+ # Shift so last timestep is zero.
22
+ alphas_bar_sqrt -= alphas_bar_sqrt_T
23
+ # Scale so first timestep is back to old value.
24
+ alphas_bar_sqrt *= alphas_bar_sqrt_0 / (
25
+ alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
26
+
27
+ # Convert alphas_bar_sqrt to betas
28
+ alphas_bar = alphas_bar_sqrt ** 2
29
+ alphas = alphas_bar[1:] / alphas_bar[:-1]
30
+ alphas = torch.cat([alphas_bar[0:1], alphas])
31
+ betas = 1 - alphas
32
+ return betas
33
+
34
+ def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=4, fps=8):
35
+ videos = rearrange(videos, "b c t h w -> t b c h w")
36
+ outputs = []
37
+ for x in videos:
38
+ x = torchvision.utils.make_grid(x, nrow=n_rows)
39
+ x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
40
+ if rescale:
41
+ x = (x + 1.0) / 2.0 # -1,1 -> 0,1
42
+ x = (x * 255).numpy().astype(np.uint8)
43
+ outputs.append(x)
44
+
45
+ os.makedirs(os.path.dirname(path), exist_ok=True)
46
+ imageio.mimsave(path, outputs, duration=1000/fps)
47
+
48
+ def save_imgs_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=4, fps=8):
49
+ videos = rearrange(videos, "b c t h w -> t b c h w")
50
+ for i, x in enumerate(videos):
51
+ x = torchvision.utils.make_grid(x, nrow=n_rows)
52
+ x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
53
+ if rescale:
54
+ x = (x + 1.0) / 2.0 # -1,1 -> 0,1
55
+ x = (x * 255).numpy().astype(np.uint8)
56
+ os.makedirs(os.path.dirname(path), exist_ok=True)
57
+ cv2.imwrite(os.path.join(path, f'view_{i}.png'), x[:,:,::-1])
58
+
59
+ def imgs_grid(videos: torch.Tensor, rescale=False, n_rows=4, fps=8):
60
+ videos = rearrange(videos, "b c t h w -> t b c h w")
61
+ image_list = []
62
+ for i, x in enumerate(videos):
63
+ x = torchvision.utils.make_grid(x, nrow=n_rows)
64
+ x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
65
+ if rescale:
66
+ x = (x + 1.0) / 2.0 # -1,1 -> 0,1
67
+ x = (x * 255).numpy().astype(np.uint8)
68
+ # image_list.append(x[:,:,::-1])
69
+ image_list.append(x)
70
+ return image_list
71
+
72
+ # DDIM Inversion
73
+ @torch.no_grad()
74
+ def init_prompt(prompt, pipeline):
75
+ uncond_input = pipeline.tokenizer(
76
+ [""], padding="max_length", max_length=pipeline.tokenizer.model_max_length,
77
+ return_tensors="pt"
78
+ )
79
+ uncond_embeddings = pipeline.text_encoder(uncond_input.input_ids.to(pipeline.device))[0]
80
+ text_input = pipeline.tokenizer(
81
+ [prompt],
82
+ padding="max_length",
83
+ max_length=pipeline.tokenizer.model_max_length,
84
+ truncation=True,
85
+ return_tensors="pt",
86
+ )
87
+ text_embeddings = pipeline.text_encoder(text_input.input_ids.to(pipeline.device))[0]
88
+ context = torch.cat([uncond_embeddings, text_embeddings])
89
+
90
+ return context
91
+
92
+
93
+ def next_step(model_output: Union[torch.FloatTensor, np.ndarray], timestep: int,
94
+ sample: Union[torch.FloatTensor, np.ndarray], ddim_scheduler):
95
+ timestep, next_timestep = min(
96
+ timestep - ddim_scheduler.config.num_train_timesteps // ddim_scheduler.num_inference_steps, 999), timestep
97
+ alpha_prod_t = ddim_scheduler.alphas_cumprod[timestep] if timestep >= 0 else ddim_scheduler.final_alpha_cumprod
98
+ alpha_prod_t_next = ddim_scheduler.alphas_cumprod[next_timestep]
99
+ beta_prod_t = 1 - alpha_prod_t
100
+ next_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5
101
+ next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output
102
+ next_sample = alpha_prod_t_next ** 0.5 * next_original_sample + next_sample_direction
103
+ return next_sample
104
+
105
+
106
+ def get_noise_pred_single(latents, t, context, unet):
107
+ noise_pred = unet(latents, t, encoder_hidden_states=context)["sample"]
108
+ return noise_pred
109
+
110
+
111
+ @torch.no_grad()
112
+ def ddim_loop(pipeline, ddim_scheduler, latent, num_inv_steps, prompt):
113
+ context = init_prompt(prompt, pipeline)
114
+ uncond_embeddings, cond_embeddings = context.chunk(2)
115
+ all_latent = [latent]
116
+ latent = latent.clone().detach()
117
+ for i in tqdm(range(num_inv_steps)):
118
+ t = ddim_scheduler.timesteps[len(ddim_scheduler.timesteps) - i - 1]
119
+ noise_pred = get_noise_pred_single(latent.to(torch.float32), t, cond_embeddings.to(torch.float32), pipeline.unet)
120
+ latent = next_step(noise_pred, t, latent, ddim_scheduler)
121
+ all_latent.append(latent)
122
+ return all_latent
123
+
124
+
125
+ @torch.no_grad()
126
+ def ddim_inversion(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt=""):
127
+ ddim_latents = ddim_loop(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt)
128
+ return ddim_latents
configs/canonicalization-infer.yaml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ pretrained_model_path: "./ckpt/StdGEN-canonicalize-1024"
2
+
3
+ validation:
4
+ guidance_scale: 5.0
5
+ timestep: 40
6
+ width_input: 640
7
+ height_input: 1024
8
+ use_inv_latent: False
9
+
10
+ use_noise: False
11
+ unet_condition_type: image
12
+
13
+ unet_from_pretrained_kwargs:
14
+ camera_embedding_type: 'e_de_da_sincos'
15
+ projection_class_embeddings_input_dim: 10 # modify
16
+ joint_attention: false # modify
17
+ num_views: 1
18
+ sample_size: 96
19
+ zero_init_conv_in: false
20
+ zero_init_camera_projection: false
21
+ in_channels: 4
22
+ use_safetensors: true
configs/mesh-slrm-infer.yaml ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_config:
2
+ target: slrm.models.lrm_mesh.MeshSLRM
3
+ params:
4
+ encoder_feat_dim: 768
5
+ encoder_freeze: false
6
+ encoder_model_name: facebook/dino-vitb16
7
+ transformer_dim: 1024
8
+ transformer_layers: 16
9
+ transformer_heads: 16
10
+ triplane_low_res: 32
11
+ triplane_high_res: 64
12
+ triplane_dim: 80
13
+ rendering_samples_per_ray: 128
14
+ grid_res_xy: 100
15
+ grid_res_z: 150
16
+ grid_scale_xy: 1.4
17
+ grid_scale_z: 2.1
18
+ is_ortho: false
19
+ lora_rank: 128
20
+
21
+
22
+ infer_config:
23
+ model_path: ckpt/StdGEN-mesh-slrm.pth
24
+ texture_resolution: 1024
25
+ render_resolution: 512
data/test_list.json ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ "7/3439809555813808357",
3
+ "2/6732152415572359482",
4
+ "6/6198244732386977066",
5
+ "7/7008911571585236777",
6
+ "8/8155498832525298838",
7
+ "1/2204149645140259881",
8
+ "0/1323933330222715340",
9
+ "7/1098644675621653787",
10
+ "9/6777209416978605329",
11
+ "1/1542224037528704351",
12
+ "0/8703316823295014690",
13
+ "3/5204013134706272913",
14
+ "0/6457137167414843850",
15
+ "2/6617574843151473382",
16
+ "8/7981152186026608038",
17
+ "1/4344590844740564561",
18
+ "2/7649110201056191442",
19
+ "2/1146977392849123402",
20
+ "2/2426517581512337892",
21
+ "7/2824689386300465357",
22
+ "6/2270010410433478366",
23
+ "3/3814323604952041013",
24
+ "9/8728960448674306769",
25
+ "7/1506365063811110387",
26
+ "5/5718924742282692475",
27
+ "1/1633099290949034671",
28
+ "5/8999640709832005845",
29
+ "5/720254657332917065",
30
+ "7/4357384925726277837",
31
+ "3/4227726538279421493",
32
+ "2/4382303856103217892",
33
+ "8/6632593566609006548",
34
+ "7/3749944138508065767",
35
+ "2/878764636138223992",
36
+ "5/8170908340955840135",
37
+ "6/4845695357833755236",
38
+ "1/2743140748471131991",
39
+ "1/5803218296084123071",
40
+ "6/9182882771353803536",
41
+ "5/5872666540206860925",
42
+ "4/9212223181352426964",
43
+ "5/3899312551169605935",
44
+ "0/7695929267562496220",
45
+ "7/3104109662674926717",
46
+ "8/2319063723115019838",
47
+ "6/8112121852475729956",
48
+ "9/5705939742315993109",
49
+ "1/6952166826280123421",
50
+ "0/6830091751476954110",
51
+ "2/8891263394100940152",
52
+ "3/8287958311266406833",
53
+ "9/8934151403263879299",
54
+ "7/730625960893750417",
55
+ "8/2007959965099676308",
56
+ "7/7110997111250638537",
57
+ "1/1910258394089325361",
58
+ "6/7538221091944098366",
59
+ "9/8509393563940760269",
60
+ "3/1981376850787241243",
61
+ "4/821179359686508964",
62
+ "6/2359248447840976906",
63
+ "2/5396219174677320232",
64
+ "7/4683457172478674257",
65
+ "8/1863701953709398218",
66
+ "9/910003033484940229",
67
+ "3/880320695540753593",
68
+ "0/990769530404275120",
69
+ "2/4551500513185396552",
70
+ "5/5015097855418058995",
71
+ "7/4896074338113329997",
72
+ "5/7306978321405535555",
73
+ "9/7776834385265136719",
74
+ "6/6631395994048613416",
75
+ "8/3757051138516476638",
76
+ "3/3283421712821668743",
77
+ "1/8144010044536474571",
78
+ "2/7876180780086370752",
79
+ "6/1647234603582341626",
80
+ "6/1341337037707864016",
81
+ "2/6302505551505574612",
82
+ "0/3465024955374919620",
83
+ "5/7900060151297927765",
84
+ "1/4675194210589373061",
85
+ "0/3282208207844657250",
86
+ "4/3240020585468727994",
87
+ "2/7833064532316643952",
88
+ "6/4790345485250053216",
89
+ "7/2935339105576984837",
90
+ "8/2599602859354916028",
91
+ "2/4769742243183930282",
92
+ "6/604217236327738596",
93
+ "4/5117485835686648194",
94
+ "0/1487097526635566140",
95
+ "4/3484530361677579674",
96
+ "3/8530544536064633943",
97
+ "7/4144922250519743927",
98
+ "9/2413192196654279969",
99
+ "2/1350971297625987822",
100
+ "5/6433334135280042785",
101
+ "7/6692827166906062907",
102
+ "8/4678213844371676838",
103
+ "9/262140445129918559",
104
+ "5/4188635875053572005",
105
+ "9/6950138434143075689",
106
+ "4/6953579337597168824",
107
+ "6/16762222989681526",
108
+ "0/8704380013906593380",
109
+ "0/6734578480501157450",
110
+ "1/8562961060475858791"
111
+ ]
data/train_list.json ADDED
The diff for this file is too large to render. See raw diff
 
infer_api.py ADDED
@@ -0,0 +1,881 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import glob
3
+
4
+ import io
5
+ import argparse
6
+ import inspect
7
+ import os
8
+ import random
9
+ import tempfile
10
+ from typing import Dict, Optional, Tuple
11
+ from omegaconf import OmegaConf
12
+ import numpy as np
13
+
14
+ import torch
15
+
16
+ from diffusers import AutoencoderKL, DDIMScheduler
17
+ from diffusers.utils import check_min_version
18
+ from tqdm.auto import tqdm
19
+ from transformers import CLIPTextModel, CLIPTokenizer, CLIPImageProcessor, CLIPVisionModelWithProjection
20
+ from torchvision import transforms
21
+
22
+ from canonicalize.models.unet_mv2d_condition import UNetMV2DConditionModel
23
+ from canonicalize.models.unet_mv2d_ref import UNetMV2DRefModel
24
+ from canonicalize.pipeline_canonicalize import CanonicalizationPipeline
25
+ from einops import rearrange
26
+ from torchvision.utils import save_image
27
+ import json
28
+ import cv2
29
+
30
+ import onnxruntime as rt
31
+ from huggingface_hub.file_download import hf_hub_download
32
+ from huggingface_hub import list_repo_files
33
+ from rm_anime_bg.cli import get_mask, SCALE
34
+
35
+ import argparse
36
+ import os
37
+ import cv2
38
+ import glob
39
+ import numpy as np
40
+ import matplotlib.pyplot as plt
41
+ from typing import Dict, Optional, List
42
+ from omegaconf import OmegaConf, DictConfig
43
+ from PIL import Image
44
+ from pathlib import Path
45
+ from dataclasses import dataclass
46
+ from typing import Dict
47
+ import torch
48
+ import torch.nn.functional as F
49
+ import torch.utils.checkpoint
50
+ import torchvision.transforms.functional as TF
51
+ from torch.utils.data import Dataset, DataLoader
52
+ from torchvision import transforms
53
+ from torchvision.utils import make_grid, save_image
54
+ from accelerate.utils import set_seed
55
+ from tqdm.auto import tqdm
56
+ from einops import rearrange, repeat
57
+ from multiview.pipeline_multiclass import StableUnCLIPImg2ImgPipeline
58
+
59
+ import os
60
+ import imageio
61
+ import numpy as np
62
+ import torch
63
+ import cv2
64
+ import glob
65
+ import matplotlib.pyplot as plt
66
+ from PIL import Image
67
+ from torchvision.transforms import v2
68
+ from pytorch_lightning import seed_everything
69
+ from omegaconf import OmegaConf
70
+ from tqdm import tqdm
71
+
72
+ from slrm.utils.train_util import instantiate_from_config
73
+ from slrm.utils.camera_util import (
74
+ FOV_to_intrinsics,
75
+ get_circular_camera_poses,
76
+ )
77
+ from slrm.utils.mesh_util import save_obj, save_glb
78
+ from slrm.utils.infer_util import images_to_video
79
+
80
+ import cv2
81
+ import numpy as np
82
+ import os
83
+ import trimesh
84
+ import argparse
85
+ import torch
86
+ import scipy
87
+ from PIL import Image
88
+
89
+ from refine.mesh_refine import geo_refine
90
+ from refine.func import make_star_cameras_orthographic
91
+ from refine.render import NormalsRenderer, calc_vertex_normals
92
+
93
+ import pytorch3d
94
+ from pytorch3d.structures import Meshes
95
+ from sklearn.neighbors import KDTree
96
+
97
+ from segment_anything import SamAutomaticMaskGenerator, sam_model_registry
98
+
99
+ check_min_version("0.24.0")
100
+ weight_dtype = torch.float16
101
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
102
+ VIEWS = ['front', 'front_right', 'right', 'back', 'left', 'front_left']
103
+
104
+
105
+ def set_seed(seed):
106
+ random.seed(seed)
107
+ np.random.seed(seed)
108
+ torch.manual_seed(seed)
109
+ torch.cuda.manual_seed_all(seed)
110
+
111
+ class BkgRemover:
112
+ def __init__(self, force_cpu: Optional[bool] = True):
113
+ session_infer_path = hf_hub_download(
114
+ repo_id="skytnt/anime-seg", filename="isnetis.onnx",
115
+ )
116
+ providers: list[str] = ["CPUExecutionProvider"]
117
+ if not force_cpu and "CUDAExecutionProvider" in rt.get_available_providers():
118
+ providers = ["CUDAExecutionProvider"]
119
+
120
+ self.session_infer = rt.InferenceSession(
121
+ session_infer_path, providers=providers,
122
+ )
123
+
124
+ def remove_background(
125
+ self,
126
+ img: np.ndarray,
127
+ alpha_min: float,
128
+ alpha_max: float,
129
+ ) -> list:
130
+ img = np.array(img)
131
+ mask = get_mask(self.session_infer, img)
132
+ mask[mask < alpha_min] = 0.0
133
+ mask[mask > alpha_max] = 1.0
134
+ img_after = (mask * img).astype(np.uint8)
135
+ mask = (mask * SCALE).astype(np.uint8)
136
+ img_after = np.concatenate([img_after, mask], axis=2, dtype=np.uint8)
137
+ return Image.fromarray(img_after)
138
+
139
+
140
+ def process_image(image, totensor, width, height):
141
+ assert image.mode == "RGBA"
142
+
143
+ # Find non-transparent pixels
144
+ non_transparent = np.nonzero(np.array(image)[..., 3])
145
+ min_x, max_x = non_transparent[1].min(), non_transparent[1].max()
146
+ min_y, max_y = non_transparent[0].min(), non_transparent[0].max()
147
+ image = image.crop((min_x, min_y, max_x, max_y))
148
+
149
+ # paste to center
150
+ max_dim = max(image.width, image.height)
151
+ max_height = int(max_dim * 1.2)
152
+ max_width = int(max_dim / (height/width) * 1.2)
153
+ new_image = Image.new("RGBA", (max_width, max_height))
154
+ left = (max_width - image.width) // 2
155
+ top = (max_height - image.height) // 2
156
+ new_image.paste(image, (left, top))
157
+
158
+ image = new_image.resize((width, height), resample=Image.BICUBIC)
159
+ image = np.array(image)
160
+ image = image.astype(np.float32) / 255.
161
+ assert image.shape[-1] == 4 # RGBA
162
+ alpha = image[..., 3:4]
163
+ bg_color = np.array([1., 1., 1.], dtype=np.float32)
164
+ image = image[..., :3] * alpha + bg_color * (1 - alpha)
165
+ return totensor(image)
166
+
167
+
168
+ @torch.no_grad()
169
+ def inference(validation_pipeline, bkg_remover, input_image, vae, feature_extractor, image_encoder, unet, ref_unet, tokenizer,
170
+ text_encoder, pretrained_model_path, generator, validation, val_width, val_height, unet_condition_type,
171
+ use_noise=True, noise_d=256, crop=False, seed=100, timestep=20):
172
+ set_seed(seed)
173
+
174
+ totensor = transforms.ToTensor()
175
+
176
+ prompts = "high quality, best quality"
177
+ prompt_ids = tokenizer(
178
+ prompts, max_length=tokenizer.model_max_length, padding="max_length", truncation=True,
179
+ return_tensors="pt"
180
+ ).input_ids[0]
181
+
182
+ # (B*Nv, 3, H, W)
183
+ B = 1
184
+ if input_image.mode != "RGBA":
185
+ # remove background
186
+ input_image = bkg_remover.remove_background(input_image, 0.1, 0.9)
187
+ imgs_in = process_image(input_image, totensor, val_width, val_height)
188
+ imgs_in = rearrange(imgs_in.unsqueeze(0).unsqueeze(0), "B Nv C H W -> (B Nv) C H W")
189
+
190
+ with torch.autocast('cuda' if torch.cuda.is_available() else 'cpu', dtype=weight_dtype):
191
+ imgs_in = imgs_in.to(device=device)
192
+ # B*Nv images
193
+ out = validation_pipeline(prompt=prompts, image=imgs_in.to(weight_dtype), generator=generator,
194
+ num_inference_steps=timestep, prompt_ids=prompt_ids,
195
+ height=val_height, width=val_width, unet_condition_type=unet_condition_type,
196
+ use_noise=use_noise, **validation,)
197
+ out = rearrange(out, "B C f H W -> (B f) C H W", f=1)
198
+
199
+ img_buf = io.BytesIO()
200
+ save_image(out[0], img_buf, format='PNG')
201
+ img_buf.seek(0)
202
+ img = Image.open(img_buf)
203
+
204
+ torch.cuda.empty_cache()
205
+ return img
206
+
207
+
208
+ ######### Multi View Part #############
209
+ weight_dtype = torch.float16
210
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
211
+
212
+ def tensor_to_numpy(tensor):
213
+ return tensor.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
214
+
215
+
216
+ @dataclass
217
+ class TestConfig:
218
+ pretrained_model_name_or_path: str
219
+ pretrained_unet_path:Optional[str]
220
+ revision: Optional[str]
221
+ validation_dataset: Dict
222
+ save_dir: str
223
+ seed: Optional[int]
224
+ validation_batch_size: int
225
+ dataloader_num_workers: int
226
+ save_mode: str
227
+ local_rank: int
228
+
229
+ pipe_kwargs: Dict
230
+ pipe_validation_kwargs: Dict
231
+ unet_from_pretrained_kwargs: Dict
232
+ validation_grid_nrow: int
233
+ camera_embedding_lr_mult: float
234
+
235
+ num_views: int
236
+ camera_embedding_type: str
237
+
238
+ pred_type: str
239
+ regress_elevation: bool
240
+ enable_xformers_memory_efficient_attention: bool
241
+
242
+ cond_on_normals: bool
243
+ cond_on_colors: bool
244
+
245
+ regress_elevation: bool
246
+ regress_focal_length: bool
247
+
248
+
249
+
250
+ def convert_to_numpy(tensor):
251
+ return tensor.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
252
+
253
+ def save_image(tensor):
254
+ ndarr = convert_to_numpy(tensor)
255
+ return save_image_numpy(ndarr)
256
+
257
+ def save_image_numpy(ndarr):
258
+ im = Image.fromarray(ndarr)
259
+ # pad to square
260
+ if im.size[0] != im.size[1]:
261
+ size = max(im.size)
262
+ new_im = Image.new("RGB", (size, size))
263
+ # set to white
264
+ new_im.paste((255, 255, 255), (0, 0, size, size))
265
+ new_im.paste(im, ((size - im.size[0]) // 2, (size - im.size[1]) // 2))
266
+ im = new_im
267
+ # resize to 1024x1024
268
+ im = im.resize((1024, 1024), Image.LANCZOS)
269
+ return im
270
+
271
+ def run_multiview_infer(data, pipeline, cfg: TestConfig, num_levels=3):
272
+ if cfg.seed is None:
273
+ generator = None
274
+ else:
275
+ generator = torch.Generator(device=pipeline.unet.device).manual_seed(cfg.seed)
276
+
277
+ images_cond = []
278
+ results = {}
279
+
280
+ torch.cuda.empty_cache()
281
+ images_cond.append(data['image_cond_rgb'][:, 0].cuda())
282
+ imgs_in = torch.cat([data['image_cond_rgb']]*2, dim=0).cuda()
283
+ num_views = imgs_in.shape[1]
284
+ imgs_in = rearrange(imgs_in, "B Nv C H W -> (B Nv) C H W")# (B*Nv, 3, H, W)
285
+
286
+ target_h, target_w = imgs_in.shape[-2], imgs_in.shape[-1]
287
+
288
+ normal_prompt_embeddings, clr_prompt_embeddings = data['normal_prompt_embeddings'].cuda(), data['color_prompt_embeddings'].cuda()
289
+ prompt_embeddings = torch.cat([normal_prompt_embeddings, clr_prompt_embeddings], dim=0)
290
+ prompt_embeddings = rearrange(prompt_embeddings, "B Nv N C -> (B Nv) N C")
291
+
292
+ # B*Nv images
293
+ unet_out = pipeline(
294
+ imgs_in, None, prompt_embeds=prompt_embeddings,
295
+ generator=generator, guidance_scale=3.0, output_type='pt', num_images_per_prompt=1,
296
+ height=cfg.height, width=cfg.width,
297
+ num_inference_steps=40, eta=1.0,
298
+ num_levels=num_levels,
299
+ )
300
+
301
+ for level in range(num_levels):
302
+ out = unet_out[level].images
303
+ bsz = out.shape[0] // 2
304
+
305
+ normals_pred = out[:bsz]
306
+ images_pred = out[bsz:]
307
+
308
+ if num_levels == 2:
309
+ results[level+1] = {'normals': [], 'images': []}
310
+ else:
311
+ results[level] = {'normals': [], 'images': []}
312
+
313
+ for i in range(bsz//num_views):
314
+ img_in_ = images_cond[-1][i].to(out.device)
315
+ for j in range(num_views):
316
+ view = VIEWS[j]
317
+ idx = i*num_views + j
318
+ normal = normals_pred[idx]
319
+ color = images_pred[idx]
320
+
321
+ ## save color and normal---------------------
322
+ new_normal = save_image(normal)
323
+ new_color = save_image(color)
324
+
325
+ if num_levels == 2:
326
+ results[level+1]['normals'].append(new_normal)
327
+ results[level+1]['images'].append(new_color)
328
+ else:
329
+ results[level]['normals'].append(new_normal)
330
+ results[level]['images'].append(new_color)
331
+
332
+ torch.cuda.empty_cache()
333
+ return results
334
+
335
+
336
+ def load_multiview_pipeline(cfg):
337
+ pipeline = StableUnCLIPImg2ImgPipeline.from_pretrained(
338
+ cfg.pretrained_path,
339
+ torch_dtype=torch.float16,)
340
+ pipeline.unet.enable_xformers_memory_efficient_attention()
341
+ if torch.cuda.is_available():
342
+ pipeline.to(device)
343
+ return pipeline
344
+
345
+
346
+ class InferAPI:
347
+ def __init__(self,
348
+ canonical_configs,
349
+ multiview_configs,
350
+ slrm_configs,
351
+ refine_configs):
352
+ self.canonical_configs = canonical_configs
353
+ self.multiview_configs = multiview_configs
354
+ self.slrm_configs = slrm_configs
355
+ self.refine_configs = refine_configs
356
+
357
+ repo_id = "hyz317/StdGEN"
358
+ all_files = list_repo_files(repo_id, revision="main")
359
+ for file in all_files:
360
+ if os.path.exists(file):
361
+ continue
362
+ hf_hub_download(repo_id, file, local_dir="./ckpt")
363
+
364
+ self.canonical_infer = InferCanonicalAPI(self.canonical_configs)
365
+ self.multiview_infer = InferMultiviewAPI(self.multiview_configs)
366
+ self.slrm_infer = InferSlrmAPI(self.slrm_configs)
367
+ self.refine_infer = InferRefineAPI(self.refine_configs)
368
+
369
+ def genStage1(self, img, seed):
370
+ return self.canonical_infer.gen(img, seed)
371
+
372
+ def genStage2(self, img, seed, num_levels):
373
+ return self.multiview_infer.gen(img, seed, num_levels)
374
+
375
+ def genStage3(self, img):
376
+ return self.slrm_infer.gen(img)
377
+
378
+ def genStage4(self, meshes, imgs):
379
+ return self.refine_infer.refine(meshes, imgs)
380
+
381
+
382
+ ############## Refine ##############
383
+ def fix_vert_color_glb(mesh_path):
384
+ from pygltflib import GLTF2, Material, PbrMetallicRoughness
385
+ obj1 = GLTF2().load(mesh_path)
386
+ obj1.meshes[0].primitives[0].material = 0
387
+ obj1.materials.append(Material(
388
+ pbrMetallicRoughness = PbrMetallicRoughness(
389
+ baseColorFactor = [1.0, 1.0, 1.0, 1.0],
390
+ metallicFactor = 0.,
391
+ roughnessFactor = 1.0,
392
+ ),
393
+ emissiveFactor = [0.0, 0.0, 0.0],
394
+ doubleSided = True,
395
+ ))
396
+ obj1.save(mesh_path)
397
+
398
+
399
+ def srgb_to_linear(c_srgb):
400
+ c_linear = np.where(c_srgb <= 0.04045, c_srgb / 12.92, ((c_srgb + 0.055) / 1.055) ** 2.4)
401
+ return c_linear.clip(0, 1.)
402
+
403
+
404
+ def save_py3dmesh_with_trimesh_fast(meshes: Meshes, save_glb_path, apply_sRGB_to_LinearRGB=True):
405
+ # convert from pytorch3d meshes to trimesh mesh
406
+ vertices = meshes.verts_packed().cpu().float().numpy()
407
+ triangles = meshes.faces_packed().cpu().long().numpy()
408
+ np_color = meshes.textures.verts_features_packed().cpu().float().numpy()
409
+ if save_glb_path.endswith(".glb"):
410
+ # rotate 180 along +Y
411
+ vertices[:, [0, 2]] = -vertices[:, [0, 2]]
412
+
413
+ if apply_sRGB_to_LinearRGB:
414
+ np_color = srgb_to_linear(np_color)
415
+ assert vertices.shape[0] == np_color.shape[0]
416
+ assert np_color.shape[1] == 3
417
+ assert 0 <= np_color.min() and np_color.max() <= 1.001, f"min={np_color.min()}, max={np_color.max()}"
418
+ np_color = np.clip(np_color, 0, 1)
419
+ mesh = trimesh.Trimesh(vertices=vertices, faces=triangles, vertex_colors=np_color)
420
+ mesh.remove_unreferenced_vertices()
421
+ # save mesh
422
+ mesh.export(save_glb_path)
423
+ if save_glb_path.endswith(".glb"):
424
+ fix_vert_color_glb(save_glb_path)
425
+ print(f"saving to {save_glb_path}")
426
+
427
+
428
+ def calc_horizontal_offset(target_img, source_img):
429
+ target_mask = target_img.astype(np.float32).sum(axis=-1) > 750
430
+ source_mask = source_img.astype(np.float32).sum(axis=-1) > 750
431
+ best_offset = -114514
432
+ for offset in range(-200, 200):
433
+ offset_mask = np.roll(source_mask, offset, axis=1)
434
+ overlap = (target_mask & offset_mask).sum()
435
+ if overlap > best_offset:
436
+ best_offset = overlap
437
+ best_offset_value = offset
438
+ return best_offset_value
439
+
440
+
441
+ def calc_horizontal_offset2(target_mask, source_img):
442
+ source_mask = source_img.astype(np.float32).sum(axis=-1) > 750
443
+ best_offset = -114514
444
+ for offset in range(-200, 200):
445
+ offset_mask = np.roll(source_mask, offset, axis=1)
446
+ overlap = (target_mask & offset_mask).sum()
447
+ if overlap > best_offset:
448
+ best_offset = overlap
449
+ best_offset_value = offset
450
+ return best_offset_value
451
+
452
+
453
+ def get_distract_mask(generator, color_0, color_1, normal_0=None, normal_1=None, thres=0.25, ratio=0.50, outside_thres=0.10, outside_ratio=0.20):
454
+ distract_area = np.abs(color_0 - color_1).sum(axis=-1) > thres
455
+ if normal_0 is not None and normal_1 is not None:
456
+ distract_area |= np.abs(normal_0 - normal_1).sum(axis=-1) > thres
457
+ labeled_array, num_features = scipy.ndimage.label(distract_area)
458
+ results = []
459
+
460
+ random_sampled_points = []
461
+
462
+ for i in range(num_features + 1):
463
+ if np.sum(labeled_array == i) > 1000 and np.sum(labeled_array == i) < 100000:
464
+ results.append((i, np.sum(labeled_array == i)))
465
+ # random sample a point in the area
466
+ points = np.argwhere(labeled_array == i)
467
+ random_sampled_points.append(points[np.random.randint(0, points.shape[0])])
468
+
469
+ results = sorted(results, key=lambda x: x[1], reverse=True) # [1:]
470
+ distract_mask = np.zeros_like(distract_area)
471
+ distract_bbox = np.zeros_like(distract_area)
472
+ for i, _ in results:
473
+ distract_mask |= labeled_array == i
474
+ bbox = np.argwhere(labeled_array == i)
475
+ min_x, min_y = bbox.min(axis=0)
476
+ max_x, max_y = bbox.max(axis=0)
477
+ distract_bbox[min_x:max_x, min_y:max_y] = 1
478
+
479
+ points = np.array(random_sampled_points)[:, ::-1]
480
+ labels = np.ones(len(points), dtype=np.int32)
481
+
482
+ masks = generator.generate((color_1 * 255).astype(np.uint8))
483
+
484
+ outside_area = np.abs(color_0 - color_1).sum(axis=-1) < outside_thres
485
+
486
+ final_mask = np.zeros_like(distract_mask)
487
+ for iii, mask in enumerate(masks):
488
+ mask['segmentation'] = cv2.resize(mask['segmentation'].astype(np.float32), (1024, 1024)) > 0.5
489
+ intersection = np.logical_and(mask['segmentation'], distract_mask).sum()
490
+ total = mask['segmentation'].sum()
491
+ iou = intersection / total
492
+ outside_intersection = np.logical_and(mask['segmentation'], outside_area).sum()
493
+ outside_total = mask['segmentation'].sum()
494
+ outside_iou = outside_intersection / outside_total
495
+ if iou > ratio and outside_iou < outside_ratio:
496
+ final_mask |= mask['segmentation']
497
+
498
+ # calculate coverage
499
+ intersection = np.logical_and(final_mask, distract_mask).sum()
500
+ total = distract_mask.sum()
501
+ coverage = intersection / total
502
+
503
+ if coverage < 0.8:
504
+ # use original distract mask
505
+ final_mask = (distract_mask.copy() * 255).astype(np.uint8)
506
+ final_mask = cv2.dilate(final_mask, np.ones((3, 3), np.uint8), iterations=3)
507
+ labeled_array_dilate, num_features_dilate = scipy.ndimage.label(final_mask)
508
+ for i in range(num_features_dilate + 1):
509
+ if np.sum(labeled_array_dilate == i) < 200:
510
+ final_mask[labeled_array_dilate == i] = 255
511
+
512
+ final_mask = cv2.erode(final_mask, np.ones((3, 3), np.uint8), iterations=3)
513
+ final_mask = final_mask > 127
514
+
515
+ return distract_mask, distract_bbox, random_sampled_points, final_mask
516
+
517
+
518
+ class InferRefineAPI:
519
+ def __init__(self, config):
520
+ self.sam = sam_model_registry["vit_h"](checkpoint="./ckpt/sam_vit_h_4b8939.pth").cuda()
521
+ self.generator = SamAutomaticMaskGenerator(
522
+ model=self.sam,
523
+ points_per_side=64,
524
+ pred_iou_thresh=0.80,
525
+ stability_score_thresh=0.92,
526
+ crop_n_layers=1,
527
+ crop_n_points_downscale_factor=2,
528
+ min_mask_region_area=100,
529
+ )
530
+ self.outside_ratio = 0.20
531
+
532
+ def refine(self, meshes, imgs):
533
+ fixed_v, fixed_f, fixed_t = None, None, None
534
+ flow_vert, flow_vector = None, None
535
+ last_colors, last_normals = None, None
536
+ last_front_color, last_front_normal = None, None
537
+ distract_mask = None
538
+
539
+ mv, proj = make_star_cameras_orthographic(8, 1, r=1.2)
540
+ mv = mv[[4, 3, 2, 0, 6, 5]]
541
+ renderer = NormalsRenderer(mv,proj,(1024,1024))
542
+
543
+ results = []
544
+
545
+ for name_idx, level in zip([2, 0, 1], [2, 1, 0]):
546
+ mesh = trimesh.load(meshes[name_idx])
547
+ new_mesh = mesh.split(only_watertight=False)
548
+ new_mesh = [ j for j in new_mesh if len(j.vertices) >= 300 ]
549
+ mesh = trimesh.Scene(new_mesh).dump(concatenate=True)
550
+ mesh_v, mesh_f = mesh.vertices, mesh.faces
551
+
552
+ if last_colors is None:
553
+ images = renderer.render(
554
+ torch.tensor(mesh_v, device='cuda').float(),
555
+ torch.ones_like(torch.from_numpy(mesh_v), device='cuda').float(),
556
+ torch.tensor(mesh_f, device='cuda'),
557
+ )
558
+ mask = (images[..., 3] < 0.9).cpu().numpy()
559
+
560
+ colors, normals = [], []
561
+ for i in range(6):
562
+ color = np.array(imgs[level]['images'][i])
563
+ normal = np.array(imgs[level]['normals'][i])
564
+
565
+ if last_colors is not None:
566
+ offset = calc_horizontal_offset(np.array(last_colors[i]), color)
567
+ # print('offset', i, offset)
568
+ else:
569
+ offset = calc_horizontal_offset2(mask[i], color)
570
+ # print('init offset', i, offset)
571
+
572
+ if offset != 0:
573
+ color = np.roll(color, offset, axis=1)
574
+ normal = np.roll(normal, offset, axis=1)
575
+
576
+ color = Image.fromarray(color)
577
+ normal = Image.fromarray(normal)
578
+ colors.append(color)
579
+ normals.append(normal)
580
+
581
+ if last_front_color is not None and level == 0:
582
+ original_mask, distract_bbox, _, distract_mask = get_distract_mask(self.generator, last_front_color, np.array(colors[0]).astype(np.float32) / 255.0, outside_ratio=self.outside_ratio)
583
+ else:
584
+ distract_mask = None
585
+ distract_bbox = None
586
+
587
+ last_front_color = np.array(colors[0]).astype(np.float32) / 255.0
588
+ last_front_normal = np.array(normals[0]).astype(np.float32) / 255.0
589
+
590
+ if last_colors is None:
591
+ from copy import deepcopy
592
+ last_colors, last_normals = deepcopy(colors), deepcopy(normals)
593
+
594
+ # my mesh flow weight by nearest vertexs
595
+ if fixed_v is not None and fixed_f is not None and level == 1:
596
+ t = trimesh.Trimesh(vertices=mesh_v, faces=mesh_f)
597
+
598
+ fixed_v_cpu = fixed_v.cpu().numpy()
599
+ kdtree_anchor = KDTree(fixed_v_cpu)
600
+ kdtree_mesh_v = KDTree(mesh_v)
601
+ _, idx_anchor = kdtree_anchor.query(mesh_v, k=1)
602
+ _, idx_mesh_v = kdtree_mesh_v.query(mesh_v, k=25)
603
+ idx_anchor = idx_anchor.squeeze()
604
+ neighbors = torch.tensor(mesh_v).cuda()[idx_mesh_v] # V, 25, 3
605
+ # calculate the distances neighbors [V, 25, 3]; mesh_v [V, 3] -> [V, 25]
606
+ neighbor_dists = torch.norm(neighbors - torch.tensor(mesh_v).cuda()[:, None], dim=-1)
607
+ neighbor_dists[neighbor_dists > 0.06] = 114514.
608
+ neighbor_weights = torch.exp(-neighbor_dists * 1.)
609
+ neighbor_weights = neighbor_weights / neighbor_weights.sum(dim=1, keepdim=True)
610
+ anchors = fixed_v[idx_anchor] # V, 3
611
+ anchor_normals = calc_vertex_normals(fixed_v, fixed_f)[idx_anchor] # V, 3
612
+ dis_anchor = torch.clamp(((anchors - torch.tensor(mesh_v).cuda()) * anchor_normals).sum(-1), min=0) + 0.01
613
+ vec_anchor = dis_anchor[:, None] * anchor_normals # V, 3
614
+ vec_anchor = vec_anchor[idx_mesh_v] # V, 25, 3
615
+ weighted_vec_anchor = (vec_anchor * neighbor_weights[:, :, None]).sum(1) # V, 3
616
+ mesh_v += weighted_vec_anchor.cpu().numpy()
617
+
618
+ t = trimesh.Trimesh(vertices=mesh_v, faces=mesh_f)
619
+
620
+ mesh_v = torch.tensor(mesh_v, device='cuda', dtype=torch.float32)
621
+ mesh_f = torch.tensor(mesh_f, device='cuda')
622
+
623
+ new_mesh, simp_v, simp_f = geo_refine(mesh_v, mesh_f, colors, normals, fixed_v=fixed_v, fixed_f=fixed_f, distract_mask=distract_mask, distract_bbox=distract_bbox)
624
+
625
+ # my mesh flow weight by nearest vertexs
626
+ try:
627
+ if fixed_v is not None and fixed_f is not None and level != 0:
628
+ new_mesh_v = new_mesh.verts_packed().cpu().numpy()
629
+
630
+ fixed_v_cpu = fixed_v.cpu().numpy()
631
+ kdtree_anchor = KDTree(fixed_v_cpu)
632
+ kdtree_mesh_v = KDTree(new_mesh_v)
633
+ _, idx_anchor = kdtree_anchor.query(new_mesh_v, k=1)
634
+ _, idx_mesh_v = kdtree_mesh_v.query(new_mesh_v, k=25)
635
+ idx_anchor = idx_anchor.squeeze()
636
+ neighbors = torch.tensor(new_mesh_v).cuda()[idx_mesh_v] # V, 25, 3
637
+ # calculate the distances neighbors [V, 25, 3]; new_mesh_v [V, 3] -> [V, 25]
638
+ neighbor_dists = torch.norm(neighbors - torch.tensor(new_mesh_v).cuda()[:, None], dim=-1)
639
+ neighbor_dists[neighbor_dists > 0.06] = 114514.
640
+ neighbor_weights = torch.exp(-neighbor_dists * 1.)
641
+ neighbor_weights = neighbor_weights / neighbor_weights.sum(dim=1, keepdim=True)
642
+ anchors = fixed_v[idx_anchor] # V, 3
643
+ anchor_normals = calc_vertex_normals(fixed_v, fixed_f)[idx_anchor] # V, 3
644
+ dis_anchor = torch.clamp(((anchors - torch.tensor(new_mesh_v).cuda()) * anchor_normals).sum(-1), min=0) + 0.01
645
+ vec_anchor = dis_anchor[:, None] * anchor_normals # V, 3
646
+ vec_anchor = vec_anchor[idx_mesh_v] # V, 25, 3
647
+ weighted_vec_anchor = (vec_anchor * neighbor_weights[:, :, None]).sum(1) # V, 3
648
+ new_mesh_v += weighted_vec_anchor.cpu().numpy()
649
+
650
+ # replace new_mesh verts with new_mesh_v
651
+ new_mesh = Meshes(verts=[torch.tensor(new_mesh_v, device='cuda')], faces=new_mesh.faces_list(), textures=new_mesh.textures)
652
+
653
+ except Exception as e:
654
+ pass
655
+
656
+ notsimp_v, notsimp_f, notsimp_t = new_mesh.verts_packed(), new_mesh.faces_packed(), new_mesh.textures.verts_features_packed()
657
+
658
+ if fixed_v is None:
659
+ fixed_v, fixed_f = simp_v, simp_f
660
+ complete_v, complete_f, complete_t = notsimp_v, notsimp_f, notsimp_t
661
+ else:
662
+ fixed_f = torch.cat([fixed_f, simp_f + fixed_v.shape[0]], dim=0)
663
+ fixed_v = torch.cat([fixed_v, simp_v], dim=0)
664
+
665
+ complete_f = torch.cat([complete_f, notsimp_f + complete_v.shape[0]], dim=0)
666
+ complete_v = torch.cat([complete_v, notsimp_v], dim=0)
667
+ complete_t = torch.cat([complete_t, notsimp_t], dim=0)
668
+
669
+ if level == 2:
670
+ new_mesh = Meshes(verts=[new_mesh.verts_packed()], faces=[new_mesh.faces_packed()], textures=pytorch3d.renderer.mesh.textures.TexturesVertex(verts_features=[torch.ones_like(new_mesh.textures.verts_features_packed(), device=new_mesh.verts_packed().device)*0.5]))
671
+
672
+ save_py3dmesh_with_trimesh_fast(new_mesh, meshes[name_idx].replace('.obj', '_refined.obj'), apply_sRGB_to_LinearRGB=False)
673
+ results.append(meshes[name_idx].replace('.obj', '_refined.obj'))
674
+
675
+ # save whole mesh
676
+ save_py3dmesh_with_trimesh_fast(Meshes(verts=[complete_v], faces=[complete_f], textures=pytorch3d.renderer.mesh.textures.TexturesVertex(verts_features=[complete_t])), meshes[name_idx].replace('.obj', '_refined_whole.obj'), apply_sRGB_to_LinearRGB=False)
677
+ results.append(meshes[name_idx].replace('.obj', '_refined_whole.obj'))
678
+
679
+ return results
680
+
681
+
682
+ class InferSlrmAPI:
683
+ def __init__(self, config):
684
+ self.config_path = config['config_path']
685
+ self.config = OmegaConf.load(self.config_path)
686
+ self.config_name = os.path.basename(self.config_path).replace('.yaml', '')
687
+ self.model_config = self.config.model_config
688
+ self.infer_config = self.config.infer_config
689
+ self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
690
+ self.model = instantiate_from_config(self.model_config)
691
+ state_dict = torch.load(self.infer_config.model_path, map_location='cpu')
692
+ self.model.load_state_dict(state_dict, strict=False)
693
+ self.model = self.model.to(self.device)
694
+ self.model.init_flexicubes_geometry(self.device, fovy=30.0, is_ortho=self.model.is_ortho)
695
+ self.model = self.model.eval()
696
+
697
+ def gen(self, imgs):
698
+ imgs = [ cv2.imread(img[0])[:, :, ::-1] for img in imgs ]
699
+ imgs = np.stack(imgs, axis=0).astype(np.float32) / 255.0
700
+ imgs = torch.from_numpy(np.array(imgs)).permute(0, 3, 1, 2).contiguous().float() # (6, 3, 1024, 1024)
701
+ mesh_glb_fpaths = self.make3d(imgs)
702
+ return mesh_glb_fpaths[1:4] + mesh_glb_fpaths[0:1]
703
+
704
+ def make3d(self, images):
705
+ input_cameras = torch.tensor(np.load('slrm/cameras.npy')).to(device)
706
+
707
+ images = images.unsqueeze(0).to(device)
708
+ images = v2.functional.resize(images, (320, 320), interpolation=3, antialias=True).clamp(0, 1)
709
+
710
+ mesh_fpath = tempfile.NamedTemporaryFile(suffix=f".obj", delete=False).name
711
+ print(mesh_fpath)
712
+ mesh_basename = os.path.basename(mesh_fpath).split('.')[0]
713
+ mesh_dirname = os.path.dirname(mesh_fpath)
714
+
715
+ with torch.no_grad():
716
+ # get triplane
717
+ planes = self.model.forward_planes(images, input_cameras.float())
718
+
719
+ # get mesh
720
+ mesh_glb_fpaths = []
721
+ for j in range(4):
722
+ mesh_glb_fpath = self.make_mesh(mesh_fpath.replace(mesh_fpath[-4:], f'_{j}{mesh_fpath[-4:]}'), planes, level=[0, 3, 4, 2][j])
723
+ mesh_glb_fpaths.append(mesh_glb_fpath)
724
+
725
+ return mesh_glb_fpaths
726
+
727
+ def make_mesh(self, mesh_fpath, planes, level=None):
728
+ mesh_basename = os.path.basename(mesh_fpath).split('.')[0]
729
+ mesh_dirname = os.path.dirname(mesh_fpath)
730
+ mesh_glb_fpath = os.path.join(mesh_dirname, f"{mesh_basename}.glb")
731
+
732
+ with torch.no_grad():
733
+ # get mesh
734
+ mesh_out = self.model.extract_mesh(
735
+ planes,
736
+ use_texture_map=False,
737
+ levels=torch.tensor([level]).to(device),
738
+ **self.infer_config,
739
+ )
740
+
741
+ vertices, faces, vertex_colors = mesh_out
742
+ vertices = vertices[:, [1, 2, 0]]
743
+
744
+ if level == 2:
745
+ # fill all vertex_colors with 127
746
+ vertex_colors = np.ones_like(vertex_colors) * 127
747
+
748
+ save_obj(vertices, faces, vertex_colors, mesh_fpath)
749
+
750
+ return mesh_fpath
751
+
752
+
753
+ class InferMultiviewAPI:
754
+ def __init__(self, config):
755
+ parser = argparse.ArgumentParser()
756
+ parser.add_argument("--seed", type=int, default=42)
757
+ parser.add_argument("--num_views", type=int, default=6)
758
+ parser.add_argument("--num_levels", type=int, default=3)
759
+ parser.add_argument("--pretrained_path", type=str, default='./ckpt/StdGEN-multiview-1024')
760
+ parser.add_argument("--height", type=int, default=1024)
761
+ parser.add_argument("--width", type=int, default=576)
762
+ self.cfg = parser.parse_args()
763
+ self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
764
+ self.pipeline = load_multiview_pipeline(self.cfg)
765
+ self.results = {}
766
+ if torch.cuda.is_available():
767
+ self.pipeline.to(device)
768
+
769
+ self.image_transforms = [transforms.Resize(int(max(self.cfg.height, self.cfg.width))),
770
+ transforms.CenterCrop((self.cfg.height, self.cfg.width)),
771
+ transforms.ToTensor(),
772
+ transforms.Lambda(lambda x: x * 2. - 1),
773
+ ]
774
+ self.image_transforms = transforms.Compose(self.image_transforms)
775
+
776
+ prompt_embeds_path = './multiview/fixed_prompt_embeds_6view'
777
+ self.normal_text_embeds = torch.load(f'{prompt_embeds_path}/normal_embeds.pt')
778
+ self.color_text_embeds = torch.load(f'{prompt_embeds_path}/clr_embeds.pt')
779
+ self.total_views = self.cfg.num_views
780
+
781
+
782
+ def process_im(self, im):
783
+ im = self.image_transforms(im)
784
+ return im
785
+
786
+
787
+ def gen(self, img, seed, num_levels):
788
+ set_seed(seed)
789
+ data = {}
790
+
791
+ cond_im_rgb = self.process_im(img)
792
+ cond_im_rgb = torch.stack([cond_im_rgb] * self.total_views, dim=0)
793
+ data["image_cond_rgb"] = cond_im_rgb[None, ...]
794
+ data["normal_prompt_embeddings"] = self.normal_text_embeds[None, ...]
795
+ data["color_prompt_embeddings"] = self.color_text_embeds[None, ...]
796
+
797
+ results = run_multiview_infer(data, self.pipeline, self.cfg, num_levels=num_levels)
798
+ for k in results:
799
+ self.results[k] = results[k]
800
+ return results
801
+
802
+
803
+ class InferCanonicalAPI:
804
+ def __init__(self, config):
805
+ self.config = config
806
+ self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
807
+
808
+ self.config_path = config['config_path']
809
+ self.loaded_config = OmegaConf.load(self.config_path)
810
+
811
+ self.setup(**self.loaded_config)
812
+
813
+ def setup(self,
814
+ validation: Dict,
815
+ pretrained_model_path: str,
816
+ local_crossattn: bool = True,
817
+ unet_from_pretrained_kwargs=None,
818
+ unet_condition_type=None,
819
+ use_noise=True,
820
+ noise_d=256,
821
+ timestep: int = 40,
822
+ width_input: int = 640,
823
+ height_input: int = 1024,
824
+ ):
825
+ self.width_input = width_input
826
+ self.height_input = height_input
827
+ self.timestep = timestep
828
+ self.use_noise = use_noise
829
+ self.noise_d = noise_d
830
+ self.validation = validation
831
+ self.unet_condition_type = unet_condition_type
832
+ self.pretrained_model_path = pretrained_model_path
833
+
834
+ self.tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer")
835
+ self.text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder")
836
+ self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(pretrained_model_path, subfolder="image_encoder")
837
+ self.feature_extractor = CLIPImageProcessor()
838
+ self.vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae")
839
+ self.unet = UNetMV2DConditionModel.from_pretrained_2d(pretrained_model_path, subfolder="unet", local_crossattn=local_crossattn, **unet_from_pretrained_kwargs)
840
+ self.ref_unet = UNetMV2DRefModel.from_pretrained_2d(pretrained_model_path, subfolder="ref_unet", local_crossattn=local_crossattn, **unet_from_pretrained_kwargs)
841
+
842
+ self.text_encoder.to(device, dtype=weight_dtype)
843
+ self.image_encoder.to(device, dtype=weight_dtype)
844
+ self.vae.to(device, dtype=weight_dtype)
845
+ self.ref_unet.to(device, dtype=weight_dtype)
846
+ self.unet.to(device, dtype=weight_dtype)
847
+
848
+ self.vae.requires_grad_(False)
849
+ self.ref_unet.requires_grad_(False)
850
+ self.unet.requires_grad_(False)
851
+
852
+ self.noise_scheduler = DDIMScheduler.from_pretrained(pretrained_model_path, subfolder="scheduler-zerosnr")
853
+ self.validation_pipeline = CanonicalizationPipeline(
854
+ vae=self.vae, text_encoder=self.text_encoder, tokenizer=self.tokenizer, unet=self.unet, ref_unet=self.ref_unet,feature_extractor=self.feature_extractor,image_encoder=self.image_encoder,
855
+ scheduler=self.noise_scheduler
856
+ )
857
+ self.validation_pipeline.set_progress_bar_config(disable=True)
858
+
859
+ self.bkg_remover = BkgRemover()
860
+
861
+ def canonicalize(self, image, seed):
862
+ generator = torch.Generator(device=device).manual_seed(seed)
863
+ return inference(
864
+ self.validation_pipeline, self.bkg_remover, image, self.vae, self.feature_extractor, self.image_encoder, self.unet, self.ref_unet, self.tokenizer, self.text_encoder,
865
+ self.pretrained_model_path, generator, self.validation, self.width_input, self.height_input, self.unet_condition_type,
866
+ use_noise=self.use_noise, noise_d=self.noise_d, crop=True, seed=seed, timestep=self.timestep
867
+ )
868
+
869
+ def gen(self, img_input, seed=0):
870
+ if np.array(img_input).shape[-1] == 4 and np.array(img_input)[..., 3].min() == 255:
871
+ # convert to RGB
872
+ img_input = img_input.convert("RGB")
873
+ img_output = self.canonicalize(img_input, seed)
874
+
875
+ max_dim = max(img_output.width, img_output.height)
876
+ new_image = Image.new("RGBA", (max_dim, max_dim))
877
+ left = (max_dim - img_output.width) // 2
878
+ top = (max_dim - img_output.height) // 2
879
+ new_image.paste(img_output, (left, top))
880
+
881
+ return new_image
infer_canonicalize.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import glob
3
+
4
+ import io
5
+ import argparse
6
+ import inspect
7
+ import os
8
+ import random
9
+ from typing import Dict, Optional, Tuple
10
+ from omegaconf import OmegaConf
11
+ import numpy as np
12
+
13
+ import torch
14
+
15
+ from diffusers import AutoencoderKL, DDIMScheduler
16
+ from diffusers.utils import check_min_version
17
+ from tqdm.auto import tqdm
18
+ from transformers import CLIPTextModel, CLIPTokenizer, CLIPImageProcessor, CLIPVisionModelWithProjection
19
+ from torchvision import transforms
20
+
21
+ from canonicalize.models.unet_mv2d_condition import UNetMV2DConditionModel
22
+ from canonicalize.models.unet_mv2d_ref import UNetMV2DRefModel
23
+ from canonicalize.pipeline_canonicalize import CanonicalizationPipeline
24
+ from einops import rearrange
25
+ from torchvision.utils import save_image
26
+ import json
27
+ import cv2
28
+
29
+ import onnxruntime as rt
30
+ from huggingface_hub.file_download import hf_hub_download
31
+ from rm_anime_bg.cli import get_mask, SCALE
32
+
33
+ check_min_version("0.24.0")
34
+ weight_dtype = torch.float16
35
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
36
+
37
+
38
+ class BkgRemover:
39
+ def __init__(self, force_cpu: Optional[bool] = True):
40
+ session_infer_path = hf_hub_download(
41
+ repo_id="skytnt/anime-seg", filename="isnetis.onnx",
42
+ )
43
+ providers: list[str] = ["CPUExecutionProvider"]
44
+ if not force_cpu and "CUDAExecutionProvider" in rt.get_available_providers():
45
+ providers = ["CUDAExecutionProvider"]
46
+
47
+ self.session_infer = rt.InferenceSession(
48
+ session_infer_path, providers=providers,
49
+ )
50
+
51
+ def remove_background(
52
+ self,
53
+ img: np.ndarray,
54
+ alpha_min: float,
55
+ alpha_max: float,
56
+ ) -> list:
57
+ img = np.array(img)
58
+ mask = get_mask(self.session_infer, img)
59
+ mask[mask < alpha_min] = 0.0
60
+ mask[mask > alpha_max] = 1.0
61
+ img_after = (mask * img).astype(np.uint8)
62
+ mask = (mask * SCALE).astype(np.uint8)
63
+ img_after = np.concatenate([img_after, mask], axis=2, dtype=np.uint8)
64
+ return Image.fromarray(img_after)
65
+
66
+
67
+ def set_seed(seed):
68
+ random.seed(seed)
69
+ np.random.seed(seed)
70
+ torch.manual_seed(seed)
71
+ torch.cuda.manual_seed_all(seed)
72
+
73
+
74
+ def process_image(image, totensor, width, height):
75
+ assert image.mode == "RGBA"
76
+
77
+ # Find non-transparent pixels
78
+ non_transparent = np.nonzero(np.array(image)[..., 3])
79
+ min_x, max_x = non_transparent[1].min(), non_transparent[1].max()
80
+ min_y, max_y = non_transparent[0].min(), non_transparent[0].max()
81
+ image = image.crop((min_x, min_y, max_x, max_y))
82
+
83
+ # paste to center
84
+ max_dim = max(image.width, image.height)
85
+ max_height = int(max_dim * 1.2)
86
+ max_width = int(max_dim / (height/width) * 1.2)
87
+ new_image = Image.new("RGBA", (max_width, max_height))
88
+ left = (max_width - image.width) // 2
89
+ top = (max_height - image.height) // 2
90
+ new_image.paste(image, (left, top))
91
+
92
+ image = new_image.resize((width, height), resample=Image.BICUBIC)
93
+ image = np.array(image)
94
+ image = image.astype(np.float32) / 255.
95
+ assert image.shape[-1] == 4 # RGBA
96
+ alpha = image[..., 3:4]
97
+ bg_color = np.array([1., 1., 1.], dtype=np.float32)
98
+ image = image[..., :3] * alpha + bg_color * (1 - alpha)
99
+ return totensor(image)
100
+
101
+
102
+ @torch.no_grad()
103
+ def inference(validation_pipeline, bkg_remover, input_image, vae, feature_extractor, image_encoder, unet, ref_unet, tokenizer,
104
+ text_encoder, pretrained_model_path, generator, validation, val_width, val_height, unet_condition_type,
105
+ use_noise=True, noise_d=256, crop=False, seed=100, timestep=20):
106
+ set_seed(seed)
107
+
108
+ totensor = transforms.ToTensor()
109
+
110
+ prompts = "high quality, best quality"
111
+ prompt_ids = tokenizer(
112
+ prompts, max_length=tokenizer.model_max_length, padding="max_length", truncation=True,
113
+ return_tensors="pt"
114
+ ).input_ids[0]
115
+
116
+ # (B*Nv, 3, H, W)
117
+ B = 1
118
+ if input_image.mode != "RGBA":
119
+ # remove background
120
+ input_image = bkg_remover.remove_background(input_image, 0.1, 0.9)
121
+ imgs_in = process_image(input_image, totensor, val_width, val_height)
122
+ imgs_in = rearrange(imgs_in.unsqueeze(0).unsqueeze(0), "B Nv C H W -> (B Nv) C H W")
123
+
124
+ with torch.autocast('cuda' if torch.cuda.is_available() else 'cpu', dtype=weight_dtype):
125
+ imgs_in = imgs_in.to(device=device)
126
+ # B*Nv images
127
+ out = validation_pipeline(prompt=prompts, image=imgs_in.to(weight_dtype), generator=generator,
128
+ num_inference_steps=timestep, prompt_ids=prompt_ids,
129
+ height=val_height, width=val_width, unet_condition_type=unet_condition_type,
130
+ use_noise=use_noise, **validation,)
131
+ out = rearrange(out, "B C f H W -> (B f) C H W", f=1)
132
+
133
+ img_buf = io.BytesIO()
134
+ save_image(out[0], img_buf, format='PNG')
135
+ img_buf.seek(0)
136
+ img = Image.open(img_buf)
137
+
138
+ torch.cuda.empty_cache()
139
+ return img
140
+
141
+
142
+ @torch.no_grad()
143
+ def main(
144
+ input_dir: str,
145
+ output_dir: str,
146
+ pretrained_model_path: str,
147
+ validation: Dict,
148
+ local_crossattn: bool = True,
149
+ unet_from_pretrained_kwargs=None,
150
+ unet_condition_type=None,
151
+ use_noise=True,
152
+ noise_d=256,
153
+ seed: int = 42,
154
+ timestep: int = 40,
155
+ width_input: int = 640,
156
+ height_input: int = 1024,
157
+ ):
158
+ *_, config = inspect.getargvalues(inspect.currentframe())
159
+
160
+ tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer")
161
+ text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder")
162
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(pretrained_model_path, subfolder="image_encoder")
163
+ feature_extractor = CLIPImageProcessor()
164
+ vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae")
165
+ unet = UNetMV2DConditionModel.from_pretrained_2d(pretrained_model_path, subfolder="unet", local_crossattn=local_crossattn, **unet_from_pretrained_kwargs)
166
+ ref_unet = UNetMV2DRefModel.from_pretrained_2d(pretrained_model_path, subfolder="ref_unet", local_crossattn=local_crossattn, **unet_from_pretrained_kwargs)
167
+
168
+ text_encoder.to(device, dtype=weight_dtype)
169
+ image_encoder.to(device, dtype=weight_dtype)
170
+ vae.to(device, dtype=weight_dtype)
171
+ ref_unet.to(device, dtype=weight_dtype)
172
+ unet.to(device, dtype=weight_dtype)
173
+
174
+ vae.requires_grad_(False)
175
+ unet.requires_grad_(False)
176
+ ref_unet.requires_grad_(False)
177
+
178
+ # set pipeline
179
+ noise_scheduler = DDIMScheduler.from_pretrained(pretrained_model_path, subfolder="scheduler-zerosnr")
180
+ validation_pipeline = CanonicalizationPipeline(
181
+ vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, ref_unet=ref_unet,feature_extractor=feature_extractor,image_encoder=image_encoder,
182
+ scheduler=noise_scheduler
183
+ )
184
+ validation_pipeline.set_progress_bar_config(disable=True)
185
+
186
+ bkg_remover = BkgRemover()
187
+
188
+ def canonicalize(image, width, height, seed, timestep):
189
+ generator = torch.Generator(device=device).manual_seed(seed)
190
+ return inference(
191
+ validation_pipeline, bkg_remover, image, vae, feature_extractor, image_encoder, unet, ref_unet, tokenizer, text_encoder,
192
+ pretrained_model_path, generator, validation, width, height, unet_condition_type,
193
+ use_noise=use_noise, noise_d=noise_d, crop=True, seed=seed, timestep=timestep
194
+ )
195
+
196
+ img_paths = sorted(glob.glob(os.path.join(input_dir, "*.png")))
197
+ os.makedirs(output_dir, exist_ok=True)
198
+
199
+ for path in tqdm(img_paths):
200
+ img_input = Image.open(path)
201
+ if np.array(img_input)[..., 3].min() == 255:
202
+ # convert to RGB
203
+ img_input = img_input.convert("RGB")
204
+ img_output = canonicalize(img_input, width_input, height_input, seed, timestep)
205
+ img_output.save(os.path.join(output_dir, f"{os.path.basename(path).split('.')[0]}.png"))
206
+
207
+ if __name__ == "__main__":
208
+ parser = argparse.ArgumentParser()
209
+ parser.add_argument("--config", type=str, default="./configs/canonicalization-infer.yaml")
210
+ parser.add_argument("--input_dir", type=str, default="./input_cases")
211
+ parser.add_argument("--output_dir", type=str, default="./result/apose")
212
+ parser.add_argument("--seed", type=int, default=42)
213
+ args = parser.parse_args()
214
+
215
+ main(**OmegaConf.load(args.config), seed=args.seed, input_dir=args.input_dir, output_dir=args.output_dir)
infer_multiview.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import cv2
4
+ import glob
5
+ import numpy as np
6
+ import matplotlib.pyplot as plt
7
+ from typing import Dict, Optional, List
8
+ from omegaconf import OmegaConf, DictConfig
9
+ from PIL import Image
10
+ from pathlib import Path
11
+ from dataclasses import dataclass
12
+ from typing import Dict
13
+ import torch
14
+ import torch.nn.functional as F
15
+ import torch.utils.checkpoint
16
+ import torchvision.transforms.functional as TF
17
+ from torch.utils.data import Dataset, DataLoader
18
+ from torchvision import transforms
19
+ from torchvision.utils import make_grid, save_image
20
+ from accelerate.utils import set_seed
21
+ from tqdm.auto import tqdm
22
+ from einops import rearrange, repeat
23
+ from multiview.pipeline_multiclass import StableUnCLIPImg2ImgPipeline
24
+
25
+ weight_dtype = torch.float16
26
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
27
+
28
+ def tensor_to_numpy(tensor):
29
+ return tensor.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
30
+
31
+
32
+ os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
33
+
34
+ def nonzero_normalize_depth(depth, mask=None):
35
+ if mask.max() > 0: # not all transparent
36
+ nonzero_depth_min = depth[mask > 0].min()
37
+ else:
38
+ nonzero_depth_min = 0
39
+ depth = (depth - nonzero_depth_min) / depth.max()
40
+ return np.clip(depth, 0, 1)
41
+
42
+
43
+ class SingleImageData(Dataset):
44
+ def __init__(self,
45
+ input_dir,
46
+ prompt_embeds_path='./multiview/fixed_prompt_embeds_6view',
47
+ image_transforms=[],
48
+ total_views=6,
49
+ ext="png",
50
+ return_paths=True,
51
+ ) -> None:
52
+ """Create a dataset from a folder of images.
53
+ If you pass in a root directory it will be searched for images
54
+ ending in ext (ext can be a list)
55
+ """
56
+ self.input_dir = Path(input_dir)
57
+ self.return_paths = return_paths
58
+ self.total_views = total_views
59
+
60
+ self.paths = glob.glob(str(self.input_dir / f'*.{ext}'))
61
+
62
+ print('============= length of dataset %d =============' % len(self.paths))
63
+ self.tform = image_transforms
64
+ self.normal_text_embeds = torch.load(f'{prompt_embeds_path}/normal_embeds.pt')
65
+ self.color_text_embeds = torch.load(f'{prompt_embeds_path}/clr_embeds.pt')
66
+
67
+
68
+ def __len__(self):
69
+ return len(self.paths)
70
+
71
+
72
+ def load_rgb(self, path, color):
73
+ img = plt.imread(path)
74
+ img = Image.fromarray(np.uint8(img * 255.))
75
+ new_img = Image.new("RGB", (1024, 1024))
76
+ # white background
77
+ width, height = img.size
78
+ new_width = int(width / height * 1024)
79
+ img = img.resize((new_width, 1024))
80
+ new_img.paste((255, 255, 255), (0, 0, 1024, 1024))
81
+ offset = (1024 - new_width) // 2
82
+ new_img.paste(img, (offset, 0))
83
+ return new_img
84
+
85
+ def __getitem__(self, index):
86
+ data = {}
87
+ filename = self.paths[index]
88
+
89
+ if self.return_paths:
90
+ data["path"] = str(filename)
91
+ color = 1.0
92
+ cond_im_rgb = self.process_im(self.load_rgb(filename, color))
93
+ cond_im_rgb = torch.stack([cond_im_rgb] * self.total_views, dim=0)
94
+
95
+ data["image_cond_rgb"] = cond_im_rgb
96
+ data["normal_prompt_embeddings"] = self.normal_text_embeds
97
+ data["color_prompt_embeddings"] = self.color_text_embeds
98
+ data["filename"] = filename.split('/')[-1]
99
+
100
+ return data
101
+
102
+ def process_im(self, im):
103
+ im = im.convert("RGB")
104
+ return self.tform(im)
105
+
106
+ def tensor_to_image(self, tensor):
107
+ return Image.fromarray(np.uint8(tensor.numpy() * 255.))
108
+
109
+
110
+ @dataclass
111
+ class TestConfig:
112
+ pretrained_model_name_or_path: str
113
+ pretrained_unet_path:Optional[str]
114
+ revision: Optional[str]
115
+ validation_dataset: Dict
116
+ save_dir: str
117
+ seed: Optional[int]
118
+ validation_batch_size: int
119
+ dataloader_num_workers: int
120
+ save_mode: str
121
+ local_rank: int
122
+
123
+ pipe_kwargs: Dict
124
+ pipe_validation_kwargs: Dict
125
+ unet_from_pretrained_kwargs: Dict
126
+ validation_grid_nrow: int
127
+ camera_embedding_lr_mult: float
128
+
129
+ num_views: int
130
+ camera_embedding_type: str
131
+
132
+ pred_type: str
133
+ regress_elevation: bool
134
+ enable_xformers_memory_efficient_attention: bool
135
+
136
+ cond_on_normals: bool
137
+ cond_on_colors: bool
138
+
139
+ regress_elevation: bool
140
+ regress_focal_length: bool
141
+
142
+
143
+
144
+ def convert_to_numpy(tensor):
145
+ return tensor.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
146
+
147
+ def save_image(tensor, fp):
148
+ ndarr = convert_to_numpy(tensor)
149
+ save_image_numpy(ndarr, fp)
150
+ return ndarr
151
+
152
+ def save_image_numpy(ndarr, fp):
153
+ im = Image.fromarray(ndarr)
154
+ # pad to square
155
+ if im.size[0] != im.size[1]:
156
+ size = max(im.size)
157
+ new_im = Image.new("RGB", (size, size))
158
+ # set to white
159
+ new_im.paste((255, 255, 255), (0, 0, size, size))
160
+ new_im.paste(im, ((size - im.size[0]) // 2, (size - im.size[1]) // 2))
161
+ im = new_im
162
+ # resize to 1024x1024
163
+ im = im.resize((1024, 1024), Image.LANCZOS)
164
+ im.save(fp)
165
+
166
+ def run_multiview_infer(dataloader, pipeline, cfg: TestConfig, save_dir, num_levels=3):
167
+ if cfg.seed is None:
168
+ generator = None
169
+ else:
170
+ generator = torch.Generator(device=pipeline.unet.device).manual_seed(cfg.seed)
171
+
172
+ images_cond = []
173
+ for _, batch in tqdm(enumerate(dataloader)):
174
+ torch.cuda.empty_cache()
175
+ images_cond.append(batch['image_cond_rgb'][:, 0].cuda())
176
+ imgs_in = torch.cat([batch['image_cond_rgb']]*2, dim=0).cuda()
177
+ num_views = imgs_in.shape[1]
178
+ imgs_in = rearrange(imgs_in, "B Nv C H W -> (B Nv) C H W")# (B*Nv, 3, H, W)
179
+
180
+ target_h, target_w = imgs_in.shape[-2], imgs_in.shape[-1]
181
+
182
+ normal_prompt_embeddings, clr_prompt_embeddings = batch['normal_prompt_embeddings'].cuda(), batch['color_prompt_embeddings'].cuda()
183
+ prompt_embeddings = torch.cat([normal_prompt_embeddings, clr_prompt_embeddings], dim=0)
184
+ prompt_embeddings = rearrange(prompt_embeddings, "B Nv N C -> (B Nv) N C")
185
+
186
+ # B*Nv images
187
+ unet_out = pipeline(
188
+ imgs_in, None, prompt_embeds=prompt_embeddings,
189
+ generator=generator, guidance_scale=3.0, output_type='pt', num_images_per_prompt=1,
190
+ height=cfg.height, width=cfg.width,
191
+ num_inference_steps=40, eta=1.0,
192
+ num_levels=num_levels,
193
+ )
194
+
195
+ for level in range(num_levels):
196
+ out = unet_out[level].images
197
+ bsz = out.shape[0] // 2
198
+
199
+ normals_pred = out[:bsz]
200
+ images_pred = out[bsz:]
201
+
202
+ cur_dir = save_dir
203
+ os.makedirs(cur_dir, exist_ok=True)
204
+
205
+ for i in range(bsz//num_views):
206
+ scene = batch['filename'][i].split('.')[0]
207
+ scene_dir = os.path.join(cur_dir, scene, f'level{level}')
208
+ os.makedirs(scene_dir, exist_ok=True)
209
+
210
+ img_in_ = images_cond[-1][i].to(out.device)
211
+ for j in range(num_views):
212
+ view = VIEWS[j]
213
+ idx = i*num_views + j
214
+ normal = normals_pred[idx]
215
+ color = images_pred[idx]
216
+
217
+ ## save color and normal---------------------
218
+ normal_filename = f"normal_{j}.png"
219
+ rgb_filename = f"color_{j}.png"
220
+ save_image(normal, os.path.join(scene_dir, normal_filename))
221
+ save_image(color, os.path.join(scene_dir, rgb_filename))
222
+
223
+ torch.cuda.empty_cache()
224
+
225
+ def load_multiview_pipeline(cfg):
226
+ pipeline = StableUnCLIPImg2ImgPipeline.from_pretrained(
227
+ cfg.pretrained_path,
228
+ torch_dtype=torch.float16,)
229
+ pipeline.unet.enable_xformers_memory_efficient_attention()
230
+ if torch.cuda.is_available():
231
+ pipeline.to(device)
232
+ return pipeline
233
+
234
+ def main(
235
+ cfg: TestConfig
236
+ ):
237
+ set_seed(cfg.seed)
238
+ pipeline = load_multiview_pipeline(cfg)
239
+ if torch.cuda.is_available():
240
+ pipeline.to(device)
241
+
242
+ image_transforms = [transforms.Resize(int(max(cfg.height, cfg.width))),
243
+ transforms.CenterCrop((cfg.height, cfg.width)),
244
+ transforms.ToTensor(),
245
+ transforms.Lambda(lambda x: x * 2. - 1),
246
+ ]
247
+ image_transforms = transforms.Compose(image_transforms)
248
+ dataset = SingleImageData(image_transforms=image_transforms, input_dir=cfg.input_dir, total_views=cfg.num_views)
249
+ dataloader = torch.utils.data.DataLoader(
250
+ dataset, batch_size=1, shuffle=False, num_workers=1
251
+ )
252
+ os.makedirs(cfg.output_dir, exist_ok=True)
253
+
254
+ with torch.no_grad():
255
+ run_multiview_infer(dataloader, pipeline, cfg, cfg.output_dir, num_levels=cfg.num_levels)
256
+
257
+
258
+ if __name__ == '__main__':
259
+ parser = argparse.ArgumentParser()
260
+ parser.add_argument("--seed", type=int, default=42)
261
+ parser.add_argument("--num_views", type=int, default=6)
262
+ parser.add_argument("--num_levels", type=int, default=3)
263
+ parser.add_argument("--pretrained_path", type=str, default='./ckpt/StdGEN-multiview-1024')
264
+ parser.add_argument("--height", type=int, default=1024)
265
+ parser.add_argument("--width", type=int, default=576)
266
+ parser.add_argument("--input_dir", type=str, default='./result/apose')
267
+ parser.add_argument("--output_dir", type=str, default='./result/multiview')
268
+ cfg = parser.parse_args()
269
+
270
+ if cfg.num_views == 6:
271
+ VIEWS = ['front', 'front_right', 'right', 'back', 'left', 'front_left']
272
+ else:
273
+ raise NotImplementedError(f"Number of views {cfg.num_views} not supported")
274
+ main(cfg)
infer_refine.py ADDED
@@ -0,0 +1,353 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import os
4
+ import trimesh
5
+ import argparse
6
+ import torch
7
+ import scipy
8
+ from PIL import Image
9
+
10
+ from refine.mesh_refine import geo_refine
11
+ from refine.func import make_star_cameras_orthographic
12
+ from refine.render import NormalsRenderer, calc_vertex_normals
13
+
14
+ from pytorch3d.structures import Meshes
15
+ from sklearn.neighbors import KDTree
16
+
17
+ from segment_anything import SamAutomaticMaskGenerator, sam_model_registry
18
+
19
+ sam = sam_model_registry["vit_h"](checkpoint="./ckpt/sam_vit_h_4b8939.pth").cuda()
20
+ generator = SamAutomaticMaskGenerator(
21
+ model=sam,
22
+ points_per_side=64,
23
+ pred_iou_thresh=0.80,
24
+ stability_score_thresh=0.92,
25
+ crop_n_layers=1,
26
+ crop_n_points_downscale_factor=2,
27
+ min_mask_region_area=100,
28
+ )
29
+
30
+
31
+ def fix_vert_color_glb(mesh_path):
32
+ from pygltflib import GLTF2, Material, PbrMetallicRoughness
33
+ obj1 = GLTF2().load(mesh_path)
34
+ obj1.meshes[0].primitives[0].material = 0
35
+ obj1.materials.append(Material(
36
+ pbrMetallicRoughness = PbrMetallicRoughness(
37
+ baseColorFactor = [1.0, 1.0, 1.0, 1.0],
38
+ metallicFactor = 0.,
39
+ roughnessFactor = 1.0,
40
+ ),
41
+ emissiveFactor = [0.0, 0.0, 0.0],
42
+ doubleSided = True,
43
+ ))
44
+ obj1.save(mesh_path)
45
+
46
+
47
+ def srgb_to_linear(c_srgb):
48
+ c_linear = np.where(c_srgb <= 0.04045, c_srgb / 12.92, ((c_srgb + 0.055) / 1.055) ** 2.4)
49
+ return c_linear.clip(0, 1.)
50
+
51
+
52
+ def save_py3dmesh_with_trimesh_fast(meshes: Meshes, save_glb_path, apply_sRGB_to_LinearRGB=True):
53
+ # convert from pytorch3d meshes to trimesh mesh
54
+ vertices = meshes.verts_packed().cpu().float().numpy()
55
+ triangles = meshes.faces_packed().cpu().long().numpy()
56
+ np_color = meshes.textures.verts_features_packed().cpu().float().numpy()
57
+ if save_glb_path.endswith(".glb"):
58
+ # rotate 180 along +Y
59
+ vertices[:, [0, 2]] = -vertices[:, [0, 2]]
60
+
61
+ if apply_sRGB_to_LinearRGB:
62
+ np_color = srgb_to_linear(np_color)
63
+ assert vertices.shape[0] == np_color.shape[0]
64
+ assert np_color.shape[1] == 3
65
+ assert 0 <= np_color.min() and np_color.max() <= 1.001, f"min={np_color.min()}, max={np_color.max()}"
66
+ np_color = np.clip(np_color, 0, 1)
67
+ mesh = trimesh.Trimesh(vertices=vertices, faces=triangles, vertex_colors=np_color)
68
+ mesh.remove_unreferenced_vertices()
69
+ # save mesh
70
+ mesh.export(save_glb_path)
71
+ if save_glb_path.endswith(".glb"):
72
+ fix_vert_color_glb(save_glb_path)
73
+ print(f"saving to {save_glb_path}")
74
+
75
+
76
+ def calc_horizontal_offset(target_img, source_img):
77
+ target_mask = target_img.astype(np.float32).sum(axis=-1) > 750
78
+ source_mask = source_img.astype(np.float32).sum(axis=-1) > 750
79
+ best_offset = -114514
80
+ for offset in range(-200, 200):
81
+ offset_mask = np.roll(source_mask, offset, axis=1)
82
+ overlap = (target_mask & offset_mask).sum()
83
+ if overlap > best_offset:
84
+ best_offset = overlap
85
+ best_offset_value = offset
86
+ return best_offset_value
87
+
88
+
89
+ def calc_horizontal_offset2(target_mask, source_img):
90
+ source_mask = source_img.astype(np.float32).sum(axis=-1) > 750
91
+ best_offset = -114514
92
+ for offset in range(-200, 200):
93
+ offset_mask = np.roll(source_mask, offset, axis=1)
94
+ overlap = (target_mask & offset_mask).sum()
95
+ if overlap > best_offset:
96
+ best_offset = overlap
97
+ best_offset_value = offset
98
+ return best_offset_value
99
+
100
+
101
+ def get_distract_mask(color_0, color_1, normal_0=None, normal_1=None, thres=0.25, ratio=0.50, outside_thres=0.10, outside_ratio=0.20):
102
+ distract_area = np.abs(color_0 - color_1).sum(axis=-1) > thres
103
+ if normal_0 is not None and normal_1 is not None:
104
+ distract_area |= np.abs(normal_0 - normal_1).sum(axis=-1) > thres
105
+ labeled_array, num_features = scipy.ndimage.label(distract_area)
106
+ results = []
107
+
108
+ random_sampled_points = []
109
+
110
+ for i in range(num_features + 1):
111
+ if np.sum(labeled_array == i) > 1000 and np.sum(labeled_array == i) < 100000:
112
+ results.append((i, np.sum(labeled_array == i)))
113
+ # random sample a point in the area
114
+ points = np.argwhere(labeled_array == i)
115
+ random_sampled_points.append(points[np.random.randint(0, points.shape[0])])
116
+
117
+ results = sorted(results, key=lambda x: x[1], reverse=True) # [1:]
118
+ distract_mask = np.zeros_like(distract_area)
119
+ distract_bbox = np.zeros_like(distract_area)
120
+ for i, _ in results:
121
+ distract_mask |= labeled_array == i
122
+ bbox = np.argwhere(labeled_array == i)
123
+ min_x, min_y = bbox.min(axis=0)
124
+ max_x, max_y = bbox.max(axis=0)
125
+ distract_bbox[min_x:max_x, min_y:max_y] = 1
126
+
127
+ points = np.array(random_sampled_points)[:, ::-1]
128
+ labels = np.ones(len(points), dtype=np.int32)
129
+
130
+ masks = generator.generate((color_1 * 255).astype(np.uint8))
131
+
132
+ outside_area = np.abs(color_0 - color_1).sum(axis=-1) < outside_thres
133
+
134
+ final_mask = np.zeros_like(distract_mask)
135
+ for iii, mask in enumerate(masks):
136
+ mask['segmentation'] = cv2.resize(mask['segmentation'].astype(np.float32), (1024, 1024)) > 0.5
137
+ intersection = np.logical_and(mask['segmentation'], distract_mask).sum()
138
+ total = mask['segmentation'].sum()
139
+ iou = intersection / total
140
+ outside_intersection = np.logical_and(mask['segmentation'], outside_area).sum()
141
+ outside_total = mask['segmentation'].sum()
142
+ outside_iou = outside_intersection / outside_total
143
+ if iou > ratio and outside_iou < outside_ratio:
144
+ final_mask |= mask['segmentation']
145
+
146
+ # calculate coverage
147
+ intersection = np.logical_and(final_mask, distract_mask).sum()
148
+ total = distract_mask.sum()
149
+ coverage = intersection / total
150
+
151
+ if coverage < 0.8:
152
+ # use original distract mask
153
+ final_mask = (distract_mask.copy() * 255).astype(np.uint8)
154
+ final_mask = cv2.dilate(final_mask, np.ones((3, 3), np.uint8), iterations=3)
155
+ labeled_array_dilate, num_features_dilate = scipy.ndimage.label(final_mask)
156
+ for i in range(num_features_dilate + 1):
157
+ if np.sum(labeled_array_dilate == i) < 200:
158
+ final_mask[labeled_array_dilate == i] = 255
159
+
160
+ final_mask = cv2.erode(final_mask, np.ones((3, 3), np.uint8), iterations=3)
161
+ final_mask = final_mask > 127
162
+
163
+ return distract_mask, distract_bbox, random_sampled_points, final_mask
164
+
165
+
166
+ if __name__ == '__main__':
167
+ parser = argparse.ArgumentParser()
168
+ parser.add_argument('--input_mv_dir', type=str, default='result/multiview')
169
+ parser.add_argument('--input_obj_dir', type=str, default='result/slrm')
170
+ parser.add_argument('--output_dir', type=str, default='result/refined')
171
+ parser.add_argument('--outside_ratio', type=float, default=0.20)
172
+ parser.add_argument('--no_decompose', action='store_true')
173
+ args = parser.parse_args()
174
+
175
+ for test_idx in os.listdir(args.input_mv_dir):
176
+ mv_root_dir = os.path.join(args.input_mv_dir, test_idx)
177
+ obj_dir = os.path.join(args.input_obj_dir, test_idx)
178
+
179
+ fixed_v, fixed_f = None, None
180
+ flow_vert, flow_vector = None, None
181
+ last_colors, last_normals = None, None
182
+ last_front_color, last_front_normal = None, None
183
+ distract_mask = None
184
+
185
+ mv, proj = make_star_cameras_orthographic(8, 1, r=1.2)
186
+ mv = mv[[4, 3, 2, 0, 6, 5]]
187
+ renderer = NormalsRenderer(mv,proj,(1024,1024))
188
+
189
+ if not args.no_decompose:
190
+ for name_idx, level in zip([3, 1, 2], [2, 1, 0]):
191
+ mesh = trimesh.load(obj_dir + f'_{name_idx}.obj')
192
+ new_mesh = mesh.split(only_watertight=False)
193
+ new_mesh = [ j for j in new_mesh if len(j.vertices) >= 300 ]
194
+ mesh = trimesh.Scene(new_mesh).dump(concatenate=True)
195
+ mesh_v, mesh_f = mesh.vertices, mesh.faces
196
+
197
+ if last_colors is None:
198
+ images = renderer.render(
199
+ torch.tensor(mesh_v, device='cuda').float(),
200
+ torch.ones_like(torch.from_numpy(mesh_v), device='cuda').float(),
201
+ torch.tensor(mesh_f, device='cuda'),
202
+ )
203
+ mask = (images[..., 3] < 0.9).cpu().numpy()
204
+
205
+ colors, normals = [], []
206
+ for i in range(6):
207
+ color_path = os.path.join(mv_root_dir, f'level{level}', f'color_{i}.png')
208
+ normal_path = os.path.join(mv_root_dir, f'level{level}', f'normal_{i}.png')
209
+ color = cv2.imread(color_path)
210
+ normal = cv2.imread(normal_path)
211
+ color = color[..., ::-1]
212
+ normal = normal[..., ::-1]
213
+
214
+ if last_colors is not None:
215
+ offset = calc_horizontal_offset(np.array(last_colors[i]), color)
216
+ # print('offset', i, offset)
217
+ else:
218
+ offset = calc_horizontal_offset2(mask[i], color)
219
+ # print('init offset', i, offset)
220
+
221
+ if offset != 0:
222
+ color = np.roll(color, offset, axis=1)
223
+ normal = np.roll(normal, offset, axis=1)
224
+
225
+ color = Image.fromarray(color)
226
+ normal = Image.fromarray(normal)
227
+ colors.append(color)
228
+ normals.append(normal)
229
+
230
+ if last_front_color is not None and level == 0:
231
+ original_mask, distract_bbox, _, distract_mask = get_distract_mask(last_front_color, np.array(colors[0]).astype(np.float32) / 255.0, outside_ratio=args.outside_ratio)
232
+ cv2.imwrite(f'{args.output_dir}/{test_idx}/distract_mask.png', distract_mask.astype(np.uint8) * 255)
233
+ else:
234
+ distract_mask = None
235
+ distract_bbox = None
236
+
237
+ last_front_color = np.array(colors[0]).astype(np.float32) / 255.0
238
+ last_front_normal = np.array(normals[0]).astype(np.float32) / 255.0
239
+
240
+ if last_colors is None:
241
+ from copy import deepcopy
242
+ last_colors, last_normals = deepcopy(colors), deepcopy(normals)
243
+
244
+ # my mesh flow weight by nearest vertexs
245
+ if fixed_v is not None and fixed_f is not None and level == 1:
246
+ t = trimesh.Trimesh(vertices=mesh_v, faces=mesh_f)
247
+
248
+ fixed_v_cpu = fixed_v.cpu().numpy()
249
+ kdtree_anchor = KDTree(fixed_v_cpu)
250
+ kdtree_mesh_v = KDTree(mesh_v)
251
+ _, idx_anchor = kdtree_anchor.query(mesh_v, k=1)
252
+ _, idx_mesh_v = kdtree_mesh_v.query(mesh_v, k=25)
253
+ idx_anchor = idx_anchor.squeeze()
254
+ neighbors = torch.tensor(mesh_v).cuda()[idx_mesh_v] # V, 25, 3
255
+ # calculate the distances neighbors [V, 25, 3]; mesh_v [V, 3] -> [V, 25]
256
+ neighbor_dists = torch.norm(neighbors - torch.tensor(mesh_v).cuda()[:, None], dim=-1)
257
+ neighbor_dists[neighbor_dists > 0.06] = 114514.
258
+ neighbor_weights = torch.exp(-neighbor_dists * 1.)
259
+ neighbor_weights = neighbor_weights / neighbor_weights.sum(dim=1, keepdim=True)
260
+ anchors = fixed_v[idx_anchor] # V, 3
261
+ anchor_normals = calc_vertex_normals(fixed_v, fixed_f)[idx_anchor] # V, 3
262
+ dis_anchor = torch.clamp(((anchors - torch.tensor(mesh_v).cuda()) * anchor_normals).sum(-1), min=0) + 0.01
263
+ vec_anchor = dis_anchor[:, None] * anchor_normals # V, 3
264
+ vec_anchor = vec_anchor[idx_mesh_v] # V, 25, 3
265
+ weighted_vec_anchor = (vec_anchor * neighbor_weights[:, :, None]).sum(1) # V, 3
266
+ mesh_v += weighted_vec_anchor.cpu().numpy()
267
+
268
+ t = trimesh.Trimesh(vertices=mesh_v, faces=mesh_f)
269
+
270
+ mesh_v = torch.tensor(mesh_v, device='cuda', dtype=torch.float32)
271
+ mesh_f = torch.tensor(mesh_f, device='cuda')
272
+
273
+ new_mesh, simp_v, simp_f = geo_refine(mesh_v, mesh_f, colors, normals, fixed_v=fixed_v, fixed_f=fixed_f, distract_mask=distract_mask, distract_bbox=distract_bbox)
274
+
275
+ # my mesh flow weight by nearest vertexs
276
+ try:
277
+ if fixed_v is not None and fixed_f is not None and level != 0:
278
+ new_mesh_v = new_mesh.verts_packed().cpu().numpy()
279
+
280
+ fixed_v_cpu = fixed_v.cpu().numpy()
281
+ kdtree_anchor = KDTree(fixed_v_cpu)
282
+ kdtree_mesh_v = KDTree(new_mesh_v)
283
+ _, idx_anchor = kdtree_anchor.query(new_mesh_v, k=1)
284
+ _, idx_mesh_v = kdtree_mesh_v.query(new_mesh_v, k=25)
285
+ idx_anchor = idx_anchor.squeeze()
286
+ neighbors = torch.tensor(new_mesh_v).cuda()[idx_mesh_v] # V, 25, 3
287
+ # calculate the distances neighbors [V, 25, 3]; new_mesh_v [V, 3] -> [V, 25]
288
+ neighbor_dists = torch.norm(neighbors - torch.tensor(new_mesh_v).cuda()[:, None], dim=-1)
289
+ neighbor_dists[neighbor_dists > 0.06] = 114514.
290
+ neighbor_weights = torch.exp(-neighbor_dists * 1.)
291
+ neighbor_weights = neighbor_weights / neighbor_weights.sum(dim=1, keepdim=True)
292
+ anchors = fixed_v[idx_anchor] # V, 3
293
+ anchor_normals = calc_vertex_normals(fixed_v, fixed_f)[idx_anchor] # V, 3
294
+ dis_anchor = torch.clamp(((anchors - torch.tensor(new_mesh_v).cuda()) * anchor_normals).sum(-1), min=0) + 0.01
295
+ vec_anchor = dis_anchor[:, None] * anchor_normals # V, 3
296
+ vec_anchor = vec_anchor[idx_mesh_v] # V, 25, 3
297
+ weighted_vec_anchor = (vec_anchor * neighbor_weights[:, :, None]).sum(1) # V, 3
298
+ new_mesh_v += weighted_vec_anchor.cpu().numpy()
299
+
300
+ # replace new_mesh verts with new_mesh_v
301
+ new_mesh = Meshes(verts=[torch.tensor(new_mesh_v, device='cuda')], faces=new_mesh.faces_list(), textures=new_mesh.textures)
302
+
303
+ except Exception as e:
304
+ pass
305
+
306
+ os.makedirs(f'{args.output_dir}/{test_idx}', exist_ok=True)
307
+ save_py3dmesh_with_trimesh_fast(new_mesh, f'{args.output_dir}/{test_idx}/out_{level}.glb', apply_sRGB_to_LinearRGB=False)
308
+
309
+ if fixed_v is None:
310
+ fixed_v, fixed_f = simp_v, simp_f
311
+ else:
312
+ fixed_f = torch.cat([fixed_f, simp_f + fixed_v.shape[0]], dim=0)
313
+ fixed_v = torch.cat([fixed_v, simp_v], dim=0)
314
+
315
+
316
+ else:
317
+ mesh = trimesh.load(obj_dir + f'_0.obj')
318
+ mesh_v, mesh_f = mesh.vertices, mesh.faces
319
+
320
+ images = renderer.render(
321
+ torch.tensor(mesh_v, device='cuda').float(),
322
+ torch.ones_like(torch.from_numpy(mesh_v), device='cuda').float(),
323
+ torch.tensor(mesh_f, device='cuda'),
324
+ )
325
+ mask = (images[..., 3] < 0.9).cpu().numpy()
326
+
327
+ colors, normals = [], []
328
+ for i in range(6):
329
+ color_path = os.path.join(mv_root_dir, f'level0', f'color_{i}.png')
330
+ normal_path = os.path.join(mv_root_dir, f'level0', f'normal_{i}.png')
331
+ color = cv2.imread(color_path)
332
+ normal = cv2.imread(normal_path)
333
+ color = color[..., ::-1]
334
+ normal = normal[..., ::-1]
335
+
336
+ offset = calc_horizontal_offset2(mask[i], color)
337
+
338
+ if offset != 0:
339
+ color = np.roll(color, offset, axis=1)
340
+ normal = np.roll(normal, offset, axis=1)
341
+
342
+ color = Image.fromarray(color)
343
+ normal = Image.fromarray(normal)
344
+ colors.append(color)
345
+ normals.append(normal)
346
+
347
+ mesh_v = torch.tensor(mesh_v, device='cuda', dtype=torch.float32)
348
+ mesh_f = torch.tensor(mesh_f, device='cuda')
349
+
350
+ new_mesh, _, _ = geo_refine(mesh_v, mesh_f, colors, normals, no_decompose=True, expansion_weight=0.)
351
+
352
+ os.makedirs(f'{args.output_dir}/{test_idx}', exist_ok=True)
353
+ save_py3dmesh_with_trimesh_fast(new_mesh, f'{args.output_dir}/{test_idx}/out_nodecomp.glb', apply_sRGB_to_LinearRGB=False)
infer_slrm.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import imageio
3
+ import numpy as np
4
+ import torch
5
+ import cv2
6
+ import glob
7
+ import matplotlib.pyplot as plt
8
+ from PIL import Image
9
+ from torchvision.transforms import v2
10
+ from pytorch_lightning import seed_everything
11
+ from omegaconf import OmegaConf
12
+ from tqdm import tqdm
13
+
14
+ from slrm.utils.train_util import instantiate_from_config
15
+ from slrm.utils.camera_util import (
16
+ FOV_to_intrinsics,
17
+ get_circular_camera_poses,
18
+ )
19
+ from slrm.utils.mesh_util import save_obj, save_glb
20
+ from slrm.utils.infer_util import images_to_video
21
+
22
+ from pytorch_lightning.utilities.deepspeed import convert_zero_checkpoint_to_fp32_state_dict
23
+
24
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
25
+
26
+ def get_render_cameras(batch_size=1, M=120, radius=2.5, elevation=10.0, is_flexicubes=False):
27
+ """
28
+ Get the rendering camera parameters.
29
+ """
30
+ c2ws = get_circular_camera_poses(M=M, radius=radius, elevation=elevation)
31
+ if is_flexicubes:
32
+ cameras = torch.linalg.inv(c2ws)
33
+ cameras = cameras.unsqueeze(0).repeat(batch_size, 1, 1, 1)
34
+ else:
35
+ extrinsics = c2ws.flatten(-2)
36
+ intrinsics = FOV_to_intrinsics(30.0).unsqueeze(0).repeat(M, 1, 1).float().flatten(-2)
37
+ cameras = torch.cat([extrinsics, intrinsics], dim=-1)
38
+ cameras = cameras.unsqueeze(0).repeat(batch_size, 1, 1)
39
+ return cameras
40
+
41
+
42
+ def images_to_video(images, output_dir, fps=30):
43
+ # images: (N, C, H, W)
44
+ os.makedirs(os.path.dirname(output_dir), exist_ok=True)
45
+ frames = []
46
+ for i in range(images.shape[0]):
47
+ frame = (images[i].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8).clip(0, 255)
48
+ assert frame.shape[0] == images.shape[2] and frame.shape[1] == images.shape[3], \
49
+ f"Frame shape mismatch: {frame.shape} vs {images.shape}"
50
+ assert frame.min() >= 0 and frame.max() <= 255, \
51
+ f"Frame value out of range: {frame.min()} ~ {frame.max()}"
52
+ frames.append(frame)
53
+ imageio.mimwrite(output_dir, np.stack(frames), fps=fps, codec='h264')
54
+
55
+
56
+ ###############################################################################
57
+ # Configuration.
58
+ ###############################################################################
59
+
60
+ seed_everything(0)
61
+
62
+ config_path = 'configs/mesh-slrm-infer.yaml'
63
+ config = OmegaConf.load(config_path)
64
+ config_name = os.path.basename(config_path).replace('.yaml', '')
65
+ model_config = config.model_config
66
+ infer_config = config.infer_config
67
+
68
+ IS_FLEXICUBES = True if config_name.startswith('mesh') else False
69
+
70
+ device = torch.device('cuda')
71
+
72
+ # load reconstruction model
73
+ print('Loading reconstruction model ...')
74
+ model = instantiate_from_config(model_config)
75
+ state_dict = torch.load(infer_config.model_path, map_location='cpu')
76
+ model.load_state_dict(state_dict, strict=False)
77
+
78
+ model = model.to(device)
79
+ if IS_FLEXICUBES:
80
+ model.init_flexicubes_geometry(device, fovy=30.0, is_ortho=model.is_ortho)
81
+ model = model.eval()
82
+
83
+ print('Loading Finished!')
84
+
85
+ def make_mesh(mesh_fpath, planes, level=None):
86
+
87
+ mesh_basename = os.path.basename(mesh_fpath).split('.')[0]
88
+ mesh_dirname = os.path.dirname(mesh_fpath)
89
+ mesh_glb_fpath = os.path.join(mesh_dirname, f"{mesh_basename}.glb")
90
+
91
+ with torch.no_grad():
92
+ # get mesh
93
+ mesh_out = model.extract_mesh(
94
+ planes,
95
+ use_texture_map=False,
96
+ levels=torch.tensor([level]).to(device),
97
+ **infer_config,
98
+ )
99
+
100
+ vertices, faces, vertex_colors = mesh_out
101
+ vertices = vertices[:, [1, 2, 0]]
102
+
103
+ save_glb(vertices, faces, vertex_colors, mesh_glb_fpath)
104
+ save_obj(vertices, faces, vertex_colors, mesh_fpath)
105
+
106
+ return mesh_fpath, mesh_glb_fpath
107
+
108
+
109
+ def make3d(images, name, output_dir):
110
+ input_cameras = torch.tensor(np.load('slrm/cameras.npy')).to(device)
111
+
112
+ render_cameras = get_render_cameras(
113
+ batch_size=1, radius=4.5, elevation=20.0, is_flexicubes=IS_FLEXICUBES).to(device)
114
+
115
+ images = images.unsqueeze(0).to(device)
116
+ images = v2.functional.resize(images, (320, 320), interpolation=3, antialias=True).clamp(0, 1)
117
+
118
+ mesh_fpath = os.path.join(output_dir, f"{name}.obj")
119
+
120
+ mesh_basename = os.path.basename(mesh_fpath).split('.')[0]
121
+ mesh_dirname = os.path.dirname(mesh_fpath)
122
+ video_fpath = os.path.join(mesh_dirname, f"{mesh_basename}.mp4")
123
+
124
+ with torch.no_grad():
125
+ # get triplane
126
+ planes = model.forward_planes(images, input_cameras.float())
127
+
128
+ # get video
129
+ chunk_size = 20 if IS_FLEXICUBES else 1
130
+ render_size = 512
131
+
132
+ frames = [ [] for _ in range(4) ]
133
+ for i in tqdm(range(0, render_cameras.shape[1], chunk_size)):
134
+ if IS_FLEXICUBES:
135
+ frame = model.forward_geometry_separate(
136
+ planes,
137
+ render_cameras[:, i:i+chunk_size],
138
+ render_size=render_size,
139
+ levels=torch.tensor([0]).to(device),
140
+ )['imgs']
141
+ for j in range(4):
142
+ frames[j].append(frame[j])
143
+ else:
144
+ frame = model.synthesizer(
145
+ planes,
146
+ cameras=render_cameras[:, i:i+chunk_size],
147
+ render_size=render_size,
148
+ )['images_rgb']
149
+ frames.append(frame)
150
+
151
+ for j in range(4):
152
+ frames[j] = torch.cat(frames[j], dim=1)
153
+ video_fpath_j = video_fpath.replace('.mp4', f'_{j}.mp4')
154
+ images_to_video(
155
+ frames[j][0],
156
+ video_fpath_j,
157
+ fps=30,
158
+ )
159
+
160
+ _, mesh_glb_fpath = make_mesh(mesh_fpath.replace(mesh_fpath[-4:], f'_{j}{mesh_fpath[-4:]}'), planes, level=[0, 3, 4, 2][j])
161
+
162
+ return video_fpath, mesh_fpath, mesh_glb_fpath
163
+
164
+
165
+ if __name__ == '__main__':
166
+ import argparse
167
+ parser = argparse.ArgumentParser()
168
+ parser.add_argument('--input_dir', type=str, default="result/multiview")
169
+ parser.add_argument('--output_dir', type=str, default="result/slrm")
170
+ args = parser.parse_args()
171
+
172
+ paths = glob.glob(args.input_dir + '/*')
173
+ os.makedirs(args.output_dir, exist_ok=True)
174
+
175
+ def load_rgb(path):
176
+ img = plt.imread(path)
177
+ img = Image.fromarray(np.uint8(img * 255.))
178
+ return img
179
+
180
+ for path in tqdm(paths):
181
+ name = path.split('/')[-1]
182
+ index_targets = [
183
+ 'level0/color_0.png',
184
+ 'level0/color_1.png',
185
+ 'level0/color_2.png',
186
+ 'level0/color_3.png',
187
+ 'level0/color_4.png',
188
+ 'level0/color_5.png',
189
+ ]
190
+ imgs = []
191
+ for index_target in index_targets:
192
+ img = load_rgb(os.path.join(path, index_target))
193
+ imgs.append(img)
194
+
195
+ imgs = np.stack(imgs, axis=0).astype(np.float32) / 255.0
196
+ imgs = torch.from_numpy(np.array(imgs)).permute(0, 3, 1, 2).contiguous().float() # (6, 3, 1024, 1024)
197
+
198
+ video_fpath, mesh_fpath, mesh_glb_fpath = make3d(imgs, name, args.output_dir)
199
+
input_cases/1.png ADDED
input_cases/2.png ADDED
input_cases/3.png ADDED
input_cases/4.png ADDED
input_cases/ayaka.png ADDED
input_cases/firefly2.png ADDED
input_cases_apose/1.png ADDED
input_cases_apose/2.png ADDED
input_cases_apose/3.png ADDED
input_cases_apose/4.png ADDED
input_cases_apose/ayaka.png ADDED
input_cases_apose/belle.png ADDED
input_cases_apose/firefly.png ADDED
multiview/__init__.py ADDED
File without changes
multiview/fixed_prompt_embeds_6view/clr_embeds.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b9e51666588d0f075e031262744d371e12076160231aab19a531dbf7ab976e4d
3
+ size 946932
multiview/fixed_prompt_embeds_6view/normal_embeds.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:53dfcd17f62fbfd8aeba60b1b05fa7559d72179738fd048e2ac1d53e5be5ed9d
3
+ size 946941
multiview/models/transformer_mv2d_image.py ADDED
@@ -0,0 +1,995 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from dataclasses import dataclass
15
+ from typing import Any, Dict, Optional
16
+
17
+ import torch
18
+ import torch.nn.functional as F
19
+ from torch import nn
20
+
21
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
22
+ from diffusers.models.embeddings import ImagePositionalEmbeddings
23
+ from diffusers.utils import BaseOutput, deprecate
24
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
25
+ from diffusers.models.attention import FeedForward, AdaLayerNorm, AdaLayerNormZero, Attention
26
+ from diffusers.models.embeddings import PatchEmbed
27
+ from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear
28
+ from diffusers.models.modeling_utils import ModelMixin
29
+ from diffusers.utils.import_utils import is_xformers_available
30
+
31
+ from einops import rearrange, repeat
32
+ import pdb
33
+ import random
34
+
35
+
36
+ if is_xformers_available():
37
+ import xformers
38
+ import xformers.ops
39
+ else:
40
+ xformers = None
41
+
42
+ def my_repeat(tensor, num_repeats):
43
+ """
44
+ Repeat a tensor along a given dimension
45
+ """
46
+ if len(tensor.shape) == 3:
47
+ return repeat(tensor, "b d c -> (b v) d c", v=num_repeats)
48
+ elif len(tensor.shape) == 4:
49
+ return repeat(tensor, "a b d c -> (a v) b d c", v=num_repeats)
50
+
51
+
52
+ @dataclass
53
+ class TransformerMV2DModelOutput(BaseOutput):
54
+ """
55
+ The output of [`Transformer2DModel`].
56
+
57
+ Args:
58
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
59
+ The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
60
+ distributions for the unnoised latent pixels.
61
+ """
62
+
63
+ sample: torch.FloatTensor
64
+
65
+
66
+ class TransformerMV2DModel(ModelMixin, ConfigMixin):
67
+ """
68
+ A 2D Transformer model for image-like data.
69
+
70
+ Parameters:
71
+ num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
72
+ attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
73
+ in_channels (`int`, *optional*):
74
+ The number of channels in the input and output (specify if the input is **continuous**).
75
+ num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
76
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
77
+ cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
78
+ sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
79
+ This is fixed during training since it is used to learn a number of position embeddings.
80
+ num_vector_embeds (`int`, *optional*):
81
+ The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**).
82
+ Includes the class for the masked latent pixel.
83
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
84
+ num_embeds_ada_norm ( `int`, *optional*):
85
+ The number of diffusion steps used during training. Pass if at least one of the norm_layers is
86
+ `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
87
+ added to the hidden states.
88
+
89
+ During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`.
90
+ attention_bias (`bool`, *optional*):
91
+ Configure if the `TransformerBlocks` attention should contain a bias parameter.
92
+ """
93
+
94
+ @register_to_config
95
+ def __init__(
96
+ self,
97
+ num_attention_heads: int = 16,
98
+ attention_head_dim: int = 88,
99
+ in_channels: Optional[int] = None,
100
+ out_channels: Optional[int] = None,
101
+ num_layers: int = 1,
102
+ dropout: float = 0.0,
103
+ norm_num_groups: int = 32,
104
+ cross_attention_dim: Optional[int] = None,
105
+ attention_bias: bool = False,
106
+ sample_size: Optional[int] = None,
107
+ num_vector_embeds: Optional[int] = None,
108
+ patch_size: Optional[int] = None,
109
+ activation_fn: str = "geglu",
110
+ num_embeds_ada_norm: Optional[int] = None,
111
+ use_linear_projection: bool = False,
112
+ only_cross_attention: bool = False,
113
+ upcast_attention: bool = False,
114
+ norm_type: str = "layer_norm",
115
+ norm_elementwise_affine: bool = True,
116
+ num_views: int = 1,
117
+ cd_attention_last: bool=False,
118
+ cd_attention_mid: bool=False,
119
+ multiview_attention: bool=True,
120
+ sparse_mv_attention: bool = False,
121
+ mvcd_attention: bool=False
122
+ ):
123
+ super().__init__()
124
+ self.use_linear_projection = use_linear_projection
125
+ self.num_attention_heads = num_attention_heads
126
+ self.attention_head_dim = attention_head_dim
127
+ inner_dim = num_attention_heads * attention_head_dim
128
+
129
+ # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
130
+ # Define whether input is continuous or discrete depending on configuration
131
+ self.is_input_continuous = (in_channels is not None) and (patch_size is None)
132
+ self.is_input_vectorized = num_vector_embeds is not None
133
+ self.is_input_patches = in_channels is not None and patch_size is not None
134
+
135
+ if norm_type == "layer_norm" and num_embeds_ada_norm is not None:
136
+ deprecation_message = (
137
+ f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or"
138
+ " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config."
139
+ " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect"
140
+ " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it"
141
+ " would be very nice if you could open a Pull request for the `transformer/config.json` file"
142
+ )
143
+ deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False)
144
+ norm_type = "ada_norm"
145
+
146
+ if self.is_input_continuous and self.is_input_vectorized:
147
+ raise ValueError(
148
+ f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
149
+ " sure that either `in_channels` or `num_vector_embeds` is None."
150
+ )
151
+ elif self.is_input_vectorized and self.is_input_patches:
152
+ raise ValueError(
153
+ f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make"
154
+ " sure that either `num_vector_embeds` or `num_patches` is None."
155
+ )
156
+ elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches:
157
+ raise ValueError(
158
+ f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:"
159
+ f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None."
160
+ )
161
+
162
+ # 2. Define input layers
163
+ if self.is_input_continuous:
164
+ self.in_channels = in_channels
165
+
166
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
167
+ if use_linear_projection:
168
+ self.proj_in = LoRACompatibleLinear(in_channels, inner_dim)
169
+ else:
170
+ self.proj_in = LoRACompatibleConv(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
171
+ elif self.is_input_vectorized:
172
+ assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
173
+ assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed"
174
+
175
+ self.height = sample_size
176
+ self.width = sample_size
177
+ self.num_vector_embeds = num_vector_embeds
178
+ self.num_latent_pixels = self.height * self.width
179
+
180
+ self.latent_image_embedding = ImagePositionalEmbeddings(
181
+ num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width
182
+ )
183
+ elif self.is_input_patches:
184
+ assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size"
185
+
186
+ self.height = sample_size
187
+ self.width = sample_size
188
+
189
+ self.patch_size = patch_size
190
+ self.pos_embed = PatchEmbed(
191
+ height=sample_size,
192
+ width=sample_size,
193
+ patch_size=patch_size,
194
+ in_channels=in_channels,
195
+ embed_dim=inner_dim,
196
+ )
197
+
198
+ # 3. Define transformers blocks
199
+ self.transformer_blocks = nn.ModuleList(
200
+ [
201
+ BasicMVTransformerBlock(
202
+ inner_dim,
203
+ num_attention_heads,
204
+ attention_head_dim,
205
+ dropout=dropout,
206
+ cross_attention_dim=cross_attention_dim,
207
+ activation_fn=activation_fn,
208
+ num_embeds_ada_norm=num_embeds_ada_norm,
209
+ attention_bias=attention_bias,
210
+ only_cross_attention=only_cross_attention,
211
+ upcast_attention=upcast_attention,
212
+ norm_type=norm_type,
213
+ norm_elementwise_affine=norm_elementwise_affine,
214
+ num_views=num_views,
215
+ cd_attention_last=cd_attention_last,
216
+ cd_attention_mid=cd_attention_mid,
217
+ multiview_attention=multiview_attention,
218
+ sparse_mv_attention=sparse_mv_attention,
219
+ mvcd_attention=mvcd_attention
220
+ )
221
+ for d in range(num_layers)
222
+ ]
223
+ )
224
+
225
+ # 4. Define output layers
226
+ self.out_channels = in_channels if out_channels is None else out_channels
227
+ if self.is_input_continuous:
228
+ # TODO: should use out_channels for continuous projections
229
+ if use_linear_projection:
230
+ self.proj_out = LoRACompatibleLinear(inner_dim, in_channels)
231
+ else:
232
+ self.proj_out = LoRACompatibleConv(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
233
+ elif self.is_input_vectorized:
234
+ self.norm_out = nn.LayerNorm(inner_dim)
235
+ self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
236
+ elif self.is_input_patches:
237
+ self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
238
+ self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim)
239
+ self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
240
+
241
+ def forward(
242
+ self,
243
+ hidden_states: torch.Tensor,
244
+ encoder_hidden_states: Optional[torch.Tensor] = None,
245
+ timestep: Optional[torch.LongTensor] = None,
246
+ class_labels: Optional[torch.LongTensor] = None,
247
+ cross_attention_kwargs: Dict[str, Any] = None,
248
+ attention_mask: Optional[torch.Tensor] = None,
249
+ encoder_attention_mask: Optional[torch.Tensor] = None,
250
+ return_dict: bool = True,
251
+ ):
252
+ """
253
+ The [`Transformer2DModel`] forward method.
254
+
255
+ Args:
256
+ hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
257
+ Input `hidden_states`.
258
+ encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
259
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
260
+ self-attention.
261
+ timestep ( `torch.LongTensor`, *optional*):
262
+ Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
263
+ class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
264
+ Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
265
+ `AdaLayerZeroNorm`.
266
+ encoder_attention_mask ( `torch.Tensor`, *optional*):
267
+ Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
268
+
269
+ * Mask `(batch, sequence_length)` True = keep, False = discard.
270
+ * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
271
+
272
+ If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
273
+ above. This bias will be added to the cross-attention scores.
274
+ return_dict (`bool`, *optional*, defaults to `True`):
275
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
276
+ tuple.
277
+
278
+ Returns:
279
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
280
+ `tuple` where the first element is the sample tensor.
281
+ """
282
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
283
+ # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
284
+ # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
285
+ # expects mask of shape:
286
+ # [batch, key_tokens]
287
+ # adds singleton query_tokens dimension:
288
+ # [batch, 1, key_tokens]
289
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
290
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
291
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
292
+ if attention_mask is not None and attention_mask.ndim == 2:
293
+ # assume that mask is expressed as:
294
+ # (1 = keep, 0 = discard)
295
+ # convert mask into a bias that can be added to attention scores:
296
+ # (keep = +0, discard = -10000.0)
297
+ attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
298
+ attention_mask = attention_mask.unsqueeze(1)
299
+
300
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
301
+ if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
302
+ encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
303
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
304
+
305
+ # 1. Input
306
+ if self.is_input_continuous:
307
+ batch, _, height, width = hidden_states.shape
308
+ residual = hidden_states
309
+
310
+ hidden_states = self.norm(hidden_states)
311
+ if not self.use_linear_projection:
312
+ hidden_states = self.proj_in(hidden_states)
313
+ inner_dim = hidden_states.shape[1]
314
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
315
+ else:
316
+ inner_dim = hidden_states.shape[1]
317
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
318
+ hidden_states = self.proj_in(hidden_states)
319
+ elif self.is_input_vectorized:
320
+ hidden_states = self.latent_image_embedding(hidden_states)
321
+ elif self.is_input_patches:
322
+ hidden_states = self.pos_embed(hidden_states)
323
+
324
+ # 2. Blocks
325
+ for block in self.transformer_blocks:
326
+ hidden_states = block(
327
+ hidden_states,
328
+ attention_mask=attention_mask,
329
+ encoder_hidden_states=encoder_hidden_states,
330
+ encoder_attention_mask=encoder_attention_mask,
331
+ timestep=timestep,
332
+ cross_attention_kwargs=cross_attention_kwargs,
333
+ class_labels=class_labels,
334
+ )
335
+
336
+ # 3. Output
337
+ if self.is_input_continuous:
338
+ if not self.use_linear_projection:
339
+ hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
340
+ hidden_states = self.proj_out(hidden_states)
341
+ else:
342
+ hidden_states = self.proj_out(hidden_states)
343
+ hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
344
+
345
+ output = hidden_states + residual
346
+ elif self.is_input_vectorized:
347
+ hidden_states = self.norm_out(hidden_states)
348
+ logits = self.out(hidden_states)
349
+ # (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
350
+ logits = logits.permute(0, 2, 1)
351
+
352
+ # log(p(x_0))
353
+ output = F.log_softmax(logits.double(), dim=1).float()
354
+ elif self.is_input_patches:
355
+ # TODO: cleanup!
356
+ conditioning = self.transformer_blocks[0].norm1.emb(
357
+ timestep, class_labels, hidden_dtype=hidden_states.dtype
358
+ )
359
+ shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
360
+ hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
361
+ hidden_states = self.proj_out_2(hidden_states)
362
+
363
+ # unpatchify
364
+ height = width = int(hidden_states.shape[1] ** 0.5)
365
+ hidden_states = hidden_states.reshape(
366
+ shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
367
+ )
368
+ hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
369
+ output = hidden_states.reshape(
370
+ shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
371
+ )
372
+
373
+ if not return_dict:
374
+ return (output,)
375
+
376
+ return TransformerMV2DModelOutput(sample=output)
377
+
378
+
379
+ @maybe_allow_in_graph
380
+ class BasicMVTransformerBlock(nn.Module):
381
+ r"""
382
+ A basic Transformer block.
383
+
384
+ Parameters:
385
+ dim (`int`): The number of channels in the input and output.
386
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
387
+ attention_head_dim (`int`): The number of channels in each head.
388
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
389
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
390
+ only_cross_attention (`bool`, *optional*):
391
+ Whether to use only cross-attention layers. In this case two cross attention layers are used.
392
+ double_self_attention (`bool`, *optional*):
393
+ Whether to use two self-attention layers. In this case no cross attention layers are used.
394
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
395
+ num_embeds_ada_norm (:
396
+ obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
397
+ attention_bias (:
398
+ obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
399
+ """
400
+
401
+ def __init__(
402
+ self,
403
+ dim: int,
404
+ num_attention_heads: int,
405
+ attention_head_dim: int,
406
+ dropout=0.0,
407
+ cross_attention_dim: Optional[int] = None,
408
+ activation_fn: str = "geglu",
409
+ num_embeds_ada_norm: Optional[int] = None,
410
+ attention_bias: bool = False,
411
+ only_cross_attention: bool = False,
412
+ double_self_attention: bool = False,
413
+ upcast_attention: bool = False,
414
+ norm_elementwise_affine: bool = True,
415
+ norm_type: str = "layer_norm",
416
+ final_dropout: bool = False,
417
+ num_views: int = 1,
418
+ cd_attention_last: bool = False,
419
+ cd_attention_mid: bool = False,
420
+ multiview_attention: bool = True,
421
+ sparse_mv_attention: bool = False,
422
+ mvcd_attention: bool = False
423
+ ):
424
+ super().__init__()
425
+ self.only_cross_attention = only_cross_attention
426
+
427
+ self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
428
+ self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
429
+
430
+ if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
431
+ raise ValueError(
432
+ f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
433
+ f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
434
+ )
435
+
436
+ # Define 3 blocks. Each block has its own normalization layer.
437
+ # 1. Self-Attn
438
+ if self.use_ada_layer_norm:
439
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
440
+ elif self.use_ada_layer_norm_zero:
441
+ self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
442
+ else:
443
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
444
+
445
+ self.multiview_attention = multiview_attention
446
+ self.sparse_mv_attention = sparse_mv_attention
447
+ self.mvcd_attention = mvcd_attention
448
+
449
+ self.attn1 = CustomAttention(
450
+ query_dim=dim,
451
+ heads=num_attention_heads,
452
+ dim_head=attention_head_dim,
453
+ dropout=dropout,
454
+ bias=attention_bias,
455
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
456
+ upcast_attention=upcast_attention,
457
+ processor=MVAttnProcessor()
458
+ )
459
+
460
+ # 2. Cross-Attn
461
+ if cross_attention_dim is not None or double_self_attention:
462
+ # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
463
+ # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
464
+ # the second cross attention block.
465
+ self.norm2 = (
466
+ AdaLayerNorm(dim, num_embeds_ada_norm)
467
+ if self.use_ada_layer_norm
468
+ else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
469
+ )
470
+ self.attn2 = Attention(
471
+ query_dim=dim,
472
+ cross_attention_dim=cross_attention_dim if not double_self_attention else None,
473
+ heads=num_attention_heads,
474
+ dim_head=attention_head_dim,
475
+ dropout=dropout,
476
+ bias=attention_bias,
477
+ upcast_attention=upcast_attention,
478
+ ) # is self-attn if encoder_hidden_states is none
479
+ else:
480
+ self.norm2 = None
481
+ self.attn2 = None
482
+
483
+ # 3. Feed-forward
484
+ self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
485
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
486
+
487
+ # let chunk size default to None
488
+ self._chunk_size = None
489
+ self._chunk_dim = 0
490
+
491
+ self.num_views = num_views
492
+
493
+ self.cd_attention_last = cd_attention_last
494
+
495
+ if self.cd_attention_last:
496
+ # Joint task -Attn
497
+ self.attn_joint_last = CustomJointAttention(
498
+ query_dim=dim,
499
+ heads=num_attention_heads,
500
+ dim_head=attention_head_dim,
501
+ dropout=dropout,
502
+ bias=attention_bias,
503
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
504
+ upcast_attention=upcast_attention,
505
+ processor=JointAttnProcessor()
506
+ )
507
+ nn.init.zeros_(self.attn_joint_last.to_out[0].weight.data)
508
+ self.norm_joint_last = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
509
+
510
+
511
+ self.cd_attention_mid = cd_attention_mid
512
+
513
+ if self.cd_attention_mid:
514
+ print("cross-domain attn in the middle")
515
+ # Joint task -Attn
516
+ self.attn_joint_mid = CustomJointAttention(
517
+ query_dim=dim,
518
+ heads=num_attention_heads,
519
+ dim_head=attention_head_dim,
520
+ dropout=dropout,
521
+ bias=attention_bias,
522
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
523
+ upcast_attention=upcast_attention,
524
+ processor=JointAttnProcessor()
525
+ )
526
+ nn.init.zeros_(self.attn_joint_mid.to_out[0].weight.data)
527
+ self.norm_joint_mid = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
528
+
529
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
530
+ # Sets chunk feed-forward
531
+ self._chunk_size = chunk_size
532
+ self._chunk_dim = dim
533
+
534
+ def forward(
535
+ self,
536
+ hidden_states: torch.FloatTensor,
537
+ attention_mask: Optional[torch.FloatTensor] = None,
538
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
539
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
540
+ timestep: Optional[torch.LongTensor] = None,
541
+ cross_attention_kwargs: Dict[str, Any] = None,
542
+ class_labels: Optional[torch.LongTensor] = None,
543
+ ):
544
+ assert attention_mask is None # not supported yet
545
+ # Notice that normalization is always applied before the real computation in the following blocks.
546
+ # 1. Self-Attention
547
+ if self.use_ada_layer_norm:
548
+ norm_hidden_states = self.norm1(hidden_states, timestep)
549
+ elif self.use_ada_layer_norm_zero:
550
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
551
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
552
+ )
553
+ else:
554
+ norm_hidden_states = self.norm1(hidden_states)
555
+
556
+ cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
557
+
558
+ attn_output = self.attn1(
559
+ norm_hidden_states,
560
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
561
+ attention_mask=attention_mask,
562
+ num_views=self.num_views,
563
+ multiview_attention=self.multiview_attention,
564
+ sparse_mv_attention=self.sparse_mv_attention,
565
+ mvcd_attention=self.mvcd_attention,
566
+ **cross_attention_kwargs,
567
+ )
568
+
569
+
570
+ if self.use_ada_layer_norm_zero:
571
+ attn_output = gate_msa.unsqueeze(1) * attn_output
572
+ hidden_states = attn_output + hidden_states
573
+
574
+ # joint attention twice
575
+ if self.cd_attention_mid:
576
+ norm_hidden_states = (
577
+ self.norm_joint_mid(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_joint_mid(hidden_states)
578
+ )
579
+ hidden_states = self.attn_joint_mid(norm_hidden_states) + hidden_states
580
+
581
+ # 2. Cross-Attention
582
+ if self.attn2 is not None:
583
+ norm_hidden_states = (
584
+ self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
585
+ )
586
+
587
+ attn_output = self.attn2(
588
+ norm_hidden_states,
589
+ encoder_hidden_states=encoder_hidden_states,
590
+ attention_mask=encoder_attention_mask,
591
+ **cross_attention_kwargs,
592
+ )
593
+ hidden_states = attn_output + hidden_states
594
+
595
+ # 3. Feed-forward
596
+ norm_hidden_states = self.norm3(hidden_states)
597
+
598
+ if self.use_ada_layer_norm_zero:
599
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
600
+
601
+ if self._chunk_size is not None:
602
+ # "feed_forward_chunk_size" can be used to save memory
603
+ if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
604
+ raise ValueError(
605
+ f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
606
+ )
607
+
608
+ num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
609
+ ff_output = torch.cat(
610
+ [self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)],
611
+ dim=self._chunk_dim,
612
+ )
613
+ else:
614
+ ff_output = self.ff(norm_hidden_states)
615
+
616
+ if self.use_ada_layer_norm_zero:
617
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
618
+
619
+ hidden_states = ff_output + hidden_states
620
+
621
+ if self.cd_attention_last:
622
+ norm_hidden_states = (
623
+ self.norm_joint_last(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_joint_last(hidden_states)
624
+ )
625
+ hidden_states = self.attn_joint_last(norm_hidden_states) + hidden_states
626
+
627
+ return hidden_states
628
+
629
+
630
+ class CustomAttention(Attention):
631
+ def set_use_memory_efficient_attention_xformers(
632
+ self, use_memory_efficient_attention_xformers: bool, *args, **kwargs
633
+ ):
634
+ processor = XFormersMVAttnProcessor()
635
+ self.set_processor(processor)
636
+ # print("using xformers attention processor")
637
+
638
+
639
+ class CustomJointAttention(Attention):
640
+ def set_use_memory_efficient_attention_xformers(
641
+ self, use_memory_efficient_attention_xformers: bool, *args, **kwargs
642
+ ):
643
+ processor = XFormersJointAttnProcessor()
644
+ self.set_processor(processor)
645
+ # print("using xformers attention processor")
646
+
647
+ class MVAttnProcessor:
648
+ r"""
649
+ Default processor for performing attention-related computations.
650
+ """
651
+
652
+ def __call__(
653
+ self,
654
+ attn: Attention,
655
+ hidden_states,
656
+ encoder_hidden_states=None,
657
+ attention_mask=None,
658
+ temb=None,
659
+ num_views=1,
660
+ multiview_attention=True
661
+ ):
662
+ residual = hidden_states
663
+
664
+ if attn.spatial_norm is not None:
665
+ hidden_states = attn.spatial_norm(hidden_states, temb)
666
+
667
+ input_ndim = hidden_states.ndim
668
+
669
+ if input_ndim == 4:
670
+ batch_size, channel, height, width = hidden_states.shape
671
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
672
+
673
+ batch_size, sequence_length, _ = (
674
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
675
+ )
676
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
677
+
678
+ if attn.group_norm is not None:
679
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
680
+
681
+ query = attn.to_q(hidden_states)
682
+
683
+ if encoder_hidden_states is None:
684
+ encoder_hidden_states = hidden_states
685
+ elif attn.norm_cross:
686
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
687
+
688
+ key = attn.to_k(encoder_hidden_states)
689
+ value = attn.to_v(encoder_hidden_states)
690
+
691
+ # multi-view self-attention
692
+ if multiview_attention:
693
+ if num_views <= 6:
694
+ # after use xformer; possible to train with 6 views
695
+ key = rearrange(key, "(b t) d c -> b (t d) c", t=num_views).repeat_interleave(num_views, dim=0)
696
+ value = rearrange(value, "(b t) d c -> b (t d) c", t=num_views).repeat_interleave(num_views, dim=0)
697
+ else: # apply sparse attention
698
+ raise NotImplementedError("sparse attention not implemented yet.")
699
+
700
+ query = attn.head_to_batch_dim(query).contiguous()
701
+ key = attn.head_to_batch_dim(key).contiguous()
702
+ value = attn.head_to_batch_dim(value).contiguous()
703
+
704
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
705
+ hidden_states = torch.bmm(attention_probs, value)
706
+ hidden_states = attn.batch_to_head_dim(hidden_states)
707
+
708
+ # linear proj
709
+ hidden_states = attn.to_out[0](hidden_states)
710
+ # dropout
711
+ hidden_states = attn.to_out[1](hidden_states)
712
+
713
+ if input_ndim == 4:
714
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
715
+
716
+ if attn.residual_connection:
717
+ hidden_states = hidden_states + residual
718
+
719
+ hidden_states = hidden_states / attn.rescale_output_factor
720
+
721
+ return hidden_states
722
+
723
+
724
+ class XFormersMVAttnProcessor:
725
+ r"""
726
+ Default processor for performing attention-related computations.
727
+ """
728
+
729
+ def __call__(
730
+ self,
731
+ attn: Attention,
732
+ hidden_states,
733
+ encoder_hidden_states=None,
734
+ attention_mask=None,
735
+ temb=None,
736
+ num_views=1.,
737
+ multiview_attention=True,
738
+ sparse_mv_attention=False,
739
+ mvcd_attention=False,
740
+ ):
741
+ residual = hidden_states
742
+
743
+ if attn.spatial_norm is not None:
744
+ hidden_states = attn.spatial_norm(hidden_states, temb)
745
+
746
+ input_ndim = hidden_states.ndim
747
+
748
+ if input_ndim == 4:
749
+ batch_size, channel, height, width = hidden_states.shape
750
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
751
+
752
+ batch_size, sequence_length, _ = (
753
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
754
+ )
755
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
756
+
757
+ # from yuancheng; here attention_mask is None
758
+ if attention_mask is not None:
759
+ # expand our mask's singleton query_tokens dimension:
760
+ # [batch*heads, 1, key_tokens] ->
761
+ # [batch*heads, query_tokens, key_tokens]
762
+ # so that it can be added as a bias onto the attention scores that xformers computes:
763
+ # [batch*heads, query_tokens, key_tokens]
764
+ # we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
765
+ _, query_tokens, _ = hidden_states.shape
766
+ attention_mask = attention_mask.expand(-1, query_tokens, -1)
767
+
768
+ if attn.group_norm is not None:
769
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
770
+
771
+ query = attn.to_q(hidden_states)
772
+
773
+ if encoder_hidden_states is None:
774
+ encoder_hidden_states = hidden_states
775
+ elif attn.norm_cross:
776
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
777
+
778
+ key_raw = attn.to_k(encoder_hidden_states)
779
+ value_raw = attn.to_v(encoder_hidden_states)
780
+
781
+ # multi-view self-attention
782
+ if multiview_attention:
783
+ if not sparse_mv_attention:
784
+ key = my_repeat(rearrange(key_raw, "(b t) d c -> b (t d) c", t=num_views), num_views)
785
+ value = my_repeat(rearrange(value_raw, "(b t) d c -> b (t d) c", t=num_views), num_views)
786
+ else:
787
+ key_front = my_repeat(rearrange(key_raw, "(b t) d c -> b t d c", t=num_views)[:, 0, :, :], num_views) # [(b t), d, c]
788
+ value_front = my_repeat(rearrange(value_raw, "(b t) d c -> b t d c", t=num_views)[:, 0, :, :], num_views)
789
+ key = torch.cat([key_front, key_raw], dim=1) # shape (b t) (2 d) c
790
+ value = torch.cat([value_front, value_raw], dim=1)
791
+
792
+ if mvcd_attention:
793
+ # memory efficient, cross domain attention
794
+ key_0, key_1 = torch.chunk(key_raw, dim=0, chunks=2) # keys shape (b t) d c
795
+ value_0, value_1 = torch.chunk(value_raw, dim=0, chunks=2)
796
+ key_cross = torch.concat([key_1, key_0], dim=0)
797
+ value_cross = torch.concat([value_1, value_0], dim=0) # shape (b t) d c
798
+ key = torch.cat([key, key_cross], dim=1)
799
+ value = torch.cat([value, value_cross], dim=1) # shape (b t) (t+1 d) c
800
+ else:
801
+ # print("don't use multiview attention.")
802
+ key = key_raw
803
+ value = value_raw
804
+
805
+ query = attn.head_to_batch_dim(query)
806
+ key = attn.head_to_batch_dim(key)
807
+ value = attn.head_to_batch_dim(value)
808
+
809
+ hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
810
+ hidden_states = attn.batch_to_head_dim(hidden_states)
811
+
812
+ # linear proj
813
+ hidden_states = attn.to_out[0](hidden_states)
814
+ # dropout
815
+ hidden_states = attn.to_out[1](hidden_states)
816
+
817
+ if input_ndim == 4:
818
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
819
+
820
+ if attn.residual_connection:
821
+ hidden_states = hidden_states + residual
822
+
823
+ hidden_states = hidden_states / attn.rescale_output_factor
824
+
825
+ return hidden_states
826
+
827
+
828
+
829
+ class XFormersJointAttnProcessor:
830
+ r"""
831
+ Default processor for performing attention-related computations.
832
+ """
833
+
834
+ def __call__(
835
+ self,
836
+ attn: Attention,
837
+ hidden_states,
838
+ encoder_hidden_states=None,
839
+ attention_mask=None,
840
+ temb=None,
841
+ num_tasks=2
842
+ ):
843
+
844
+ residual = hidden_states
845
+
846
+ if attn.spatial_norm is not None:
847
+ hidden_states = attn.spatial_norm(hidden_states, temb)
848
+
849
+ input_ndim = hidden_states.ndim
850
+
851
+ if input_ndim == 4:
852
+ batch_size, channel, height, width = hidden_states.shape
853
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
854
+
855
+ batch_size, sequence_length, _ = (
856
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
857
+ )
858
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
859
+
860
+ # from yuancheng; here attention_mask is None
861
+ if attention_mask is not None:
862
+ # expand our mask's singleton query_tokens dimension:
863
+ # [batch*heads, 1, key_tokens] ->
864
+ # [batch*heads, query_tokens, key_tokens]
865
+ # so that it can be added as a bias onto the attention scores that xformers computes:
866
+ # [batch*heads, query_tokens, key_tokens]
867
+ # we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
868
+ _, query_tokens, _ = hidden_states.shape
869
+ attention_mask = attention_mask.expand(-1, query_tokens, -1)
870
+
871
+ if attn.group_norm is not None:
872
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
873
+
874
+ query = attn.to_q(hidden_states)
875
+
876
+ if encoder_hidden_states is None:
877
+ encoder_hidden_states = hidden_states
878
+ elif attn.norm_cross:
879
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
880
+
881
+ key = attn.to_k(encoder_hidden_states)
882
+ value = attn.to_v(encoder_hidden_states)
883
+
884
+ assert num_tasks == 2 # only support two tasks now
885
+
886
+ key_0, key_1 = torch.chunk(key, dim=0, chunks=2) # keys shape (b t) d c
887
+ value_0, value_1 = torch.chunk(value, dim=0, chunks=2)
888
+ key = torch.cat([key_0, key_1], dim=1) # (b t) 2d c
889
+ value = torch.cat([value_0, value_1], dim=1) # (b t) 2d c
890
+ key = torch.cat([key]*2, dim=0) # ( 2 b t) 2d c
891
+ value = torch.cat([value]*2, dim=0) # (2 b t) 2d c
892
+
893
+
894
+ query = attn.head_to_batch_dim(query).contiguous()
895
+ key = attn.head_to_batch_dim(key).contiguous()
896
+ value = attn.head_to_batch_dim(value).contiguous()
897
+
898
+ hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
899
+ hidden_states = attn.batch_to_head_dim(hidden_states)
900
+
901
+ # linear proj
902
+ hidden_states = attn.to_out[0](hidden_states)
903
+ # dropout
904
+ hidden_states = attn.to_out[1](hidden_states)
905
+
906
+ if input_ndim == 4:
907
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
908
+
909
+ if attn.residual_connection:
910
+ hidden_states = hidden_states + residual
911
+
912
+ hidden_states = hidden_states / attn.rescale_output_factor
913
+
914
+ return hidden_states
915
+
916
+
917
+ class JointAttnProcessor:
918
+ r"""
919
+ Default processor for performing attention-related computations.
920
+ """
921
+
922
+ def __call__(
923
+ self,
924
+ attn: Attention,
925
+ hidden_states,
926
+ encoder_hidden_states=None,
927
+ attention_mask=None,
928
+ temb=None,
929
+ num_tasks=2
930
+ ):
931
+
932
+ residual = hidden_states
933
+
934
+ if attn.spatial_norm is not None:
935
+ hidden_states = attn.spatial_norm(hidden_states, temb)
936
+
937
+ input_ndim = hidden_states.ndim
938
+
939
+ if input_ndim == 4:
940
+ batch_size, channel, height, width = hidden_states.shape
941
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
942
+
943
+ batch_size, sequence_length, _ = (
944
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
945
+ )
946
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
947
+
948
+
949
+ if attn.group_norm is not None:
950
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
951
+
952
+ query = attn.to_q(hidden_states)
953
+
954
+ if encoder_hidden_states is None:
955
+ encoder_hidden_states = hidden_states
956
+ elif attn.norm_cross:
957
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
958
+
959
+ key = attn.to_k(encoder_hidden_states)
960
+ value = attn.to_v(encoder_hidden_states)
961
+
962
+ assert num_tasks == 2 # only support two tasks now
963
+
964
+ key_0, key_1 = torch.chunk(key, dim=0, chunks=2) # keys shape (b t) d c
965
+ value_0, value_1 = torch.chunk(value, dim=0, chunks=2)
966
+ key = torch.cat([key_0, key_1], dim=1) # (b t) 2d c
967
+ value = torch.cat([value_0, value_1], dim=1) # (b t) 2d c
968
+ key = torch.cat([key]*2, dim=0) # ( 2 b t) 2d c
969
+ value = torch.cat([value]*2, dim=0) # (2 b t) 2d c
970
+
971
+
972
+ query = attn.head_to_batch_dim(query).contiguous()
973
+ key = attn.head_to_batch_dim(key).contiguous()
974
+ value = attn.head_to_batch_dim(value).contiguous()
975
+
976
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
977
+ hidden_states = torch.bmm(attention_probs, value)
978
+ hidden_states = attn.batch_to_head_dim(hidden_states)
979
+
980
+ # linear proj
981
+ hidden_states = attn.to_out[0](hidden_states)
982
+ # dropout
983
+ hidden_states = attn.to_out[1](hidden_states)
984
+
985
+ if input_ndim == 4:
986
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
987
+
988
+ if attn.residual_connection:
989
+ hidden_states = hidden_states + residual
990
+
991
+ hidden_states = hidden_states / attn.rescale_output_factor
992
+
993
+ return hidden_states
994
+
995
+
multiview/models/transformer_mv2d_rowwise.py ADDED
@@ -0,0 +1,972 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from dataclasses import dataclass
15
+ from typing import Any, Dict, Optional
16
+
17
+ import torch
18
+ import torch.nn.functional as F
19
+ from torch import nn
20
+
21
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
22
+ from diffusers.models.embeddings import ImagePositionalEmbeddings
23
+ from diffusers.utils import BaseOutput, deprecate
24
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
25
+ from diffusers.models.attention import FeedForward, AdaLayerNorm, AdaLayerNormZero, Attention
26
+ from diffusers.models.embeddings import PatchEmbed
27
+ from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear
28
+ from diffusers.models.modeling_utils import ModelMixin
29
+ from diffusers.utils.import_utils import is_xformers_available
30
+
31
+ from einops import rearrange
32
+ import pdb
33
+ import random
34
+ import math
35
+
36
+
37
+ if is_xformers_available():
38
+ import xformers
39
+ import xformers.ops
40
+ else:
41
+ xformers = None
42
+
43
+
44
+ @dataclass
45
+ class TransformerMV2DModelOutput(BaseOutput):
46
+ """
47
+ The output of [`Transformer2DModel`].
48
+
49
+ Args:
50
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
51
+ The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
52
+ distributions for the unnoised latent pixels.
53
+ """
54
+
55
+ sample: torch.FloatTensor
56
+
57
+
58
+ class TransformerMV2DModel(ModelMixin, ConfigMixin):
59
+ """
60
+ A 2D Transformer model for image-like data.
61
+
62
+ Parameters:
63
+ num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
64
+ attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
65
+ in_channels (`int`, *optional*):
66
+ The number of channels in the input and output (specify if the input is **continuous**).
67
+ num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
68
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
69
+ cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
70
+ sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
71
+ This is fixed during training since it is used to learn a number of position embeddings.
72
+ num_vector_embeds (`int`, *optional*):
73
+ The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**).
74
+ Includes the class for the masked latent pixel.
75
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
76
+ num_embeds_ada_norm ( `int`, *optional*):
77
+ The number of diffusion steps used during training. Pass if at least one of the norm_layers is
78
+ `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
79
+ added to the hidden states.
80
+
81
+ During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`.
82
+ attention_bias (`bool`, *optional*):
83
+ Configure if the `TransformerBlocks` attention should contain a bias parameter.
84
+ """
85
+
86
+ @register_to_config
87
+ def __init__(
88
+ self,
89
+ num_attention_heads: int = 16,
90
+ attention_head_dim: int = 88,
91
+ in_channels: Optional[int] = None,
92
+ out_channels: Optional[int] = None,
93
+ num_layers: int = 1,
94
+ dropout: float = 0.0,
95
+ norm_num_groups: int = 32,
96
+ cross_attention_dim: Optional[int] = None,
97
+ attention_bias: bool = False,
98
+ sample_size: Optional[int] = None,
99
+ num_vector_embeds: Optional[int] = None,
100
+ patch_size: Optional[int] = None,
101
+ activation_fn: str = "geglu",
102
+ num_embeds_ada_norm: Optional[int] = None,
103
+ use_linear_projection: bool = False,
104
+ only_cross_attention: bool = False,
105
+ upcast_attention: bool = False,
106
+ norm_type: str = "layer_norm",
107
+ norm_elementwise_affine: bool = True,
108
+ num_views: int = 1,
109
+ cd_attention_last: bool=False,
110
+ cd_attention_mid: bool=False,
111
+ multiview_attention: bool=True,
112
+ sparse_mv_attention: bool = True, # not used
113
+ mvcd_attention: bool=False
114
+ ):
115
+ super().__init__()
116
+ self.use_linear_projection = use_linear_projection
117
+ self.num_attention_heads = num_attention_heads
118
+ self.attention_head_dim = attention_head_dim
119
+ inner_dim = num_attention_heads * attention_head_dim
120
+
121
+ # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
122
+ # Define whether input is continuous or discrete depending on configuration
123
+ self.is_input_continuous = (in_channels is not None) and (patch_size is None)
124
+ self.is_input_vectorized = num_vector_embeds is not None
125
+ self.is_input_patches = in_channels is not None and patch_size is not None
126
+
127
+ if norm_type == "layer_norm" and num_embeds_ada_norm is not None:
128
+ deprecation_message = (
129
+ f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or"
130
+ " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config."
131
+ " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect"
132
+ " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it"
133
+ " would be very nice if you could open a Pull request for the `transformer/config.json` file"
134
+ )
135
+ deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False)
136
+ norm_type = "ada_norm"
137
+
138
+ if self.is_input_continuous and self.is_input_vectorized:
139
+ raise ValueError(
140
+ f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
141
+ " sure that either `in_channels` or `num_vector_embeds` is None."
142
+ )
143
+ elif self.is_input_vectorized and self.is_input_patches:
144
+ raise ValueError(
145
+ f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make"
146
+ " sure that either `num_vector_embeds` or `num_patches` is None."
147
+ )
148
+ elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches:
149
+ raise ValueError(
150
+ f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:"
151
+ f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None."
152
+ )
153
+
154
+ # 2. Define input layers
155
+ if self.is_input_continuous:
156
+ self.in_channels = in_channels
157
+
158
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
159
+ if use_linear_projection:
160
+ self.proj_in = LoRACompatibleLinear(in_channels, inner_dim)
161
+ else:
162
+ self.proj_in = LoRACompatibleConv(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
163
+ elif self.is_input_vectorized:
164
+ assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
165
+ assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed"
166
+
167
+ self.height = sample_size
168
+ self.width = sample_size
169
+ self.num_vector_embeds = num_vector_embeds
170
+ self.num_latent_pixels = self.height * self.width
171
+
172
+ self.latent_image_embedding = ImagePositionalEmbeddings(
173
+ num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width
174
+ )
175
+ elif self.is_input_patches:
176
+ assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size"
177
+
178
+ self.height = sample_size
179
+ self.width = sample_size
180
+
181
+ self.patch_size = patch_size
182
+ self.pos_embed = PatchEmbed(
183
+ height=sample_size,
184
+ width=sample_size,
185
+ patch_size=patch_size,
186
+ in_channels=in_channels,
187
+ embed_dim=inner_dim,
188
+ )
189
+
190
+ # 3. Define transformers blocks
191
+ self.transformer_blocks = nn.ModuleList(
192
+ [
193
+ BasicMVTransformerBlock(
194
+ inner_dim,
195
+ num_attention_heads,
196
+ attention_head_dim,
197
+ dropout=dropout,
198
+ cross_attention_dim=cross_attention_dim,
199
+ activation_fn=activation_fn,
200
+ num_embeds_ada_norm=num_embeds_ada_norm,
201
+ attention_bias=attention_bias,
202
+ only_cross_attention=only_cross_attention,
203
+ upcast_attention=upcast_attention,
204
+ norm_type=norm_type,
205
+ norm_elementwise_affine=norm_elementwise_affine,
206
+ num_views=num_views,
207
+ cd_attention_last=cd_attention_last,
208
+ cd_attention_mid=cd_attention_mid,
209
+ multiview_attention=multiview_attention,
210
+ mvcd_attention=mvcd_attention
211
+ )
212
+ for d in range(num_layers)
213
+ ]
214
+ )
215
+
216
+ # 4. Define output layers
217
+ self.out_channels = in_channels if out_channels is None else out_channels
218
+ if self.is_input_continuous:
219
+ # TODO: should use out_channels for continuous projections
220
+ if use_linear_projection:
221
+ self.proj_out = LoRACompatibleLinear(inner_dim, in_channels)
222
+ else:
223
+ self.proj_out = LoRACompatibleConv(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
224
+ elif self.is_input_vectorized:
225
+ self.norm_out = nn.LayerNorm(inner_dim)
226
+ self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
227
+ elif self.is_input_patches:
228
+ self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
229
+ self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim)
230
+ self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
231
+
232
+ def forward(
233
+ self,
234
+ hidden_states: torch.Tensor,
235
+ encoder_hidden_states: Optional[torch.Tensor] = None,
236
+ timestep: Optional[torch.LongTensor] = None,
237
+ class_labels: Optional[torch.LongTensor] = None,
238
+ cross_attention_kwargs: Dict[str, Any] = None,
239
+ attention_mask: Optional[torch.Tensor] = None,
240
+ encoder_attention_mask: Optional[torch.Tensor] = None,
241
+ return_dict: bool = True,
242
+ ):
243
+ """
244
+ The [`Transformer2DModel`] forward method.
245
+
246
+ Args:
247
+ hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
248
+ Input `hidden_states`.
249
+ encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
250
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
251
+ self-attention.
252
+ timestep ( `torch.LongTensor`, *optional*):
253
+ Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
254
+ class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
255
+ Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
256
+ `AdaLayerZeroNorm`.
257
+ encoder_attention_mask ( `torch.Tensor`, *optional*):
258
+ Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
259
+
260
+ * Mask `(batch, sequence_length)` True = keep, False = discard.
261
+ * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
262
+
263
+ If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
264
+ above. This bias will be added to the cross-attention scores.
265
+ return_dict (`bool`, *optional*, defaults to `True`):
266
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
267
+ tuple.
268
+
269
+ Returns:
270
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
271
+ `tuple` where the first element is the sample tensor.
272
+ """
273
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
274
+ # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
275
+ # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
276
+ # expects mask of shape:
277
+ # [batch, key_tokens]
278
+ # adds singleton query_tokens dimension:
279
+ # [batch, 1, key_tokens]
280
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
281
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
282
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
283
+ if attention_mask is not None and attention_mask.ndim == 2:
284
+ # assume that mask is expressed as:
285
+ # (1 = keep, 0 = discard)
286
+ # convert mask into a bias that can be added to attention scores:
287
+ # (keep = +0, discard = -10000.0)
288
+ attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
289
+ attention_mask = attention_mask.unsqueeze(1)
290
+
291
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
292
+ if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
293
+ encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
294
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
295
+
296
+ # 1. Input
297
+ if self.is_input_continuous:
298
+ batch, _, height, width = hidden_states.shape
299
+ residual = hidden_states
300
+
301
+ hidden_states = self.norm(hidden_states)
302
+ if not self.use_linear_projection:
303
+ hidden_states = self.proj_in(hidden_states)
304
+ inner_dim = hidden_states.shape[1]
305
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
306
+ else:
307
+ inner_dim = hidden_states.shape[1]
308
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
309
+ hidden_states = self.proj_in(hidden_states)
310
+ elif self.is_input_vectorized:
311
+ hidden_states = self.latent_image_embedding(hidden_states)
312
+ elif self.is_input_patches:
313
+ hidden_states = self.pos_embed(hidden_states)
314
+
315
+ # 2. Blocks
316
+ for block in self.transformer_blocks:
317
+ hidden_states = block(
318
+ hidden_states,
319
+ attention_mask=attention_mask,
320
+ encoder_hidden_states=encoder_hidden_states,
321
+ encoder_attention_mask=encoder_attention_mask,
322
+ timestep=timestep,
323
+ cross_attention_kwargs=cross_attention_kwargs,
324
+ class_labels=class_labels,
325
+ )
326
+
327
+ # 3. Output
328
+ if self.is_input_continuous:
329
+ if not self.use_linear_projection:
330
+ hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
331
+ hidden_states = self.proj_out(hidden_states)
332
+ else:
333
+ hidden_states = self.proj_out(hidden_states)
334
+ hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
335
+
336
+ output = hidden_states + residual
337
+ elif self.is_input_vectorized:
338
+ hidden_states = self.norm_out(hidden_states)
339
+ logits = self.out(hidden_states)
340
+ # (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
341
+ logits = logits.permute(0, 2, 1)
342
+
343
+ # log(p(x_0))
344
+ output = F.log_softmax(logits.double(), dim=1).float()
345
+ elif self.is_input_patches:
346
+ # TODO: cleanup!
347
+ conditioning = self.transformer_blocks[0].norm1.emb(
348
+ timestep, class_labels, hidden_dtype=hidden_states.dtype
349
+ )
350
+ shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
351
+ hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
352
+ hidden_states = self.proj_out_2(hidden_states)
353
+
354
+ # unpatchify
355
+ height = width = int(hidden_states.shape[1] ** 0.5)
356
+ hidden_states = hidden_states.reshape(
357
+ shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
358
+ )
359
+ hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
360
+ output = hidden_states.reshape(
361
+ shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
362
+ )
363
+
364
+ if not return_dict:
365
+ return (output,)
366
+
367
+ return TransformerMV2DModelOutput(sample=output)
368
+
369
+
370
+ @maybe_allow_in_graph
371
+ class BasicMVTransformerBlock(nn.Module):
372
+ r"""
373
+ A basic Transformer block.
374
+
375
+ Parameters:
376
+ dim (`int`): The number of channels in the input and output.
377
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
378
+ attention_head_dim (`int`): The number of channels in each head.
379
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
380
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
381
+ only_cross_attention (`bool`, *optional*):
382
+ Whether to use only cross-attention layers. In this case two cross attention layers are used.
383
+ double_self_attention (`bool`, *optional*):
384
+ Whether to use two self-attention layers. In this case no cross attention layers are used.
385
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
386
+ num_embeds_ada_norm (:
387
+ obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
388
+ attention_bias (:
389
+ obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
390
+ """
391
+
392
+ def __init__(
393
+ self,
394
+ dim: int,
395
+ num_attention_heads: int,
396
+ attention_head_dim: int,
397
+ dropout=0.0,
398
+ cross_attention_dim: Optional[int] = None,
399
+ activation_fn: str = "geglu",
400
+ num_embeds_ada_norm: Optional[int] = None,
401
+ attention_bias: bool = False,
402
+ only_cross_attention: bool = False,
403
+ double_self_attention: bool = False,
404
+ upcast_attention: bool = False,
405
+ norm_elementwise_affine: bool = True,
406
+ norm_type: str = "layer_norm",
407
+ final_dropout: bool = False,
408
+ num_views: int = 1,
409
+ cd_attention_last: bool = False,
410
+ cd_attention_mid: bool = False,
411
+ multiview_attention: bool = True,
412
+ mvcd_attention: bool = False,
413
+ rowwise_attention: bool = True
414
+ ):
415
+ super().__init__()
416
+ self.only_cross_attention = only_cross_attention
417
+
418
+ self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
419
+ self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
420
+
421
+ if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
422
+ raise ValueError(
423
+ f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
424
+ f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
425
+ )
426
+
427
+ # Define 3 blocks. Each block has its own normalization layer.
428
+ # 1. Self-Attn
429
+ if self.use_ada_layer_norm:
430
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
431
+ elif self.use_ada_layer_norm_zero:
432
+ self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
433
+ else:
434
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
435
+
436
+ self.multiview_attention = multiview_attention
437
+ self.mvcd_attention = mvcd_attention
438
+ self.rowwise_attention = multiview_attention and rowwise_attention
439
+
440
+ # rowwise multiview attention
441
+
442
+ print('INFO: using row wise attention...')
443
+
444
+ self.attn1 = CustomAttention(
445
+ query_dim=dim,
446
+ heads=num_attention_heads,
447
+ dim_head=attention_head_dim,
448
+ dropout=dropout,
449
+ bias=attention_bias,
450
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
451
+ upcast_attention=upcast_attention,
452
+ processor=MVAttnProcessor()
453
+ )
454
+
455
+ # 2. Cross-Attn
456
+ if cross_attention_dim is not None or double_self_attention:
457
+ # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
458
+ # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
459
+ # the second cross attention block.
460
+ self.norm2 = (
461
+ AdaLayerNorm(dim, num_embeds_ada_norm)
462
+ if self.use_ada_layer_norm
463
+ else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
464
+ )
465
+ self.attn2 = Attention(
466
+ query_dim=dim,
467
+ cross_attention_dim=cross_attention_dim if not double_self_attention else None,
468
+ heads=num_attention_heads,
469
+ dim_head=attention_head_dim,
470
+ dropout=dropout,
471
+ bias=attention_bias,
472
+ upcast_attention=upcast_attention,
473
+ ) # is self-attn if encoder_hidden_states is none
474
+ else:
475
+ self.norm2 = None
476
+ self.attn2 = None
477
+
478
+ # 3. Feed-forward
479
+ self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
480
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
481
+
482
+ # let chunk size default to None
483
+ self._chunk_size = None
484
+ self._chunk_dim = 0
485
+
486
+ self.num_views = num_views
487
+
488
+ self.cd_attention_last = cd_attention_last
489
+
490
+ if self.cd_attention_last:
491
+ # Joint task -Attn
492
+ self.attn_joint = CustomJointAttention(
493
+ query_dim=dim,
494
+ heads=num_attention_heads,
495
+ dim_head=attention_head_dim,
496
+ dropout=dropout,
497
+ bias=attention_bias,
498
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
499
+ upcast_attention=upcast_attention,
500
+ processor=JointAttnProcessor()
501
+ )
502
+ nn.init.zeros_(self.attn_joint.to_out[0].weight.data)
503
+ self.norm_joint = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
504
+
505
+
506
+ self.cd_attention_mid = cd_attention_mid
507
+
508
+ if self.cd_attention_mid:
509
+ print("joint twice")
510
+ # Joint task -Attn
511
+ self.attn_joint_twice = CustomJointAttention(
512
+ query_dim=dim,
513
+ heads=num_attention_heads,
514
+ dim_head=attention_head_dim,
515
+ dropout=dropout,
516
+ bias=attention_bias,
517
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
518
+ upcast_attention=upcast_attention,
519
+ processor=JointAttnProcessor()
520
+ )
521
+ nn.init.zeros_(self.attn_joint_twice.to_out[0].weight.data)
522
+ self.norm_joint_twice = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
523
+
524
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
525
+ # Sets chunk feed-forward
526
+ self._chunk_size = chunk_size
527
+ self._chunk_dim = dim
528
+
529
+ def forward(
530
+ self,
531
+ hidden_states: torch.FloatTensor,
532
+ attention_mask: Optional[torch.FloatTensor] = None,
533
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
534
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
535
+ timestep: Optional[torch.LongTensor] = None,
536
+ cross_attention_kwargs: Dict[str, Any] = None,
537
+ class_labels: Optional[torch.LongTensor] = None,
538
+ ):
539
+ assert attention_mask is None # not supported yet
540
+ # Notice that normalization is always applied before the real computation in the following blocks.
541
+ # 1. Self-Attention
542
+ if self.use_ada_layer_norm:
543
+ norm_hidden_states = self.norm1(hidden_states, timestep)
544
+ elif self.use_ada_layer_norm_zero:
545
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
546
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
547
+ )
548
+ else:
549
+ norm_hidden_states = self.norm1(hidden_states)
550
+
551
+ cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
552
+
553
+ attn_output = self.attn1(
554
+ norm_hidden_states,
555
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
556
+ attention_mask=attention_mask,
557
+ multiview_attention=self.multiview_attention,
558
+ mvcd_attention=self.mvcd_attention,
559
+ num_views=self.num_views,
560
+ **cross_attention_kwargs,
561
+ )
562
+
563
+ if self.use_ada_layer_norm_zero:
564
+ attn_output = gate_msa.unsqueeze(1) * attn_output
565
+ hidden_states = attn_output + hidden_states
566
+
567
+ # joint attention twice
568
+ if self.cd_attention_mid:
569
+ norm_hidden_states = (
570
+ self.norm_joint_twice(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_joint_twice(hidden_states)
571
+ )
572
+ hidden_states = self.attn_joint_twice(norm_hidden_states) + hidden_states
573
+
574
+ # 2. Cross-Attention
575
+ if self.attn2 is not None:
576
+ norm_hidden_states = (
577
+ self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
578
+ )
579
+
580
+ attn_output = self.attn2(
581
+ norm_hidden_states,
582
+ encoder_hidden_states=encoder_hidden_states,
583
+ attention_mask=encoder_attention_mask,
584
+ **cross_attention_kwargs,
585
+ )
586
+ hidden_states = attn_output + hidden_states
587
+
588
+ # 3. Feed-forward
589
+ norm_hidden_states = self.norm3(hidden_states)
590
+
591
+ if self.use_ada_layer_norm_zero:
592
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
593
+
594
+ if self._chunk_size is not None:
595
+ # "feed_forward_chunk_size" can be used to save memory
596
+ if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
597
+ raise ValueError(
598
+ f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
599
+ )
600
+
601
+ num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
602
+ ff_output = torch.cat(
603
+ [self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)],
604
+ dim=self._chunk_dim,
605
+ )
606
+ else:
607
+ ff_output = self.ff(norm_hidden_states)
608
+
609
+ if self.use_ada_layer_norm_zero:
610
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
611
+
612
+ hidden_states = ff_output + hidden_states
613
+
614
+ if self.cd_attention_last:
615
+ norm_hidden_states = (
616
+ self.norm_joint(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_joint(hidden_states)
617
+ )
618
+ hidden_states = self.attn_joint(norm_hidden_states) + hidden_states
619
+
620
+ return hidden_states
621
+
622
+
623
+ class CustomAttention(Attention):
624
+ def set_use_memory_efficient_attention_xformers(
625
+ self, use_memory_efficient_attention_xformers: bool, *args, **kwargs
626
+ ):
627
+ processor = XFormersMVAttnProcessor()
628
+ self.set_processor(processor)
629
+ # print("using xformers attention processor")
630
+
631
+
632
+ class CustomJointAttention(Attention):
633
+ def set_use_memory_efficient_attention_xformers(
634
+ self, use_memory_efficient_attention_xformers: bool, *args, **kwargs
635
+ ):
636
+ processor = XFormersJointAttnProcessor()
637
+ self.set_processor(processor)
638
+ # print("using xformers attention processor")
639
+
640
+ class MVAttnProcessor:
641
+ r"""
642
+ Default processor for performing attention-related computations.
643
+ """
644
+
645
+ def __call__(
646
+ self,
647
+ attn: Attention,
648
+ hidden_states,
649
+ encoder_hidden_states=None,
650
+ attention_mask=None,
651
+ temb=None,
652
+ num_views=1,
653
+ multiview_attention=True
654
+ ):
655
+ residual = hidden_states
656
+
657
+ if attn.spatial_norm is not None:
658
+ hidden_states = attn.spatial_norm(hidden_states, temb)
659
+
660
+ input_ndim = hidden_states.ndim
661
+
662
+ if input_ndim == 4:
663
+ batch_size, channel, height, width = hidden_states.shape
664
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
665
+
666
+ batch_size, sequence_length, _ = (
667
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
668
+ )
669
+ height = int(math.sqrt(sequence_length))
670
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
671
+
672
+ if attn.group_norm is not None:
673
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
674
+
675
+ query = attn.to_q(hidden_states)
676
+
677
+ if encoder_hidden_states is None:
678
+ encoder_hidden_states = hidden_states
679
+ elif attn.norm_cross:
680
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
681
+
682
+ key = attn.to_k(encoder_hidden_states)
683
+ value = attn.to_v(encoder_hidden_states)
684
+
685
+ # multi-view self-attention
686
+ key = rearrange(key, "(b v) (h w) c -> (b h) (v w) c", v=num_views, h=height)
687
+ value = rearrange(value, "(b v) (h w) c -> (b h) (v w) c", v=num_views, h=height)
688
+ query = rearrange(query, "(b v) (h w) c -> (b h) (v w) c", v=num_views, h=height) # torch.Size([192, 384, 320])
689
+
690
+ query = attn.head_to_batch_dim(query).contiguous()
691
+ key = attn.head_to_batch_dim(key).contiguous()
692
+ value = attn.head_to_batch_dim(value).contiguous()
693
+
694
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
695
+ hidden_states = torch.bmm(attention_probs, value)
696
+ hidden_states = attn.batch_to_head_dim(hidden_states)
697
+
698
+ # linear proj
699
+ hidden_states = attn.to_out[0](hidden_states)
700
+ # dropout
701
+ hidden_states = attn.to_out[1](hidden_states)
702
+ hidden_states = rearrange(hidden_states, "(b h) (v w) c -> (b v) (h w) c", v=num_views, h=height)
703
+ if input_ndim == 4:
704
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
705
+
706
+ if attn.residual_connection:
707
+ hidden_states = hidden_states + residual
708
+
709
+ hidden_states = hidden_states / attn.rescale_output_factor
710
+
711
+ return hidden_states
712
+
713
+
714
+ class XFormersMVAttnProcessor:
715
+ r"""
716
+ Default processor for performing attention-related computations.
717
+ """
718
+
719
+ def __call__(
720
+ self,
721
+ attn: Attention,
722
+ hidden_states,
723
+ encoder_hidden_states=None,
724
+ attention_mask=None,
725
+ temb=None,
726
+ num_views=1,
727
+ multiview_attention=True,
728
+ mvcd_attention=False,
729
+ ):
730
+ residual = hidden_states
731
+
732
+ if attn.spatial_norm is not None:
733
+ hidden_states = attn.spatial_norm(hidden_states, temb)
734
+
735
+ input_ndim = hidden_states.ndim
736
+
737
+ if input_ndim == 4:
738
+ batch_size, channel, height, width = hidden_states.shape
739
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
740
+
741
+ batch_size, sequence_length, _ = (
742
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
743
+ )
744
+ height = int(math.sqrt(sequence_length))
745
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
746
+ # from yuancheng; here attention_mask is None
747
+ if attention_mask is not None:
748
+ # expand our mask's singleton query_tokens dimension:
749
+ # [batch*heads, 1, key_tokens] ->
750
+ # [batch*heads, query_tokens, key_tokens]
751
+ # so that it can be added as a bias onto the attention scores that xformers computes:
752
+ # [batch*heads, query_tokens, key_tokens]
753
+ # we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
754
+ _, query_tokens, _ = hidden_states.shape
755
+ attention_mask = attention_mask.expand(-1, query_tokens, -1)
756
+
757
+ if attn.group_norm is not None:
758
+ print('Warning: using group norm, pay attention to use it in row-wise attention')
759
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
760
+
761
+ query = attn.to_q(hidden_states)
762
+
763
+ if encoder_hidden_states is None:
764
+ encoder_hidden_states = hidden_states
765
+ elif attn.norm_cross:
766
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
767
+
768
+ key_raw = attn.to_k(encoder_hidden_states)
769
+ value_raw = attn.to_v(encoder_hidden_states)
770
+
771
+ key = rearrange(key_raw, "(b v) (h w) c -> (b h) (v w) c", v=num_views, h=height)
772
+ value = rearrange(value_raw, "(b v) (h w) c -> (b h) (v w) c", v=num_views, h=height)
773
+ query = rearrange(query, "(b v) (h w) c -> (b h) (v w) c", v=num_views, h=height) # torch.Size([192, 384, 320])
774
+ if mvcd_attention:
775
+ # memory efficient, cross domain attention
776
+ key_0, key_1 = torch.chunk(key_raw, dim=0, chunks=2) # keys shape (b t) d c
777
+ value_0, value_1 = torch.chunk(value_raw, dim=0, chunks=2)
778
+ key_cross = torch.concat([key_1, key_0], dim=0)
779
+ value_cross = torch.concat([value_1, value_0], dim=0) # shape (b t) d c
780
+ key = torch.cat([key, key_cross], dim=1)
781
+ value = torch.cat([value, value_cross], dim=1) # shape (b t) (t+1 d) c
782
+
783
+
784
+ query = attn.head_to_batch_dim(query) # torch.Size([960, 384, 64])
785
+ key = attn.head_to_batch_dim(key)
786
+ value = attn.head_to_batch_dim(value)
787
+
788
+ hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
789
+ hidden_states = attn.batch_to_head_dim(hidden_states)
790
+
791
+ # linear proj
792
+ hidden_states = attn.to_out[0](hidden_states)
793
+ # dropout
794
+ hidden_states = attn.to_out[1](hidden_states)
795
+ # print(hidden_states.shape)
796
+ hidden_states = rearrange(hidden_states, "(b h) (v w) c -> (b v) (h w) c", v=num_views, h=height)
797
+ if input_ndim == 4:
798
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
799
+
800
+ if attn.residual_connection:
801
+ hidden_states = hidden_states + residual
802
+
803
+ hidden_states = hidden_states / attn.rescale_output_factor
804
+
805
+ return hidden_states
806
+
807
+
808
+ class XFormersJointAttnProcessor:
809
+ r"""
810
+ Default processor for performing attention-related computations.
811
+ """
812
+
813
+ def __call__(
814
+ self,
815
+ attn: Attention,
816
+ hidden_states,
817
+ encoder_hidden_states=None,
818
+ attention_mask=None,
819
+ temb=None,
820
+ num_tasks=2
821
+ ):
822
+
823
+ residual = hidden_states
824
+
825
+ if attn.spatial_norm is not None:
826
+ hidden_states = attn.spatial_norm(hidden_states, temb)
827
+
828
+ input_ndim = hidden_states.ndim
829
+
830
+ if input_ndim == 4:
831
+ batch_size, channel, height, width = hidden_states.shape
832
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
833
+
834
+ batch_size, sequence_length, _ = (
835
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
836
+ )
837
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
838
+
839
+ # from yuancheng; here attention_mask is None
840
+ if attention_mask is not None:
841
+ # expand our mask's singleton query_tokens dimension:
842
+ # [batch*heads, 1, key_tokens] ->
843
+ # [batch*heads, query_tokens, key_tokens]
844
+ # so that it can be added as a bias onto the attention scores that xformers computes:
845
+ # [batch*heads, query_tokens, key_tokens]
846
+ # we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
847
+ _, query_tokens, _ = hidden_states.shape
848
+ attention_mask = attention_mask.expand(-1, query_tokens, -1)
849
+
850
+ if attn.group_norm is not None:
851
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
852
+
853
+ query = attn.to_q(hidden_states)
854
+
855
+ if encoder_hidden_states is None:
856
+ encoder_hidden_states = hidden_states
857
+ elif attn.norm_cross:
858
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
859
+
860
+ key = attn.to_k(encoder_hidden_states)
861
+ value = attn.to_v(encoder_hidden_states)
862
+
863
+ assert num_tasks == 2 # only support two tasks now
864
+
865
+ key_0, key_1 = torch.chunk(key, dim=0, chunks=2) # keys shape (b t) d c
866
+ value_0, value_1 = torch.chunk(value, dim=0, chunks=2)
867
+ key = torch.cat([key_0, key_1], dim=1) # (b t) 2d c
868
+ value = torch.cat([value_0, value_1], dim=1) # (b t) 2d c
869
+ key = torch.cat([key]*2, dim=0) # ( 2 b t) 2d c
870
+ value = torch.cat([value]*2, dim=0) # (2 b t) 2d c
871
+
872
+
873
+ query = attn.head_to_batch_dim(query).contiguous()
874
+ key = attn.head_to_batch_dim(key).contiguous()
875
+ value = attn.head_to_batch_dim(value).contiguous()
876
+
877
+ hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
878
+ hidden_states = attn.batch_to_head_dim(hidden_states)
879
+
880
+ # linear proj
881
+ hidden_states = attn.to_out[0](hidden_states)
882
+ # dropout
883
+ hidden_states = attn.to_out[1](hidden_states)
884
+
885
+ if input_ndim == 4:
886
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
887
+
888
+ if attn.residual_connection:
889
+ hidden_states = hidden_states + residual
890
+
891
+ hidden_states = hidden_states / attn.rescale_output_factor
892
+
893
+ return hidden_states
894
+
895
+
896
+ class JointAttnProcessor:
897
+ r"""
898
+ Default processor for performing attention-related computations.
899
+ """
900
+
901
+ def __call__(
902
+ self,
903
+ attn: Attention,
904
+ hidden_states,
905
+ encoder_hidden_states=None,
906
+ attention_mask=None,
907
+ temb=None,
908
+ num_tasks=2
909
+ ):
910
+
911
+ residual = hidden_states
912
+
913
+ if attn.spatial_norm is not None:
914
+ hidden_states = attn.spatial_norm(hidden_states, temb)
915
+
916
+ input_ndim = hidden_states.ndim
917
+
918
+ if input_ndim == 4:
919
+ batch_size, channel, height, width = hidden_states.shape
920
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
921
+
922
+ batch_size, sequence_length, _ = (
923
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
924
+ )
925
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
926
+
927
+
928
+ if attn.group_norm is not None:
929
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
930
+
931
+ query = attn.to_q(hidden_states)
932
+
933
+ if encoder_hidden_states is None:
934
+ encoder_hidden_states = hidden_states
935
+ elif attn.norm_cross:
936
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
937
+
938
+ key = attn.to_k(encoder_hidden_states)
939
+ value = attn.to_v(encoder_hidden_states)
940
+
941
+ assert num_tasks == 2 # only support two tasks now
942
+
943
+ key_0, key_1 = torch.chunk(key, dim=0, chunks=2) # keys shape (b t) d c
944
+ value_0, value_1 = torch.chunk(value, dim=0, chunks=2)
945
+ key = torch.cat([key_0, key_1], dim=1) # (b t) 2d c
946
+ value = torch.cat([value_0, value_1], dim=1) # (b t) 2d c
947
+ key = torch.cat([key]*2, dim=0) # ( 2 b t) 2d c
948
+ value = torch.cat([value]*2, dim=0) # (2 b t) 2d c
949
+
950
+
951
+ query = attn.head_to_batch_dim(query).contiguous()
952
+ key = attn.head_to_batch_dim(key).contiguous()
953
+ value = attn.head_to_batch_dim(value).contiguous()
954
+
955
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
956
+ hidden_states = torch.bmm(attention_probs, value)
957
+ hidden_states = attn.batch_to_head_dim(hidden_states)
958
+
959
+ # linear proj
960
+ hidden_states = attn.to_out[0](hidden_states)
961
+ # dropout
962
+ hidden_states = attn.to_out[1](hidden_states)
963
+
964
+ if input_ndim == 4:
965
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
966
+
967
+ if attn.residual_connection:
968
+ hidden_states = hidden_states + residual
969
+
970
+ hidden_states = hidden_states / attn.rescale_output_factor
971
+
972
+ return hidden_states
multiview/models/transformer_mv2d_self_rowwise.py ADDED
@@ -0,0 +1,1042 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from dataclasses import dataclass
15
+ from typing import Any, Dict, Optional
16
+
17
+ import torch
18
+ import torch.nn.functional as F
19
+ from torch import nn
20
+
21
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
22
+ from diffusers.models.embeddings import ImagePositionalEmbeddings
23
+ from diffusers.utils import BaseOutput, deprecate
24
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
25
+ from diffusers.models.attention import FeedForward, AdaLayerNorm, AdaLayerNormZero, Attention
26
+ from diffusers.models.embeddings import PatchEmbed
27
+ from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear
28
+ from diffusers.models.modeling_utils import ModelMixin
29
+ from diffusers.utils.import_utils import is_xformers_available
30
+
31
+ from einops import rearrange
32
+ import pdb
33
+ import random
34
+ import math
35
+
36
+
37
+ if is_xformers_available():
38
+ import xformers
39
+ import xformers.ops
40
+ else:
41
+ xformers = None
42
+
43
+
44
+ @dataclass
45
+ class TransformerMV2DModelOutput(BaseOutput):
46
+ """
47
+ The output of [`Transformer2DModel`].
48
+
49
+ Args:
50
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
51
+ The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
52
+ distributions for the unnoised latent pixels.
53
+ """
54
+
55
+ sample: torch.FloatTensor
56
+
57
+
58
+ class TransformerMV2DModel(ModelMixin, ConfigMixin):
59
+ """
60
+ A 2D Transformer model for image-like data.
61
+
62
+ Parameters:
63
+ num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
64
+ attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
65
+ in_channels (`int`, *optional*):
66
+ The number of channels in the input and output (specify if the input is **continuous**).
67
+ num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
68
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
69
+ cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
70
+ sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
71
+ This is fixed during training since it is used to learn a number of position embeddings.
72
+ num_vector_embeds (`int`, *optional*):
73
+ The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**).
74
+ Includes the class for the masked latent pixel.
75
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
76
+ num_embeds_ada_norm ( `int`, *optional*):
77
+ The number of diffusion steps used during training. Pass if at least one of the norm_layers is
78
+ `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
79
+ added to the hidden states.
80
+
81
+ During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`.
82
+ attention_bias (`bool`, *optional*):
83
+ Configure if the `TransformerBlocks` attention should contain a bias parameter.
84
+ """
85
+
86
+ @register_to_config
87
+ def __init__(
88
+ self,
89
+ num_attention_heads: int = 16,
90
+ attention_head_dim: int = 88,
91
+ in_channels: Optional[int] = None,
92
+ out_channels: Optional[int] = None,
93
+ num_layers: int = 1,
94
+ dropout: float = 0.0,
95
+ norm_num_groups: int = 32,
96
+ cross_attention_dim: Optional[int] = None,
97
+ attention_bias: bool = False,
98
+ sample_size: Optional[int] = None,
99
+ num_vector_embeds: Optional[int] = None,
100
+ patch_size: Optional[int] = None,
101
+ activation_fn: str = "geglu",
102
+ num_embeds_ada_norm: Optional[int] = None,
103
+ use_linear_projection: bool = False,
104
+ only_cross_attention: bool = False,
105
+ upcast_attention: bool = False,
106
+ norm_type: str = "layer_norm",
107
+ norm_elementwise_affine: bool = True,
108
+ num_views: int = 1,
109
+ cd_attention_mid: bool=False,
110
+ cd_attention_last: bool=False,
111
+ multiview_attention: bool=True,
112
+ sparse_mv_attention: bool = True, # not used
113
+ mvcd_attention: bool=False,
114
+ use_dino: bool=False
115
+ ):
116
+ super().__init__()
117
+ self.use_linear_projection = use_linear_projection
118
+ self.num_attention_heads = num_attention_heads
119
+ self.attention_head_dim = attention_head_dim
120
+ inner_dim = num_attention_heads * attention_head_dim
121
+
122
+ # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
123
+ # Define whether input is continuous or discrete depending on configuration
124
+ self.is_input_continuous = (in_channels is not None) and (patch_size is None)
125
+ self.is_input_vectorized = num_vector_embeds is not None
126
+ self.is_input_patches = in_channels is not None and patch_size is not None
127
+
128
+ if norm_type == "layer_norm" and num_embeds_ada_norm is not None:
129
+ deprecation_message = (
130
+ f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or"
131
+ " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config."
132
+ " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect"
133
+ " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it"
134
+ " would be very nice if you could open a Pull request for the `transformer/config.json` file"
135
+ )
136
+ deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False)
137
+ norm_type = "ada_norm"
138
+
139
+ if self.is_input_continuous and self.is_input_vectorized:
140
+ raise ValueError(
141
+ f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
142
+ " sure that either `in_channels` or `num_vector_embeds` is None."
143
+ )
144
+ elif self.is_input_vectorized and self.is_input_patches:
145
+ raise ValueError(
146
+ f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make"
147
+ " sure that either `num_vector_embeds` or `num_patches` is None."
148
+ )
149
+ elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches:
150
+ raise ValueError(
151
+ f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:"
152
+ f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None."
153
+ )
154
+
155
+ # 2. Define input layers
156
+ if self.is_input_continuous:
157
+ self.in_channels = in_channels
158
+
159
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
160
+ if use_linear_projection:
161
+ self.proj_in = LoRACompatibleLinear(in_channels, inner_dim)
162
+ else:
163
+ self.proj_in = LoRACompatibleConv(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
164
+ elif self.is_input_vectorized:
165
+ assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
166
+ assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed"
167
+
168
+ self.height = sample_size
169
+ self.width = sample_size
170
+ self.num_vector_embeds = num_vector_embeds
171
+ self.num_latent_pixels = self.height * self.width
172
+
173
+ self.latent_image_embedding = ImagePositionalEmbeddings(
174
+ num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width
175
+ )
176
+ elif self.is_input_patches:
177
+ assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size"
178
+
179
+ self.height = sample_size
180
+ self.width = sample_size
181
+
182
+ self.patch_size = patch_size
183
+ self.pos_embed = PatchEmbed(
184
+ height=sample_size,
185
+ width=sample_size,
186
+ patch_size=patch_size,
187
+ in_channels=in_channels,
188
+ embed_dim=inner_dim,
189
+ )
190
+
191
+ # 3. Define transformers blocks
192
+ self.transformer_blocks = nn.ModuleList(
193
+ [
194
+ BasicMVTransformerBlock(
195
+ inner_dim,
196
+ num_attention_heads,
197
+ attention_head_dim,
198
+ dropout=dropout,
199
+ cross_attention_dim=cross_attention_dim,
200
+ activation_fn=activation_fn,
201
+ num_embeds_ada_norm=num_embeds_ada_norm,
202
+ attention_bias=attention_bias,
203
+ only_cross_attention=only_cross_attention,
204
+ upcast_attention=upcast_attention,
205
+ norm_type=norm_type,
206
+ norm_elementwise_affine=norm_elementwise_affine,
207
+ num_views=num_views,
208
+ cd_attention_last=cd_attention_last,
209
+ cd_attention_mid=cd_attention_mid,
210
+ multiview_attention=multiview_attention,
211
+ mvcd_attention=mvcd_attention,
212
+ use_dino=use_dino
213
+ )
214
+ for d in range(num_layers)
215
+ ]
216
+ )
217
+
218
+ # 4. Define output layers
219
+ self.out_channels = in_channels if out_channels is None else out_channels
220
+ if self.is_input_continuous:
221
+ # TODO: should use out_channels for continuous projections
222
+ if use_linear_projection:
223
+ self.proj_out = LoRACompatibleLinear(inner_dim, in_channels)
224
+ else:
225
+ self.proj_out = LoRACompatibleConv(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
226
+ elif self.is_input_vectorized:
227
+ self.norm_out = nn.LayerNorm(inner_dim)
228
+ self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
229
+ elif self.is_input_patches:
230
+ self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
231
+ self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim)
232
+ self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
233
+
234
+ def forward(
235
+ self,
236
+ hidden_states: torch.Tensor,
237
+ encoder_hidden_states: Optional[torch.Tensor] = None,
238
+ dino_feature: Optional[torch.Tensor] = None,
239
+ timestep: Optional[torch.LongTensor] = None,
240
+ class_labels: Optional[torch.LongTensor] = None,
241
+ cross_attention_kwargs: Dict[str, Any] = None,
242
+ attention_mask: Optional[torch.Tensor] = None,
243
+ encoder_attention_mask: Optional[torch.Tensor] = None,
244
+ hw_ratio: Optional[torch.FloatTensor] = 1.5,
245
+ return_dict: bool = True,
246
+ ):
247
+ """
248
+ The [`Transformer2DModel`] forward method.
249
+
250
+ Args:
251
+ hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
252
+ Input `hidden_states`.
253
+ encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
254
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
255
+ self-attention.
256
+ timestep ( `torch.LongTensor`, *optional*):
257
+ Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
258
+ class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
259
+ Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
260
+ `AdaLayerZeroNorm`.
261
+ encoder_attention_mask ( `torch.Tensor`, *optional*):
262
+ Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
263
+
264
+ * Mask `(batch, sequence_length)` True = keep, False = discard.
265
+ * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
266
+
267
+ If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
268
+ above. This bias will be added to the cross-attention scores.
269
+ return_dict (`bool`, *optional*, defaults to `True`):
270
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
271
+ tuple.
272
+
273
+ Returns:
274
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
275
+ `tuple` where the first element is the sample tensor.
276
+ """
277
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
278
+ # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
279
+ # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
280
+ # expects mask of shape:
281
+ # [batch, key_tokens]
282
+ # adds singleton query_tokens dimension:
283
+ # [batch, 1, key_tokens]
284
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
285
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
286
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
287
+ if attention_mask is not None and attention_mask.ndim == 2:
288
+ # assume that mask is expressed as:
289
+ # (1 = keep, 0 = discard)
290
+ # convert mask into a bias that can be added to attention scores:
291
+ # (keep = +0, discard = -10000.0)
292
+ attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
293
+ attention_mask = attention_mask.unsqueeze(1)
294
+
295
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
296
+ if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
297
+ encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
298
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
299
+
300
+ # 1. Input
301
+ if self.is_input_continuous:
302
+ batch, _, height, width = hidden_states.shape
303
+ residual = hidden_states
304
+
305
+ hidden_states = self.norm(hidden_states)
306
+ if not self.use_linear_projection:
307
+ hidden_states = self.proj_in(hidden_states)
308
+ inner_dim = hidden_states.shape[1]
309
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
310
+ else:
311
+ inner_dim = hidden_states.shape[1]
312
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
313
+ hidden_states = self.proj_in(hidden_states)
314
+ elif self.is_input_vectorized:
315
+ hidden_states = self.latent_image_embedding(hidden_states)
316
+ elif self.is_input_patches:
317
+ hidden_states = self.pos_embed(hidden_states)
318
+
319
+ # 2. Blocks
320
+ for block in self.transformer_blocks:
321
+ hidden_states = block(
322
+ hidden_states,
323
+ attention_mask=attention_mask,
324
+ encoder_hidden_states=encoder_hidden_states,
325
+ dino_feature=dino_feature,
326
+ encoder_attention_mask=encoder_attention_mask,
327
+ timestep=timestep,
328
+ cross_attention_kwargs=cross_attention_kwargs,
329
+ class_labels=class_labels,
330
+ hw_ratio=hw_ratio,
331
+ )
332
+
333
+ # 3. Output
334
+ if self.is_input_continuous:
335
+ if not self.use_linear_projection:
336
+ hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
337
+ hidden_states = self.proj_out(hidden_states)
338
+ else:
339
+ hidden_states = self.proj_out(hidden_states)
340
+ hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
341
+
342
+ output = hidden_states + residual
343
+ elif self.is_input_vectorized:
344
+ hidden_states = self.norm_out(hidden_states)
345
+ logits = self.out(hidden_states)
346
+ # (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
347
+ logits = logits.permute(0, 2, 1)
348
+
349
+ # log(p(x_0))
350
+ output = F.log_softmax(logits.double(), dim=1).float()
351
+ elif self.is_input_patches:
352
+ # TODO: cleanup!
353
+ conditioning = self.transformer_blocks[0].norm1.emb(
354
+ timestep, class_labels, hidden_dtype=hidden_states.dtype
355
+ )
356
+ shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
357
+ hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
358
+ hidden_states = self.proj_out_2(hidden_states)
359
+
360
+ # unpatchify
361
+ height = width = int(hidden_states.shape[1] ** 0.5)
362
+ hidden_states = hidden_states.reshape(
363
+ shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
364
+ )
365
+ hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
366
+ output = hidden_states.reshape(
367
+ shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
368
+ )
369
+
370
+ if not return_dict:
371
+ return (output,)
372
+
373
+ return TransformerMV2DModelOutput(sample=output)
374
+
375
+
376
+ @maybe_allow_in_graph
377
+ class BasicMVTransformerBlock(nn.Module):
378
+ r"""
379
+ A basic Transformer block.
380
+
381
+ Parameters:
382
+ dim (`int`): The number of channels in the input and output.
383
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
384
+ attention_head_dim (`int`): The number of channels in each head.
385
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
386
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
387
+ only_cross_attention (`bool`, *optional*):
388
+ Whether to use only cross-attention layers. In this case two cross attention layers are used.
389
+ double_self_attention (`bool`, *optional*):
390
+ Whether to use two self-attention layers. In this case no cross attention layers are used.
391
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
392
+ num_embeds_ada_norm (:
393
+ obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
394
+ attention_bias (:
395
+ obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
396
+ """
397
+
398
+ def __init__(
399
+ self,
400
+ dim: int,
401
+ num_attention_heads: int,
402
+ attention_head_dim: int,
403
+ dropout=0.0,
404
+ cross_attention_dim: Optional[int] = None,
405
+ activation_fn: str = "geglu",
406
+ num_embeds_ada_norm: Optional[int] = None,
407
+ attention_bias: bool = False,
408
+ only_cross_attention: bool = False,
409
+ double_self_attention: bool = False,
410
+ upcast_attention: bool = False,
411
+ norm_elementwise_affine: bool = True,
412
+ norm_type: str = "layer_norm",
413
+ final_dropout: bool = False,
414
+ num_views: int = 1,
415
+ cd_attention_last: bool = False,
416
+ cd_attention_mid: bool = False,
417
+ multiview_attention: bool = True,
418
+ mvcd_attention: bool = False,
419
+ rowwise_attention: bool = True,
420
+ use_dino: bool = False
421
+ ):
422
+ super().__init__()
423
+ self.only_cross_attention = only_cross_attention
424
+
425
+ self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
426
+ self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
427
+
428
+ if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
429
+ raise ValueError(
430
+ f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
431
+ f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
432
+ )
433
+
434
+ # Define 3 blocks. Each block has its own normalization layer.
435
+ # 1. Self-Attn
436
+ if self.use_ada_layer_norm:
437
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
438
+ elif self.use_ada_layer_norm_zero:
439
+ self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
440
+ else:
441
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
442
+
443
+ self.multiview_attention = multiview_attention
444
+ self.mvcd_attention = mvcd_attention
445
+ self.cd_attention_mid = cd_attention_mid
446
+ self.rowwise_attention = multiview_attention and rowwise_attention
447
+
448
+ if mvcd_attention and (not cd_attention_mid):
449
+ # add cross domain attn to self attn
450
+ self.attn1 = CustomJointAttention(
451
+ query_dim=dim,
452
+ heads=num_attention_heads,
453
+ dim_head=attention_head_dim,
454
+ dropout=dropout,
455
+ bias=attention_bias,
456
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
457
+ upcast_attention=upcast_attention,
458
+ processor=JointAttnProcessor()
459
+ )
460
+ else:
461
+ self.attn1 = Attention(
462
+ query_dim=dim,
463
+ heads=num_attention_heads,
464
+ dim_head=attention_head_dim,
465
+ dropout=dropout,
466
+ bias=attention_bias,
467
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
468
+ upcast_attention=upcast_attention
469
+ )
470
+ # 1.1 rowwise multiview attention
471
+ if self.rowwise_attention:
472
+ # print('INFO: using self+row_wise mv attention...')
473
+ self.norm_mv = (
474
+ AdaLayerNorm(dim, num_embeds_ada_norm)
475
+ if self.use_ada_layer_norm
476
+ else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
477
+ )
478
+ self.attn_mv = CustomAttention(
479
+ query_dim=dim,
480
+ heads=num_attention_heads,
481
+ dim_head=attention_head_dim,
482
+ dropout=dropout,
483
+ bias=attention_bias,
484
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
485
+ upcast_attention=upcast_attention,
486
+ processor=MVAttnProcessor()
487
+ )
488
+ nn.init.zeros_(self.attn_mv.to_out[0].weight.data)
489
+ else:
490
+ self.norm_mv = None
491
+ self.attn_mv = None
492
+
493
+ # # 1.2 rowwise cross-domain attn
494
+ # if mvcd_attention:
495
+ # self.attn_joint = CustomJointAttention(
496
+ # query_dim=dim,
497
+ # heads=num_attention_heads,
498
+ # dim_head=attention_head_dim,
499
+ # dropout=dropout,
500
+ # bias=attention_bias,
501
+ # cross_attention_dim=cross_attention_dim if only_cross_attention else None,
502
+ # upcast_attention=upcast_attention,
503
+ # processor=JointAttnProcessor()
504
+ # )
505
+ # nn.init.zeros_(self.attn_joint.to_out[0].weight.data)
506
+ # self.norm_joint = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
507
+ # else:
508
+ # self.attn_joint = None
509
+ # self.norm_joint = None
510
+
511
+ # 2. Cross-Attn
512
+ if cross_attention_dim is not None or double_self_attention:
513
+ # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
514
+ # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
515
+ # the second cross attention block.
516
+ self.norm2 = (
517
+ AdaLayerNorm(dim, num_embeds_ada_norm)
518
+ if self.use_ada_layer_norm
519
+ else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
520
+ )
521
+ self.attn2 = Attention(
522
+ query_dim=dim,
523
+ cross_attention_dim=cross_attention_dim if not double_self_attention else None,
524
+ heads=num_attention_heads,
525
+ dim_head=attention_head_dim,
526
+ dropout=dropout,
527
+ bias=attention_bias,
528
+ upcast_attention=upcast_attention,
529
+ ) # is self-attn if encoder_hidden_states is none
530
+ else:
531
+ self.norm2 = None
532
+ self.attn2 = None
533
+
534
+ # 3. Feed-forward
535
+ self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
536
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
537
+
538
+ # let chunk size default to None
539
+ self._chunk_size = None
540
+ self._chunk_dim = 0
541
+
542
+ self.num_views = num_views
543
+
544
+
545
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
546
+ # Sets chunk feed-forward
547
+ self._chunk_size = chunk_size
548
+ self._chunk_dim = dim
549
+
550
+ def forward(
551
+ self,
552
+ hidden_states: torch.FloatTensor,
553
+ attention_mask: Optional[torch.FloatTensor] = None,
554
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
555
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
556
+ timestep: Optional[torch.LongTensor] = None,
557
+ cross_attention_kwargs: Dict[str, Any] = None,
558
+ class_labels: Optional[torch.LongTensor] = None,
559
+ dino_feature: Optional[torch.FloatTensor] = None,
560
+ hw_ratio: Optional[torch.FloatTensor] = 1.5,
561
+ ):
562
+ assert attention_mask is None # not supported yet
563
+ # Notice that normalization is always applied before the real computation in the following blocks.
564
+ # 1. Self-Attention
565
+ if self.use_ada_layer_norm:
566
+ norm_hidden_states = self.norm1(hidden_states, timestep)
567
+ elif self.use_ada_layer_norm_zero:
568
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
569
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
570
+ )
571
+ else:
572
+ norm_hidden_states = self.norm1(hidden_states)
573
+
574
+ cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
575
+
576
+ attn_output = self.attn1(
577
+ norm_hidden_states,
578
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
579
+ attention_mask=attention_mask,
580
+ # multiview_attention=self.multiview_attention,
581
+ # mvcd_attention=self.mvcd_attention,
582
+ **cross_attention_kwargs,
583
+ )
584
+
585
+ if self.use_ada_layer_norm_zero:
586
+ attn_output = gate_msa.unsqueeze(1) * attn_output
587
+ hidden_states = attn_output + hidden_states
588
+
589
+ # import pdb;pdb.set_trace()
590
+ # 1.1 row wise multiview attention
591
+ if self.rowwise_attention:
592
+ norm_hidden_states = (
593
+ self.norm_mv(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_mv(hidden_states)
594
+ )
595
+ attn_output = self.attn_mv(
596
+ norm_hidden_states,
597
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
598
+ attention_mask=attention_mask,
599
+ num_views=self.num_views,
600
+ multiview_attention=self.multiview_attention,
601
+ cd_attention_mid=self.cd_attention_mid,
602
+ hw_ratio=hw_ratio,
603
+ **cross_attention_kwargs,
604
+ )
605
+ hidden_states = attn_output + hidden_states
606
+
607
+
608
+ # 2. Cross-Attention
609
+ if self.attn2 is not None:
610
+ norm_hidden_states = (
611
+ self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
612
+ )
613
+
614
+ attn_output = self.attn2(
615
+ norm_hidden_states,
616
+ encoder_hidden_states=encoder_hidden_states,
617
+ attention_mask=encoder_attention_mask,
618
+ **cross_attention_kwargs,
619
+ )
620
+ hidden_states = attn_output + hidden_states
621
+
622
+ # 3. Feed-forward
623
+ norm_hidden_states = self.norm3(hidden_states)
624
+
625
+ if self.use_ada_layer_norm_zero:
626
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
627
+
628
+ if self._chunk_size is not None:
629
+ # "feed_forward_chunk_size" can be used to save memory
630
+ if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
631
+ raise ValueError(
632
+ f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
633
+ )
634
+
635
+ num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
636
+ ff_output = torch.cat(
637
+ [self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)],
638
+ dim=self._chunk_dim,
639
+ )
640
+ else:
641
+ ff_output = self.ff(norm_hidden_states)
642
+
643
+ if self.use_ada_layer_norm_zero:
644
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
645
+
646
+ hidden_states = ff_output + hidden_states
647
+
648
+ return hidden_states
649
+
650
+
651
+ class CustomAttention(Attention):
652
+ def set_use_memory_efficient_attention_xformers(
653
+ self, use_memory_efficient_attention_xformers: bool, *args, **kwargs
654
+ ):
655
+ processor = XFormersMVAttnProcessor()
656
+ self.set_processor(processor)
657
+ # print("using xformers attention processor")
658
+
659
+
660
+ class CustomJointAttention(Attention):
661
+ def set_use_memory_efficient_attention_xformers(
662
+ self, use_memory_efficient_attention_xformers: bool, *args, **kwargs
663
+ ):
664
+ processor = XFormersJointAttnProcessor()
665
+ self.set_processor(processor)
666
+ # print("using xformers attention processor")
667
+
668
+ class MVAttnProcessor:
669
+ r"""
670
+ Default processor for performing attention-related computations.
671
+ """
672
+
673
+ def __call__(
674
+ self,
675
+ attn: Attention,
676
+ hidden_states,
677
+ encoder_hidden_states=None,
678
+ attention_mask=None,
679
+ temb=None,
680
+ num_views=1,
681
+ cd_attention_mid=False
682
+ ):
683
+ residual = hidden_states
684
+
685
+ if attn.spatial_norm is not None:
686
+ hidden_states = attn.spatial_norm(hidden_states, temb)
687
+
688
+ input_ndim = hidden_states.ndim
689
+
690
+ if input_ndim == 4:
691
+ batch_size, channel, height, width = hidden_states.shape
692
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
693
+
694
+ batch_size, sequence_length, _ = (
695
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
696
+ )
697
+ height = int(math.sqrt(sequence_length))
698
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
699
+
700
+ if attn.group_norm is not None:
701
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
702
+
703
+ query = attn.to_q(hidden_states)
704
+
705
+ if encoder_hidden_states is None:
706
+ encoder_hidden_states = hidden_states
707
+ elif attn.norm_cross:
708
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
709
+
710
+ key = attn.to_k(encoder_hidden_states)
711
+ value = attn.to_v(encoder_hidden_states)
712
+
713
+ # print('query', query.shape, 'key', key.shape, 'value', value.shape)
714
+ #([bx4, 1024, 320]) key torch.Size([bx4, 1024, 320]) value torch.Size([bx4, 1024, 320])
715
+ # pdb.set_trace()
716
+ # multi-view self-attention
717
+ def transpose(tensor):
718
+ tensor = rearrange(tensor, "(b v) (h w) c -> b v h w c", v=num_views, h=height)
719
+ tensor_0, tensor_1 = torch.chunk(tensor, dim=0, chunks=2) # b v h w c
720
+ tensor = torch.cat([tensor_0, tensor_1], dim=3) # b v h 2w c
721
+ tensor = rearrange(tensor, "b v h w c -> (b h) (v w) c", v=num_views, h=height)
722
+ return tensor
723
+
724
+ if cd_attention_mid:
725
+ key = transpose(key)
726
+ value = transpose(value)
727
+ query = transpose(query)
728
+ else:
729
+ key = rearrange(key, "(b v) (h w) c -> (b h) (v w) c", v=num_views, h=height)
730
+ value = rearrange(value, "(b v) (h w) c -> (b h) (v w) c", v=num_views, h=height)
731
+ query = rearrange(query, "(b v) (h w) c -> (b h) (v w) c", v=num_views, h=height) # torch.Size([192, 384, 320])
732
+
733
+ query = attn.head_to_batch_dim(query).contiguous()
734
+ key = attn.head_to_batch_dim(key).contiguous()
735
+ value = attn.head_to_batch_dim(value).contiguous()
736
+
737
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
738
+ hidden_states = torch.bmm(attention_probs, value)
739
+ hidden_states = attn.batch_to_head_dim(hidden_states)
740
+
741
+ # linear proj
742
+ hidden_states = attn.to_out[0](hidden_states)
743
+ # dropout
744
+ hidden_states = attn.to_out[1](hidden_states)
745
+ if cd_attention_mid:
746
+ hidden_states = rearrange(hidden_states, "(b h) (v w) c -> b v h w c", v=num_views, h=height)
747
+ hidden_states_0, hidden_states_1 = torch.chunk(hidden_states, dim=3, chunks=2) # b v h w c
748
+ hidden_states = torch.cat([hidden_states_0, hidden_states_1], dim=0) # 2b v h w c
749
+ hidden_states = rearrange(hidden_states, "b v h w c -> (b v) (h w) c", v=num_views, h=height)
750
+ else:
751
+ hidden_states = rearrange(hidden_states, "(b h) (v w) c -> (b v) (h w) c", v=num_views, h=height)
752
+ if input_ndim == 4:
753
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
754
+
755
+ if attn.residual_connection:
756
+ hidden_states = hidden_states + residual
757
+
758
+ hidden_states = hidden_states / attn.rescale_output_factor
759
+
760
+ return hidden_states
761
+
762
+
763
+ class XFormersMVAttnProcessor:
764
+ r"""
765
+ Default processor for performing attention-related computations.
766
+ """
767
+
768
+ def __call__(
769
+ self,
770
+ attn: Attention,
771
+ hidden_states,
772
+ encoder_hidden_states=None,
773
+ attention_mask=None,
774
+ temb=None,
775
+ num_views=1,
776
+ multiview_attention=True,
777
+ cd_attention_mid=False,
778
+ hw_ratio=1.5
779
+ ):
780
+ # print(num_views)
781
+ residual = hidden_states
782
+
783
+ if attn.spatial_norm is not None:
784
+ hidden_states = attn.spatial_norm(hidden_states, temb)
785
+
786
+ input_ndim = hidden_states.ndim
787
+
788
+ if input_ndim == 4:
789
+ batch_size, channel, height, width = hidden_states.shape
790
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
791
+
792
+ batch_size, sequence_length, _ = (
793
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
794
+ )
795
+ height = int(math.sqrt(sequence_length*hw_ratio))
796
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
797
+ # from yuancheng; here attention_mask is None
798
+ if attention_mask is not None:
799
+ # expand our mask's singleton query_tokens dimension:
800
+ # [batch*heads, 1, key_tokens] ->
801
+ # [batch*heads, query_tokens, key_tokens]
802
+ # so that it can be added as a bias onto the attention scores that xformers computes:
803
+ # [batch*heads, query_tokens, key_tokens]
804
+ # we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
805
+ _, query_tokens, _ = hidden_states.shape
806
+ attention_mask = attention_mask.expand(-1, query_tokens, -1)
807
+
808
+ if attn.group_norm is not None:
809
+ print('Warning: using group norm, pay attention to use it in row-wise attention')
810
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
811
+
812
+ query = attn.to_q(hidden_states)
813
+
814
+ if encoder_hidden_states is None:
815
+ encoder_hidden_states = hidden_states
816
+ elif attn.norm_cross:
817
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
818
+
819
+ key_raw = attn.to_k(encoder_hidden_states)
820
+ value_raw = attn.to_v(encoder_hidden_states)
821
+
822
+ # print('query', query.shape, 'key', key.shape, 'value', value.shape)
823
+ # pdb.set_trace()
824
+ def transpose(tensor):
825
+ tensor = rearrange(tensor, "(b v) (h w) c -> b v h w c", v=num_views, h=height)
826
+ tensor_0, tensor_1 = torch.chunk(tensor, dim=0, chunks=2) # b v h w c
827
+ tensor = torch.cat([tensor_0, tensor_1], dim=3) # b v h 2w c
828
+ tensor = rearrange(tensor, "b v h w c -> (b h) (v w) c", v=num_views, h=height)
829
+ return tensor
830
+ # print(mvcd_attention)
831
+ # import pdb;pdb.set_trace()
832
+ if cd_attention_mid:
833
+ key = transpose(key_raw)
834
+ value = transpose(value_raw)
835
+ query = transpose(query)
836
+ else:
837
+ key = rearrange(key_raw, "(b v) (h w) c -> (b h) (v w) c", v=num_views, h=height)
838
+ value = rearrange(value_raw, "(b v) (h w) c -> (b h) (v w) c", v=num_views, h=height)
839
+ query = rearrange(query, "(b v) (h w) c -> (b h) (v w) c", v=num_views, h=height) # torch.Size([192, 384, 320])
840
+
841
+
842
+ query = attn.head_to_batch_dim(query) # torch.Size([960, 384, 64])
843
+ key = attn.head_to_batch_dim(key)
844
+ value = attn.head_to_batch_dim(value)
845
+
846
+ hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
847
+ hidden_states = attn.batch_to_head_dim(hidden_states)
848
+
849
+ # linear proj
850
+ hidden_states = attn.to_out[0](hidden_states)
851
+ # dropout
852
+ hidden_states = attn.to_out[1](hidden_states)
853
+
854
+ if cd_attention_mid:
855
+ hidden_states = rearrange(hidden_states, "(b h) (v w) c -> b v h w c", v=num_views, h=height)
856
+ hidden_states_0, hidden_states_1 = torch.chunk(hidden_states, dim=3, chunks=2) # b v h w c
857
+ hidden_states = torch.cat([hidden_states_0, hidden_states_1], dim=0) # 2b v h w c
858
+ hidden_states = rearrange(hidden_states, "b v h w c -> (b v) (h w) c", v=num_views, h=height)
859
+ else:
860
+ hidden_states = rearrange(hidden_states, "(b h) (v w) c -> (b v) (h w) c", v=num_views, h=height)
861
+ if input_ndim == 4:
862
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
863
+
864
+ if attn.residual_connection:
865
+ hidden_states = hidden_states + residual
866
+
867
+ hidden_states = hidden_states / attn.rescale_output_factor
868
+
869
+ return hidden_states
870
+
871
+
872
+ class XFormersJointAttnProcessor:
873
+ r"""
874
+ Default processor for performing attention-related computations.
875
+ """
876
+
877
+ def __call__(
878
+ self,
879
+ attn: Attention,
880
+ hidden_states,
881
+ encoder_hidden_states=None,
882
+ attention_mask=None,
883
+ temb=None,
884
+ num_tasks=2
885
+ ):
886
+ residual = hidden_states
887
+
888
+ if attn.spatial_norm is not None:
889
+ hidden_states = attn.spatial_norm(hidden_states, temb)
890
+
891
+ input_ndim = hidden_states.ndim
892
+
893
+ if input_ndim == 4:
894
+ batch_size, channel, height, width = hidden_states.shape
895
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
896
+
897
+ batch_size, sequence_length, _ = (
898
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
899
+ )
900
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
901
+
902
+ # from yuancheng; here attention_mask is None
903
+ if attention_mask is not None:
904
+ # expand our mask's singleton query_tokens dimension:
905
+ # [batch*heads, 1, key_tokens] ->
906
+ # [batch*heads, query_tokens, key_tokens]
907
+ # so that it can be added as a bias onto the attention scores that xformers computes:
908
+ # [batch*heads, query_tokens, key_tokens]
909
+ # we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
910
+ _, query_tokens, _ = hidden_states.shape
911
+ attention_mask = attention_mask.expand(-1, query_tokens, -1)
912
+
913
+ if attn.group_norm is not None:
914
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
915
+
916
+ query = attn.to_q(hidden_states)
917
+
918
+ if encoder_hidden_states is None:
919
+ encoder_hidden_states = hidden_states
920
+ elif attn.norm_cross:
921
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
922
+
923
+ key = attn.to_k(encoder_hidden_states)
924
+ value = attn.to_v(encoder_hidden_states)
925
+
926
+ assert num_tasks == 2 # only support two tasks now
927
+
928
+ def transpose(tensor):
929
+ tensor_0, tensor_1 = torch.chunk(tensor, dim=0, chunks=2) # bv hw c
930
+ tensor = torch.cat([tensor_0, tensor_1], dim=1) # bv 2hw c
931
+ return tensor
932
+ key = transpose(key)
933
+ value = transpose(value)
934
+ query = transpose(query)
935
+ # from icecream import ic
936
+ # ic(key.shape, value.shape, query.shape)
937
+ # import pdb;pdb.set_trace()
938
+ query = attn.head_to_batch_dim(query).contiguous()
939
+ key = attn.head_to_batch_dim(key).contiguous()
940
+ value = attn.head_to_batch_dim(value).contiguous()
941
+
942
+ hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
943
+ hidden_states = attn.batch_to_head_dim(hidden_states)
944
+
945
+ # linear proj
946
+ hidden_states = attn.to_out[0](hidden_states)
947
+ # dropout
948
+ hidden_states = attn.to_out[1](hidden_states)
949
+ hidden_states_normal, hidden_states_color = torch.chunk(hidden_states, dim=1, chunks=2)
950
+ hidden_states = torch.cat([hidden_states_normal, hidden_states_color], dim=0) # 2bv hw c
951
+
952
+ if input_ndim == 4:
953
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
954
+
955
+ if attn.residual_connection:
956
+ hidden_states = hidden_states + residual
957
+
958
+ hidden_states = hidden_states / attn.rescale_output_factor
959
+
960
+ return hidden_states
961
+
962
+
963
+ class JointAttnProcessor:
964
+ r"""
965
+ Default processor for performing attention-related computations.
966
+ """
967
+
968
+ def __call__(
969
+ self,
970
+ attn: Attention,
971
+ hidden_states,
972
+ encoder_hidden_states=None,
973
+ attention_mask=None,
974
+ temb=None,
975
+ num_tasks=2
976
+ ):
977
+
978
+ residual = hidden_states
979
+
980
+ if attn.spatial_norm is not None:
981
+ hidden_states = attn.spatial_norm(hidden_states, temb)
982
+
983
+ input_ndim = hidden_states.ndim
984
+
985
+ if input_ndim == 4:
986
+ batch_size, channel, height, width = hidden_states.shape
987
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
988
+
989
+ batch_size, sequence_length, _ = (
990
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
991
+ )
992
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
993
+
994
+
995
+ if attn.group_norm is not None:
996
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
997
+
998
+ query = attn.to_q(hidden_states)
999
+
1000
+ if encoder_hidden_states is None:
1001
+ encoder_hidden_states = hidden_states
1002
+ elif attn.norm_cross:
1003
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
1004
+
1005
+ key = attn.to_k(encoder_hidden_states)
1006
+ value = attn.to_v(encoder_hidden_states)
1007
+
1008
+ assert num_tasks == 2 # only support two tasks now
1009
+
1010
+ def transpose(tensor):
1011
+ tensor_0, tensor_1 = torch.chunk(tensor, dim=0, chunks=2) # bv hw c
1012
+ tensor = torch.cat([tensor_0, tensor_1], dim=1) # bv 2hw c
1013
+ return tensor
1014
+ key = transpose(key)
1015
+ value = transpose(value)
1016
+ query = transpose(query)
1017
+
1018
+
1019
+ query = attn.head_to_batch_dim(query).contiguous()
1020
+ key = attn.head_to_batch_dim(key).contiguous()
1021
+ value = attn.head_to_batch_dim(value).contiguous()
1022
+
1023
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
1024
+ hidden_states = torch.bmm(attention_probs, value)
1025
+ hidden_states = attn.batch_to_head_dim(hidden_states)
1026
+
1027
+
1028
+ # linear proj
1029
+ hidden_states = attn.to_out[0](hidden_states)
1030
+ # dropout
1031
+ hidden_states = attn.to_out[1](hidden_states)
1032
+
1033
+ hidden_states = torch.cat([hidden_states[:, 0], hidden_states[:, 1]], dim=0) # 2bv hw c
1034
+ if input_ndim == 4:
1035
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
1036
+
1037
+ if attn.residual_connection:
1038
+ hidden_states = hidden_states + residual
1039
+
1040
+ hidden_states = hidden_states / attn.rescale_output_factor
1041
+
1042
+ return hidden_states
multiview/models/unet_mv2d_blocks.py ADDED
@@ -0,0 +1,980 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Any, Dict, Optional, Tuple
15
+
16
+ import numpy as np
17
+ import torch
18
+ import torch.nn.functional as F
19
+ from torch import nn
20
+
21
+ from diffusers.utils import is_torch_version, logging
22
+ from diffusers.models.normalization import AdaGroupNorm
23
+ from diffusers.models.attention_processor import Attention, AttnAddedKVProcessor, AttnAddedKVProcessor2_0
24
+ from diffusers.models.dual_transformer_2d import DualTransformer2DModel
25
+ from diffusers.models.resnet import Downsample2D, FirDownsample2D, FirUpsample2D, KDownsample2D, KUpsample2D, ResnetBlock2D, Upsample2D
26
+
27
+ from diffusers.models.unets.unet_2d_blocks import DownBlock2D, ResnetDownsampleBlock2D, AttnDownBlock2D, CrossAttnDownBlock2D, SimpleCrossAttnDownBlock2D, SkipDownBlock2D, AttnSkipDownBlock2D, DownEncoderBlock2D, AttnDownEncoderBlock2D, KDownBlock2D, KCrossAttnDownBlock2D
28
+ from diffusers.models.unets.unet_2d_blocks import UpBlock2D, ResnetUpsampleBlock2D, CrossAttnUpBlock2D, SimpleCrossAttnUpBlock2D, AttnUpBlock2D, SkipUpBlock2D, AttnSkipUpBlock2D, UpDecoderBlock2D, AttnUpDecoderBlock2D, KUpBlock2D, KCrossAttnUpBlock2D
29
+
30
+
31
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
32
+
33
+
34
+ def get_down_block(
35
+ down_block_type,
36
+ num_layers,
37
+ in_channels,
38
+ out_channels,
39
+ temb_channels,
40
+ add_downsample,
41
+ resnet_eps,
42
+ resnet_act_fn,
43
+ transformer_layers_per_block=1,
44
+ num_attention_heads=None,
45
+ resnet_groups=None,
46
+ cross_attention_dim=None,
47
+ downsample_padding=None,
48
+ dual_cross_attention=False,
49
+ use_linear_projection=False,
50
+ only_cross_attention=False,
51
+ upcast_attention=False,
52
+ resnet_time_scale_shift="default",
53
+ resnet_skip_time_act=False,
54
+ resnet_out_scale_factor=1.0,
55
+ cross_attention_norm=None,
56
+ attention_head_dim=None,
57
+ downsample_type=None,
58
+ num_views=1,
59
+ cd_attention_last: bool = False,
60
+ cd_attention_mid: bool = False,
61
+ multiview_attention: bool = True,
62
+ sparse_mv_attention: bool = False,
63
+ selfattn_block: str = "custom",
64
+ mvcd_attention: bool=False,
65
+ use_dino: bool = False
66
+ ):
67
+ # If attn head dim is not defined, we default it to the number of heads
68
+ if attention_head_dim is None:
69
+ logger.warn(
70
+ f"It is recommended to provide `attention_head_dim` when calling `get_down_block`. Defaulting `attention_head_dim` to {num_attention_heads}."
71
+ )
72
+ attention_head_dim = num_attention_heads
73
+
74
+ down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
75
+ if down_block_type == "DownBlock2D":
76
+ return DownBlock2D(
77
+ num_layers=num_layers,
78
+ in_channels=in_channels,
79
+ out_channels=out_channels,
80
+ temb_channels=temb_channels,
81
+ add_downsample=add_downsample,
82
+ resnet_eps=resnet_eps,
83
+ resnet_act_fn=resnet_act_fn,
84
+ resnet_groups=resnet_groups,
85
+ downsample_padding=downsample_padding,
86
+ resnet_time_scale_shift=resnet_time_scale_shift,
87
+ )
88
+ elif down_block_type == "ResnetDownsampleBlock2D":
89
+ return ResnetDownsampleBlock2D(
90
+ num_layers=num_layers,
91
+ in_channels=in_channels,
92
+ out_channels=out_channels,
93
+ temb_channels=temb_channels,
94
+ add_downsample=add_downsample,
95
+ resnet_eps=resnet_eps,
96
+ resnet_act_fn=resnet_act_fn,
97
+ resnet_groups=resnet_groups,
98
+ resnet_time_scale_shift=resnet_time_scale_shift,
99
+ skip_time_act=resnet_skip_time_act,
100
+ output_scale_factor=resnet_out_scale_factor,
101
+ )
102
+ elif down_block_type == "AttnDownBlock2D":
103
+ if add_downsample is False:
104
+ downsample_type = None
105
+ else:
106
+ downsample_type = downsample_type or "conv" # default to 'conv'
107
+ return AttnDownBlock2D(
108
+ num_layers=num_layers,
109
+ in_channels=in_channels,
110
+ out_channels=out_channels,
111
+ temb_channels=temb_channels,
112
+ resnet_eps=resnet_eps,
113
+ resnet_act_fn=resnet_act_fn,
114
+ resnet_groups=resnet_groups,
115
+ downsample_padding=downsample_padding,
116
+ attention_head_dim=attention_head_dim,
117
+ resnet_time_scale_shift=resnet_time_scale_shift,
118
+ downsample_type=downsample_type,
119
+ )
120
+ elif down_block_type == "CrossAttnDownBlock2D":
121
+ if cross_attention_dim is None:
122
+ raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock2D")
123
+ return CrossAttnDownBlock2D(
124
+ num_layers=num_layers,
125
+ transformer_layers_per_block=transformer_layers_per_block,
126
+ in_channels=in_channels,
127
+ out_channels=out_channels,
128
+ temb_channels=temb_channels,
129
+ add_downsample=add_downsample,
130
+ resnet_eps=resnet_eps,
131
+ resnet_act_fn=resnet_act_fn,
132
+ resnet_groups=resnet_groups,
133
+ downsample_padding=downsample_padding,
134
+ cross_attention_dim=cross_attention_dim,
135
+ num_attention_heads=num_attention_heads,
136
+ dual_cross_attention=dual_cross_attention,
137
+ use_linear_projection=use_linear_projection,
138
+ only_cross_attention=only_cross_attention,
139
+ upcast_attention=upcast_attention,
140
+ resnet_time_scale_shift=resnet_time_scale_shift,
141
+ )
142
+ # custom MV2D attention block
143
+ elif down_block_type == "CrossAttnDownBlockMV2D":
144
+ if cross_attention_dim is None:
145
+ raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlockMV2D")
146
+ return CrossAttnDownBlockMV2D(
147
+ num_layers=num_layers,
148
+ transformer_layers_per_block=transformer_layers_per_block,
149
+ in_channels=in_channels,
150
+ out_channels=out_channels,
151
+ temb_channels=temb_channels,
152
+ add_downsample=add_downsample,
153
+ resnet_eps=resnet_eps,
154
+ resnet_act_fn=resnet_act_fn,
155
+ resnet_groups=resnet_groups,
156
+ downsample_padding=downsample_padding,
157
+ cross_attention_dim=cross_attention_dim,
158
+ num_attention_heads=num_attention_heads,
159
+ dual_cross_attention=dual_cross_attention,
160
+ use_linear_projection=use_linear_projection,
161
+ only_cross_attention=only_cross_attention,
162
+ upcast_attention=upcast_attention,
163
+ resnet_time_scale_shift=resnet_time_scale_shift,
164
+ num_views=num_views,
165
+ cd_attention_last=cd_attention_last,
166
+ cd_attention_mid=cd_attention_mid,
167
+ multiview_attention=multiview_attention,
168
+ sparse_mv_attention=sparse_mv_attention,
169
+ selfattn_block=selfattn_block,
170
+ mvcd_attention=mvcd_attention,
171
+ use_dino=use_dino
172
+ )
173
+ elif down_block_type == "SimpleCrossAttnDownBlock2D":
174
+ if cross_attention_dim is None:
175
+ raise ValueError("cross_attention_dim must be specified for SimpleCrossAttnDownBlock2D")
176
+ return SimpleCrossAttnDownBlock2D(
177
+ num_layers=num_layers,
178
+ in_channels=in_channels,
179
+ out_channels=out_channels,
180
+ temb_channels=temb_channels,
181
+ add_downsample=add_downsample,
182
+ resnet_eps=resnet_eps,
183
+ resnet_act_fn=resnet_act_fn,
184
+ resnet_groups=resnet_groups,
185
+ cross_attention_dim=cross_attention_dim,
186
+ attention_head_dim=attention_head_dim,
187
+ resnet_time_scale_shift=resnet_time_scale_shift,
188
+ skip_time_act=resnet_skip_time_act,
189
+ output_scale_factor=resnet_out_scale_factor,
190
+ only_cross_attention=only_cross_attention,
191
+ cross_attention_norm=cross_attention_norm,
192
+ )
193
+ elif down_block_type == "SkipDownBlock2D":
194
+ return SkipDownBlock2D(
195
+ num_layers=num_layers,
196
+ in_channels=in_channels,
197
+ out_channels=out_channels,
198
+ temb_channels=temb_channels,
199
+ add_downsample=add_downsample,
200
+ resnet_eps=resnet_eps,
201
+ resnet_act_fn=resnet_act_fn,
202
+ downsample_padding=downsample_padding,
203
+ resnet_time_scale_shift=resnet_time_scale_shift,
204
+ )
205
+ elif down_block_type == "AttnSkipDownBlock2D":
206
+ return AttnSkipDownBlock2D(
207
+ num_layers=num_layers,
208
+ in_channels=in_channels,
209
+ out_channels=out_channels,
210
+ temb_channels=temb_channels,
211
+ add_downsample=add_downsample,
212
+ resnet_eps=resnet_eps,
213
+ resnet_act_fn=resnet_act_fn,
214
+ attention_head_dim=attention_head_dim,
215
+ resnet_time_scale_shift=resnet_time_scale_shift,
216
+ )
217
+ elif down_block_type == "DownEncoderBlock2D":
218
+ return DownEncoderBlock2D(
219
+ num_layers=num_layers,
220
+ in_channels=in_channels,
221
+ out_channels=out_channels,
222
+ add_downsample=add_downsample,
223
+ resnet_eps=resnet_eps,
224
+ resnet_act_fn=resnet_act_fn,
225
+ resnet_groups=resnet_groups,
226
+ downsample_padding=downsample_padding,
227
+ resnet_time_scale_shift=resnet_time_scale_shift,
228
+ )
229
+ elif down_block_type == "AttnDownEncoderBlock2D":
230
+ return AttnDownEncoderBlock2D(
231
+ num_layers=num_layers,
232
+ in_channels=in_channels,
233
+ out_channels=out_channels,
234
+ add_downsample=add_downsample,
235
+ resnet_eps=resnet_eps,
236
+ resnet_act_fn=resnet_act_fn,
237
+ resnet_groups=resnet_groups,
238
+ downsample_padding=downsample_padding,
239
+ attention_head_dim=attention_head_dim,
240
+ resnet_time_scale_shift=resnet_time_scale_shift,
241
+ )
242
+ elif down_block_type == "KDownBlock2D":
243
+ return KDownBlock2D(
244
+ num_layers=num_layers,
245
+ in_channels=in_channels,
246
+ out_channels=out_channels,
247
+ temb_channels=temb_channels,
248
+ add_downsample=add_downsample,
249
+ resnet_eps=resnet_eps,
250
+ resnet_act_fn=resnet_act_fn,
251
+ )
252
+ elif down_block_type == "KCrossAttnDownBlock2D":
253
+ return KCrossAttnDownBlock2D(
254
+ num_layers=num_layers,
255
+ in_channels=in_channels,
256
+ out_channels=out_channels,
257
+ temb_channels=temb_channels,
258
+ add_downsample=add_downsample,
259
+ resnet_eps=resnet_eps,
260
+ resnet_act_fn=resnet_act_fn,
261
+ cross_attention_dim=cross_attention_dim,
262
+ attention_head_dim=attention_head_dim,
263
+ add_self_attention=True if not add_downsample else False,
264
+ )
265
+ raise ValueError(f"{down_block_type} does not exist.")
266
+
267
+
268
+ def get_up_block(
269
+ up_block_type,
270
+ num_layers,
271
+ in_channels,
272
+ out_channels,
273
+ prev_output_channel,
274
+ temb_channels,
275
+ add_upsample,
276
+ resnet_eps,
277
+ resnet_act_fn,
278
+ transformer_layers_per_block=1,
279
+ num_attention_heads=None,
280
+ resnet_groups=None,
281
+ cross_attention_dim=None,
282
+ dual_cross_attention=False,
283
+ use_linear_projection=False,
284
+ only_cross_attention=False,
285
+ upcast_attention=False,
286
+ resnet_time_scale_shift="default",
287
+ resnet_skip_time_act=False,
288
+ resnet_out_scale_factor=1.0,
289
+ cross_attention_norm=None,
290
+ attention_head_dim=None,
291
+ upsample_type=None,
292
+ num_views=1,
293
+ cd_attention_last: bool = False,
294
+ cd_attention_mid: bool = False,
295
+ multiview_attention: bool = True,
296
+ sparse_mv_attention: bool = False,
297
+ selfattn_block: str = "custom",
298
+ mvcd_attention: bool=False,
299
+ use_dino: bool = False
300
+ ):
301
+ # If attn head dim is not defined, we default it to the number of heads
302
+ if attention_head_dim is None:
303
+ logger.warn(
304
+ f"It is recommended to provide `attention_head_dim` when calling `get_up_block`. Defaulting `attention_head_dim` to {num_attention_heads}."
305
+ )
306
+ attention_head_dim = num_attention_heads
307
+
308
+ up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
309
+ if up_block_type == "UpBlock2D":
310
+ return UpBlock2D(
311
+ num_layers=num_layers,
312
+ in_channels=in_channels,
313
+ out_channels=out_channels,
314
+ prev_output_channel=prev_output_channel,
315
+ temb_channels=temb_channels,
316
+ add_upsample=add_upsample,
317
+ resnet_eps=resnet_eps,
318
+ resnet_act_fn=resnet_act_fn,
319
+ resnet_groups=resnet_groups,
320
+ resnet_time_scale_shift=resnet_time_scale_shift,
321
+ )
322
+ elif up_block_type == "ResnetUpsampleBlock2D":
323
+ return ResnetUpsampleBlock2D(
324
+ num_layers=num_layers,
325
+ in_channels=in_channels,
326
+ out_channels=out_channels,
327
+ prev_output_channel=prev_output_channel,
328
+ temb_channels=temb_channels,
329
+ add_upsample=add_upsample,
330
+ resnet_eps=resnet_eps,
331
+ resnet_act_fn=resnet_act_fn,
332
+ resnet_groups=resnet_groups,
333
+ resnet_time_scale_shift=resnet_time_scale_shift,
334
+ skip_time_act=resnet_skip_time_act,
335
+ output_scale_factor=resnet_out_scale_factor,
336
+ )
337
+ elif up_block_type == "CrossAttnUpBlock2D":
338
+ if cross_attention_dim is None:
339
+ raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock2D")
340
+ return CrossAttnUpBlock2D(
341
+ num_layers=num_layers,
342
+ transformer_layers_per_block=transformer_layers_per_block,
343
+ in_channels=in_channels,
344
+ out_channels=out_channels,
345
+ prev_output_channel=prev_output_channel,
346
+ temb_channels=temb_channels,
347
+ add_upsample=add_upsample,
348
+ resnet_eps=resnet_eps,
349
+ resnet_act_fn=resnet_act_fn,
350
+ resnet_groups=resnet_groups,
351
+ cross_attention_dim=cross_attention_dim,
352
+ num_attention_heads=num_attention_heads,
353
+ dual_cross_attention=dual_cross_attention,
354
+ use_linear_projection=use_linear_projection,
355
+ only_cross_attention=only_cross_attention,
356
+ upcast_attention=upcast_attention,
357
+ resnet_time_scale_shift=resnet_time_scale_shift,
358
+ )
359
+ # custom MV2D attention block
360
+ elif up_block_type == "CrossAttnUpBlockMV2D":
361
+ if cross_attention_dim is None:
362
+ raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlockMV2D")
363
+ return CrossAttnUpBlockMV2D(
364
+ num_layers=num_layers,
365
+ transformer_layers_per_block=transformer_layers_per_block,
366
+ in_channels=in_channels,
367
+ out_channels=out_channels,
368
+ prev_output_channel=prev_output_channel,
369
+ temb_channels=temb_channels,
370
+ add_upsample=add_upsample,
371
+ resnet_eps=resnet_eps,
372
+ resnet_act_fn=resnet_act_fn,
373
+ resnet_groups=resnet_groups,
374
+ cross_attention_dim=cross_attention_dim,
375
+ num_attention_heads=num_attention_heads,
376
+ dual_cross_attention=dual_cross_attention,
377
+ use_linear_projection=use_linear_projection,
378
+ only_cross_attention=only_cross_attention,
379
+ upcast_attention=upcast_attention,
380
+ resnet_time_scale_shift=resnet_time_scale_shift,
381
+ num_views=num_views,
382
+ cd_attention_last=cd_attention_last,
383
+ cd_attention_mid=cd_attention_mid,
384
+ multiview_attention=multiview_attention,
385
+ sparse_mv_attention=sparse_mv_attention,
386
+ selfattn_block=selfattn_block,
387
+ mvcd_attention=mvcd_attention,
388
+ use_dino=use_dino
389
+ )
390
+ elif up_block_type == "SimpleCrossAttnUpBlock2D":
391
+ if cross_attention_dim is None:
392
+ raise ValueError("cross_attention_dim must be specified for SimpleCrossAttnUpBlock2D")
393
+ return SimpleCrossAttnUpBlock2D(
394
+ num_layers=num_layers,
395
+ in_channels=in_channels,
396
+ out_channels=out_channels,
397
+ prev_output_channel=prev_output_channel,
398
+ temb_channels=temb_channels,
399
+ add_upsample=add_upsample,
400
+ resnet_eps=resnet_eps,
401
+ resnet_act_fn=resnet_act_fn,
402
+ resnet_groups=resnet_groups,
403
+ cross_attention_dim=cross_attention_dim,
404
+ attention_head_dim=attention_head_dim,
405
+ resnet_time_scale_shift=resnet_time_scale_shift,
406
+ skip_time_act=resnet_skip_time_act,
407
+ output_scale_factor=resnet_out_scale_factor,
408
+ only_cross_attention=only_cross_attention,
409
+ cross_attention_norm=cross_attention_norm,
410
+ )
411
+ elif up_block_type == "AttnUpBlock2D":
412
+ if add_upsample is False:
413
+ upsample_type = None
414
+ else:
415
+ upsample_type = upsample_type or "conv" # default to 'conv'
416
+
417
+ return AttnUpBlock2D(
418
+ num_layers=num_layers,
419
+ in_channels=in_channels,
420
+ out_channels=out_channels,
421
+ prev_output_channel=prev_output_channel,
422
+ temb_channels=temb_channels,
423
+ resnet_eps=resnet_eps,
424
+ resnet_act_fn=resnet_act_fn,
425
+ resnet_groups=resnet_groups,
426
+ attention_head_dim=attention_head_dim,
427
+ resnet_time_scale_shift=resnet_time_scale_shift,
428
+ upsample_type=upsample_type,
429
+ )
430
+ elif up_block_type == "SkipUpBlock2D":
431
+ return SkipUpBlock2D(
432
+ num_layers=num_layers,
433
+ in_channels=in_channels,
434
+ out_channels=out_channels,
435
+ prev_output_channel=prev_output_channel,
436
+ temb_channels=temb_channels,
437
+ add_upsample=add_upsample,
438
+ resnet_eps=resnet_eps,
439
+ resnet_act_fn=resnet_act_fn,
440
+ resnet_time_scale_shift=resnet_time_scale_shift,
441
+ )
442
+ elif up_block_type == "AttnSkipUpBlock2D":
443
+ return AttnSkipUpBlock2D(
444
+ num_layers=num_layers,
445
+ in_channels=in_channels,
446
+ out_channels=out_channels,
447
+ prev_output_channel=prev_output_channel,
448
+ temb_channels=temb_channels,
449
+ add_upsample=add_upsample,
450
+ resnet_eps=resnet_eps,
451
+ resnet_act_fn=resnet_act_fn,
452
+ attention_head_dim=attention_head_dim,
453
+ resnet_time_scale_shift=resnet_time_scale_shift,
454
+ )
455
+ elif up_block_type == "UpDecoderBlock2D":
456
+ return UpDecoderBlock2D(
457
+ num_layers=num_layers,
458
+ in_channels=in_channels,
459
+ out_channels=out_channels,
460
+ add_upsample=add_upsample,
461
+ resnet_eps=resnet_eps,
462
+ resnet_act_fn=resnet_act_fn,
463
+ resnet_groups=resnet_groups,
464
+ resnet_time_scale_shift=resnet_time_scale_shift,
465
+ temb_channels=temb_channels,
466
+ )
467
+ elif up_block_type == "AttnUpDecoderBlock2D":
468
+ return AttnUpDecoderBlock2D(
469
+ num_layers=num_layers,
470
+ in_channels=in_channels,
471
+ out_channels=out_channels,
472
+ add_upsample=add_upsample,
473
+ resnet_eps=resnet_eps,
474
+ resnet_act_fn=resnet_act_fn,
475
+ resnet_groups=resnet_groups,
476
+ attention_head_dim=attention_head_dim,
477
+ resnet_time_scale_shift=resnet_time_scale_shift,
478
+ temb_channels=temb_channels,
479
+ )
480
+ elif up_block_type == "KUpBlock2D":
481
+ return KUpBlock2D(
482
+ num_layers=num_layers,
483
+ in_channels=in_channels,
484
+ out_channels=out_channels,
485
+ temb_channels=temb_channels,
486
+ add_upsample=add_upsample,
487
+ resnet_eps=resnet_eps,
488
+ resnet_act_fn=resnet_act_fn,
489
+ )
490
+ elif up_block_type == "KCrossAttnUpBlock2D":
491
+ return KCrossAttnUpBlock2D(
492
+ num_layers=num_layers,
493
+ in_channels=in_channels,
494
+ out_channels=out_channels,
495
+ temb_channels=temb_channels,
496
+ add_upsample=add_upsample,
497
+ resnet_eps=resnet_eps,
498
+ resnet_act_fn=resnet_act_fn,
499
+ cross_attention_dim=cross_attention_dim,
500
+ attention_head_dim=attention_head_dim,
501
+ )
502
+
503
+ raise ValueError(f"{up_block_type} does not exist.")
504
+
505
+
506
+ class UNetMidBlockMV2DCrossAttn(nn.Module):
507
+ def __init__(
508
+ self,
509
+ in_channels: int,
510
+ temb_channels: int,
511
+ dropout: float = 0.0,
512
+ num_layers: int = 1,
513
+ transformer_layers_per_block: int = 1,
514
+ resnet_eps: float = 1e-6,
515
+ resnet_time_scale_shift: str = "default",
516
+ resnet_act_fn: str = "swish",
517
+ resnet_groups: int = 32,
518
+ resnet_pre_norm: bool = True,
519
+ num_attention_heads=1,
520
+ output_scale_factor=1.0,
521
+ cross_attention_dim=1280,
522
+ dual_cross_attention=False,
523
+ use_linear_projection=False,
524
+ upcast_attention=False,
525
+ num_views: int = 1,
526
+ cd_attention_last: bool = False,
527
+ cd_attention_mid: bool = False,
528
+ multiview_attention: bool = True,
529
+ sparse_mv_attention: bool = False,
530
+ selfattn_block: str = "custom",
531
+ mvcd_attention: bool=False,
532
+ use_dino: bool = False
533
+ ):
534
+ super().__init__()
535
+
536
+ self.has_cross_attention = True
537
+ self.num_attention_heads = num_attention_heads
538
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
539
+ if selfattn_block == "custom":
540
+ from .transformer_mv2d import TransformerMV2DModel
541
+ elif selfattn_block == "rowwise":
542
+ from .transformer_mv2d_rowwise import TransformerMV2DModel
543
+ elif selfattn_block == "self_rowwise":
544
+ from .transformer_mv2d_self_rowwise import TransformerMV2DModel
545
+ else:
546
+ raise NotImplementedError
547
+
548
+ # there is always at least one resnet
549
+ resnets = [
550
+ ResnetBlock2D(
551
+ in_channels=in_channels,
552
+ out_channels=in_channels,
553
+ temb_channels=temb_channels,
554
+ eps=resnet_eps,
555
+ groups=resnet_groups,
556
+ dropout=dropout,
557
+ time_embedding_norm=resnet_time_scale_shift,
558
+ non_linearity=resnet_act_fn,
559
+ output_scale_factor=output_scale_factor,
560
+ pre_norm=resnet_pre_norm,
561
+ )
562
+ ]
563
+ attentions = []
564
+
565
+ for _ in range(num_layers):
566
+ if not dual_cross_attention:
567
+ attentions.append(
568
+ TransformerMV2DModel(
569
+ num_attention_heads,
570
+ in_channels // num_attention_heads,
571
+ in_channels=in_channels,
572
+ num_layers=transformer_layers_per_block,
573
+ cross_attention_dim=cross_attention_dim,
574
+ norm_num_groups=resnet_groups,
575
+ use_linear_projection=use_linear_projection,
576
+ upcast_attention=upcast_attention,
577
+ num_views=num_views,
578
+ cd_attention_last=cd_attention_last,
579
+ cd_attention_mid=cd_attention_mid,
580
+ multiview_attention=multiview_attention,
581
+ sparse_mv_attention=sparse_mv_attention,
582
+ mvcd_attention=mvcd_attention,
583
+ use_dino=use_dino
584
+ )
585
+ )
586
+ else:
587
+ raise NotImplementedError
588
+ resnets.append(
589
+ ResnetBlock2D(
590
+ in_channels=in_channels,
591
+ out_channels=in_channels,
592
+ temb_channels=temb_channels,
593
+ eps=resnet_eps,
594
+ groups=resnet_groups,
595
+ dropout=dropout,
596
+ time_embedding_norm=resnet_time_scale_shift,
597
+ non_linearity=resnet_act_fn,
598
+ output_scale_factor=output_scale_factor,
599
+ pre_norm=resnet_pre_norm,
600
+ )
601
+ )
602
+
603
+ self.attentions = nn.ModuleList(attentions)
604
+ self.resnets = nn.ModuleList(resnets)
605
+
606
+ def forward(
607
+ self,
608
+ hidden_states: torch.FloatTensor,
609
+ temb: Optional[torch.FloatTensor] = None,
610
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
611
+ attention_mask: Optional[torch.FloatTensor] = None,
612
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
613
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
614
+ dino_feature: Optional[torch.FloatTensor] = None
615
+ ) -> torch.FloatTensor:
616
+ hw_ratio = hidden_states.size(2) / hidden_states.size(3)
617
+ hidden_states = self.resnets[0](hidden_states, temb)
618
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
619
+ hidden_states = attn(
620
+ hidden_states,
621
+ encoder_hidden_states=encoder_hidden_states,
622
+ cross_attention_kwargs=cross_attention_kwargs,
623
+ attention_mask=attention_mask,
624
+ encoder_attention_mask=encoder_attention_mask,
625
+ dino_feature=dino_feature,
626
+ return_dict=False,
627
+ hw_ratio=hw_ratio,
628
+ )[0]
629
+ hidden_states = resnet(hidden_states, temb)
630
+
631
+ return hidden_states
632
+
633
+
634
+ class CrossAttnUpBlockMV2D(nn.Module):
635
+ def __init__(
636
+ self,
637
+ in_channels: int,
638
+ out_channels: int,
639
+ prev_output_channel: int,
640
+ temb_channels: int,
641
+ dropout: float = 0.0,
642
+ num_layers: int = 1,
643
+ transformer_layers_per_block: int = 1,
644
+ resnet_eps: float = 1e-6,
645
+ resnet_time_scale_shift: str = "default",
646
+ resnet_act_fn: str = "swish",
647
+ resnet_groups: int = 32,
648
+ resnet_pre_norm: bool = True,
649
+ num_attention_heads=1,
650
+ cross_attention_dim=1280,
651
+ output_scale_factor=1.0,
652
+ add_upsample=True,
653
+ dual_cross_attention=False,
654
+ use_linear_projection=False,
655
+ only_cross_attention=False,
656
+ upcast_attention=False,
657
+ num_views: int = 1,
658
+ cd_attention_last: bool = False,
659
+ cd_attention_mid: bool = False,
660
+ multiview_attention: bool = True,
661
+ sparse_mv_attention: bool = False,
662
+ selfattn_block: str = "custom",
663
+ mvcd_attention: bool=False,
664
+ use_dino: bool = False
665
+ ):
666
+ super().__init__()
667
+ resnets = []
668
+ attentions = []
669
+
670
+ self.has_cross_attention = True
671
+ self.num_attention_heads = num_attention_heads
672
+
673
+ if selfattn_block == "custom":
674
+ from .transformer_mv2d import TransformerMV2DModel
675
+ elif selfattn_block == "rowwise":
676
+ from .transformer_mv2d_rowwise import TransformerMV2DModel
677
+ elif selfattn_block == "self_rowwise":
678
+ from .transformer_mv2d_self_rowwise import TransformerMV2DModel
679
+ else:
680
+ raise NotImplementedError
681
+
682
+ for i in range(num_layers):
683
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
684
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
685
+
686
+ resnets.append(
687
+ ResnetBlock2D(
688
+ in_channels=resnet_in_channels + res_skip_channels,
689
+ out_channels=out_channels,
690
+ temb_channels=temb_channels,
691
+ eps=resnet_eps,
692
+ groups=resnet_groups,
693
+ dropout=dropout,
694
+ time_embedding_norm=resnet_time_scale_shift,
695
+ non_linearity=resnet_act_fn,
696
+ output_scale_factor=output_scale_factor,
697
+ pre_norm=resnet_pre_norm,
698
+ )
699
+ )
700
+ if not dual_cross_attention:
701
+ attentions.append(
702
+ TransformerMV2DModel(
703
+ num_attention_heads,
704
+ out_channels // num_attention_heads,
705
+ in_channels=out_channels,
706
+ num_layers=transformer_layers_per_block,
707
+ cross_attention_dim=cross_attention_dim,
708
+ norm_num_groups=resnet_groups,
709
+ use_linear_projection=use_linear_projection,
710
+ only_cross_attention=only_cross_attention,
711
+ upcast_attention=upcast_attention,
712
+ num_views=num_views,
713
+ cd_attention_last=cd_attention_last,
714
+ cd_attention_mid=cd_attention_mid,
715
+ multiview_attention=multiview_attention,
716
+ sparse_mv_attention=sparse_mv_attention,
717
+ mvcd_attention=mvcd_attention,
718
+ use_dino=use_dino
719
+ )
720
+ )
721
+ else:
722
+ raise NotImplementedError
723
+ self.attentions = nn.ModuleList(attentions)
724
+ self.resnets = nn.ModuleList(resnets)
725
+
726
+ if add_upsample:
727
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
728
+ else:
729
+ self.upsamplers = None
730
+
731
+ self.gradient_checkpointing = False
732
+
733
+ def forward(
734
+ self,
735
+ hidden_states: torch.FloatTensor,
736
+ res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
737
+ temb: Optional[torch.FloatTensor] = None,
738
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
739
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
740
+ upsample_size: Optional[int] = None,
741
+ attention_mask: Optional[torch.FloatTensor] = None,
742
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
743
+ dino_feature: Optional[torch.FloatTensor] = None
744
+ ):
745
+ hw_ratio = hidden_states.size(2) / hidden_states.size(3)
746
+
747
+ for resnet, attn in zip(self.resnets, self.attentions):
748
+ # pop res hidden states
749
+ res_hidden_states = res_hidden_states_tuple[-1]
750
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
751
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
752
+
753
+ if self.training and self.gradient_checkpointing:
754
+
755
+ def create_custom_forward(module, return_dict=None):
756
+ def custom_forward(*inputs):
757
+ if return_dict is not None:
758
+ return module(*inputs, return_dict=return_dict)
759
+ else:
760
+ return module(*inputs)
761
+
762
+ return custom_forward
763
+
764
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
765
+ hidden_states = torch.utils.checkpoint.checkpoint(
766
+ create_custom_forward(resnet),
767
+ hidden_states,
768
+ temb,
769
+ **ckpt_kwargs,
770
+ )
771
+ hidden_states = torch.utils.checkpoint.checkpoint(
772
+ create_custom_forward(attn, return_dict=False),
773
+ hidden_states,
774
+ encoder_hidden_states,
775
+ dino_feature,
776
+ None, # timestep
777
+ None, # class_labels
778
+ cross_attention_kwargs,
779
+ attention_mask,
780
+ encoder_attention_mask,
781
+ hw_ratio,
782
+ **ckpt_kwargs,
783
+ )[0]
784
+ else:
785
+ hidden_states = resnet(hidden_states, temb)
786
+ hidden_states = attn(
787
+ hidden_states,
788
+ encoder_hidden_states=encoder_hidden_states,
789
+ cross_attention_kwargs=cross_attention_kwargs,
790
+ attention_mask=attention_mask,
791
+ encoder_attention_mask=encoder_attention_mask,
792
+ dino_feature=dino_feature,
793
+ hw_ratio=hw_ratio,
794
+ return_dict=False,
795
+ )[0]
796
+
797
+ if self.upsamplers is not None:
798
+ for upsampler in self.upsamplers:
799
+ hidden_states = upsampler(hidden_states, upsample_size)
800
+
801
+ return hidden_states
802
+
803
+
804
+ class CrossAttnDownBlockMV2D(nn.Module):
805
+ def __init__(
806
+ self,
807
+ in_channels: int,
808
+ out_channels: int,
809
+ temb_channels: int,
810
+ dropout: float = 0.0,
811
+ num_layers: int = 1,
812
+ transformer_layers_per_block: int = 1,
813
+ resnet_eps: float = 1e-6,
814
+ resnet_time_scale_shift: str = "default",
815
+ resnet_act_fn: str = "swish",
816
+ resnet_groups: int = 32,
817
+ resnet_pre_norm: bool = True,
818
+ num_attention_heads=1,
819
+ cross_attention_dim=1280,
820
+ output_scale_factor=1.0,
821
+ downsample_padding=1,
822
+ add_downsample=True,
823
+ dual_cross_attention=False,
824
+ use_linear_projection=False,
825
+ only_cross_attention=False,
826
+ upcast_attention=False,
827
+ num_views: int = 1,
828
+ cd_attention_last: bool = False,
829
+ cd_attention_mid: bool = False,
830
+ multiview_attention: bool = True,
831
+ sparse_mv_attention: bool = False,
832
+ selfattn_block: str = "custom",
833
+ mvcd_attention: bool=False,
834
+ use_dino: bool = False
835
+ ):
836
+ super().__init__()
837
+ resnets = []
838
+ attentions = []
839
+
840
+ self.has_cross_attention = True
841
+ self.num_attention_heads = num_attention_heads
842
+ if selfattn_block == "custom":
843
+ from .transformer_mv2d import TransformerMV2DModel
844
+ elif selfattn_block == "rowwise":
845
+ from .transformer_mv2d_rowwise import TransformerMV2DModel
846
+ elif selfattn_block == "self_rowwise":
847
+ from .transformer_mv2d_self_rowwise import TransformerMV2DModel
848
+ else:
849
+ raise NotImplementedError
850
+
851
+ for i in range(num_layers):
852
+ in_channels = in_channels if i == 0 else out_channels
853
+ resnets.append(
854
+ ResnetBlock2D(
855
+ in_channels=in_channels,
856
+ out_channels=out_channels,
857
+ temb_channels=temb_channels,
858
+ eps=resnet_eps,
859
+ groups=resnet_groups,
860
+ dropout=dropout,
861
+ time_embedding_norm=resnet_time_scale_shift,
862
+ non_linearity=resnet_act_fn,
863
+ output_scale_factor=output_scale_factor,
864
+ pre_norm=resnet_pre_norm,
865
+ )
866
+ )
867
+ if not dual_cross_attention:
868
+ attentions.append(
869
+ TransformerMV2DModel(
870
+ num_attention_heads,
871
+ out_channels // num_attention_heads,
872
+ in_channels=out_channels,
873
+ num_layers=transformer_layers_per_block,
874
+ cross_attention_dim=cross_attention_dim,
875
+ norm_num_groups=resnet_groups,
876
+ use_linear_projection=use_linear_projection,
877
+ only_cross_attention=only_cross_attention,
878
+ upcast_attention=upcast_attention,
879
+ num_views=num_views,
880
+ cd_attention_last=cd_attention_last,
881
+ cd_attention_mid=cd_attention_mid,
882
+ multiview_attention=multiview_attention,
883
+ sparse_mv_attention=sparse_mv_attention,
884
+ mvcd_attention=mvcd_attention,
885
+ use_dino=use_dino
886
+ )
887
+ )
888
+ else:
889
+ raise NotImplementedError
890
+ self.attentions = nn.ModuleList(attentions)
891
+ self.resnets = nn.ModuleList(resnets)
892
+
893
+ if add_downsample:
894
+ self.downsamplers = nn.ModuleList(
895
+ [
896
+ Downsample2D(
897
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
898
+ )
899
+ ]
900
+ )
901
+ else:
902
+ self.downsamplers = None
903
+
904
+ self.gradient_checkpointing = False
905
+
906
+ def forward(
907
+ self,
908
+ hidden_states: torch.FloatTensor,
909
+ temb: Optional[torch.FloatTensor] = None,
910
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
911
+ dino_feature: Optional[torch.FloatTensor] = None,
912
+ attention_mask: Optional[torch.FloatTensor] = None,
913
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
914
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
915
+ additional_residuals=None,
916
+ ):
917
+ output_states = ()
918
+
919
+ hw_ratio = hidden_states.size(2) / hidden_states.size(3)
920
+ blocks = list(zip(self.resnets, self.attentions))
921
+
922
+ for i, (resnet, attn) in enumerate(blocks):
923
+ if self.training and self.gradient_checkpointing:
924
+
925
+ def create_custom_forward(module, return_dict=None):
926
+ def custom_forward(*inputs):
927
+ if return_dict is not None:
928
+ return module(*inputs, return_dict=return_dict)
929
+ else:
930
+ return module(*inputs)
931
+
932
+ return custom_forward
933
+
934
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
935
+ hidden_states = torch.utils.checkpoint.checkpoint(
936
+ create_custom_forward(resnet),
937
+ hidden_states,
938
+ temb,
939
+ **ckpt_kwargs,
940
+ )
941
+ hidden_states = torch.utils.checkpoint.checkpoint(
942
+ create_custom_forward(attn, return_dict=False),
943
+ hidden_states,
944
+ encoder_hidden_states,
945
+ dino_feature,
946
+ None, # timestep
947
+ None, # class_labels
948
+ cross_attention_kwargs,
949
+ attention_mask,
950
+ encoder_attention_mask,
951
+ hw_ratio,
952
+ **ckpt_kwargs,
953
+ )[0]
954
+ else:
955
+ hidden_states = resnet(hidden_states, temb)
956
+ hidden_states = attn(
957
+ hidden_states,
958
+ encoder_hidden_states=encoder_hidden_states,
959
+ dino_feature=dino_feature,
960
+ cross_attention_kwargs=cross_attention_kwargs,
961
+ attention_mask=attention_mask,
962
+ encoder_attention_mask=encoder_attention_mask,
963
+ hw_ratio=hw_ratio,
964
+ return_dict=False,
965
+ )[0]
966
+
967
+ # apply additional residuals to the output of the last pair of resnet and attention blocks
968
+ if i == len(blocks) - 1 and additional_residuals is not None:
969
+ hidden_states = hidden_states + additional_residuals
970
+
971
+ output_states = output_states + (hidden_states,)
972
+
973
+ if self.downsamplers is not None:
974
+ for downsampler in self.downsamplers:
975
+ hidden_states = downsampler(hidden_states)
976
+
977
+ output_states = output_states + (hidden_states,)
978
+
979
+ return hidden_states, output_states
980
+
multiview/models/unet_mv2d_condition.py ADDED
@@ -0,0 +1,1685 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from dataclasses import dataclass
15
+ from typing import Any, Dict, List, Optional, Tuple, Union
16
+ import os
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.utils.checkpoint
21
+
22
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
23
+ from diffusers.loaders import UNet2DConditionLoadersMixin
24
+ from diffusers.utils import BaseOutput, logging
25
+ from diffusers.models.activations import get_activation
26
+ from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor
27
+ from diffusers.models.embeddings import (
28
+ GaussianFourierProjection,
29
+ ImageHintTimeEmbedding,
30
+ ImageProjection,
31
+ ImageTimeEmbedding,
32
+ TextImageProjection,
33
+ TextImageTimeEmbedding,
34
+ TextTimeEmbedding,
35
+ TimestepEmbedding,
36
+ Timesteps,
37
+ )
38
+ from diffusers.models.modeling_utils import ModelMixin, load_state_dict, _load_state_dict_into_model
39
+ from diffusers.models.unet_2d_blocks import (
40
+ CrossAttnDownBlock2D,
41
+ CrossAttnUpBlock2D,
42
+ DownBlock2D,
43
+ UNetMidBlock2DCrossAttn,
44
+ UNetMidBlock2DSimpleCrossAttn,
45
+ UpBlock2D,
46
+ )
47
+ from diffusers.utils import (
48
+ CONFIG_NAME,
49
+ FLAX_WEIGHTS_NAME,
50
+ SAFETENSORS_WEIGHTS_NAME,
51
+ WEIGHTS_NAME,
52
+ _add_variant,
53
+ _get_model_file,
54
+ deprecate,
55
+ is_torch_version,
56
+ logging,
57
+ )
58
+ from diffusers.utils.import_utils import is_accelerate_available
59
+ from diffusers.utils.hub_utils import HF_HUB_OFFLINE
60
+ from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
61
+ DIFFUSERS_CACHE = HUGGINGFACE_HUB_CACHE
62
+
63
+ from diffusers import __version__
64
+ from .unet_mv2d_blocks import (
65
+ CrossAttnDownBlockMV2D,
66
+ CrossAttnUpBlockMV2D,
67
+ UNetMidBlockMV2DCrossAttn,
68
+ get_down_block,
69
+ get_up_block,
70
+ )
71
+ from einops import rearrange, repeat
72
+
73
+ from diffusers import __version__
74
+ from .unet_mv2d_blocks import (
75
+ CrossAttnDownBlockMV2D,
76
+ CrossAttnUpBlockMV2D,
77
+ UNetMidBlockMV2DCrossAttn,
78
+ get_down_block,
79
+ get_up_block,
80
+ )
81
+
82
+
83
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
84
+
85
+
86
+ @dataclass
87
+ class UNetMV2DConditionOutput(BaseOutput):
88
+ """
89
+ The output of [`UNet2DConditionModel`].
90
+
91
+ Args:
92
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
93
+ The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
94
+ """
95
+
96
+ sample: torch.FloatTensor = None
97
+
98
+
99
+ class ResidualBlock(nn.Module):
100
+ def __init__(self, dim):
101
+ super(ResidualBlock, self).__init__()
102
+ self.linear1 = nn.Linear(dim, dim)
103
+ self.activation = nn.SiLU()
104
+ self.linear2 = nn.Linear(dim, dim)
105
+
106
+ def forward(self, x):
107
+ identity = x
108
+ out = self.linear1(x)
109
+ out = self.activation(out)
110
+ out = self.linear2(out)
111
+ out += identity
112
+ out = self.activation(out)
113
+ return out
114
+
115
+ class ResidualLiner(nn.Module):
116
+ def __init__(self, in_features, out_features, dim, act=None, num_block=1):
117
+ super(ResidualLiner, self).__init__()
118
+ self.linear_in = nn.Sequential(nn.Linear(in_features, dim), nn.SiLU())
119
+
120
+ blocks = nn.ModuleList()
121
+ for _ in range(num_block):
122
+ blocks.append(ResidualBlock(dim))
123
+ self.blocks = blocks
124
+
125
+ self.linear_out = nn.Linear(dim, out_features)
126
+ self.act = act
127
+
128
+ def forward(self, x):
129
+ out = self.linear_in(x)
130
+ for block in self.blocks:
131
+ out = block(out)
132
+ out = self.linear_out(out)
133
+ if self.act is not None:
134
+ out = self.act(out)
135
+ return out
136
+
137
+ class BasicConvBlock(nn.Module):
138
+ def __init__(self, in_channels, out_channels, stride=1):
139
+ super(BasicConvBlock, self).__init__()
140
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
141
+ self.norm1 = nn.GroupNorm(num_groups=8, num_channels=in_channels, affine=True)
142
+ self.act = nn.SiLU()
143
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
144
+ self.norm2 = nn.GroupNorm(num_groups=8, num_channels=in_channels, affine=True)
145
+ self.downsample = nn.Sequential()
146
+ if stride != 1 or in_channels != out_channels:
147
+ self.downsample = nn.Sequential(
148
+ nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
149
+ nn.GroupNorm(num_groups=8, num_channels=in_channels, affine=True)
150
+ )
151
+
152
+ def forward(self, x):
153
+ identity = x
154
+ out = self.conv1(x)
155
+ out = self.norm1(out)
156
+ out = self.act(out)
157
+ out = self.conv2(out)
158
+ out = self.norm2(out)
159
+ out += self.downsample(identity)
160
+ out = self.act(out)
161
+ return out
162
+
163
+ class UNetMV2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
164
+ r"""
165
+ A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
166
+ shaped output.
167
+
168
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
169
+ for all models (such as downloading or saving).
170
+
171
+ Parameters:
172
+ sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
173
+ Height and width of input/output sample.
174
+ in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample.
175
+ out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
176
+ center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
177
+ flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
178
+ Whether to flip the sin to cos in the time embedding.
179
+ freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
180
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
181
+ The tuple of downsample blocks to use.
182
+ mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
183
+ Block type for middle of UNet, it can be either `UNetMidBlock2DCrossAttn` or
184
+ `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
185
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
186
+ The tuple of upsample blocks to use.
187
+ only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
188
+ Whether to include self-attention in the basic transformer blocks, see
189
+ [`~models.attention.BasicTransformerBlock`].
190
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
191
+ The tuple of output channels for each block.
192
+ layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
193
+ downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
194
+ mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
195
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
196
+ norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
197
+ If `None`, normalization and activation layers is skipped in post-processing.
198
+ norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
199
+ cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
200
+ The dimension of the cross attention features.
201
+ transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
202
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
203
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
204
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
205
+ encoder_hid_dim (`int`, *optional*, defaults to None):
206
+ If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
207
+ dimension to `cross_attention_dim`.
208
+ encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
209
+ If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
210
+ embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
211
+ attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
212
+ num_attention_heads (`int`, *optional*):
213
+ The number of attention heads. If not defined, defaults to `attention_head_dim`
214
+ resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
215
+ for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
216
+ class_embed_type (`str`, *optional*, defaults to `None`):
217
+ The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
218
+ `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
219
+ addition_embed_type (`str`, *optional*, defaults to `None`):
220
+ Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
221
+ "text". "text" will use the `TextTimeEmbedding` layer.
222
+ addition_time_embed_dim: (`int`, *optional*, defaults to `None`):
223
+ Dimension for the timestep embeddings.
224
+ num_class_embeds (`int`, *optional*, defaults to `None`):
225
+ Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
226
+ class conditioning with `class_embed_type` equal to `None`.
227
+ time_embedding_type (`str`, *optional*, defaults to `positional`):
228
+ The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
229
+ time_embedding_dim (`int`, *optional*, defaults to `None`):
230
+ An optional override for the dimension of the projected time embedding.
231
+ time_embedding_act_fn (`str`, *optional*, defaults to `None`):
232
+ Optional activation function to use only once on the time embeddings before they are passed to the rest of
233
+ the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`.
234
+ timestep_post_act (`str`, *optional*, defaults to `None`):
235
+ The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
236
+ time_cond_proj_dim (`int`, *optional*, defaults to `None`):
237
+ The dimension of `cond_proj` layer in the timestep embedding.
238
+ conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer.
239
+ conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer.
240
+ projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when
241
+ `class_embed_type="projection"`. Required when `class_embed_type="projection"`.
242
+ class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
243
+ embeddings with the class embeddings.
244
+ mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
245
+ Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If
246
+ `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the
247
+ `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False`
248
+ otherwise.
249
+ """
250
+
251
+ _supports_gradient_checkpointing = True
252
+
253
+ @register_to_config
254
+ def __init__(
255
+ self,
256
+ sample_size: Optional[int] = None,
257
+ in_channels: int = 4,
258
+ out_channels: int = 4,
259
+ center_input_sample: bool = False,
260
+ flip_sin_to_cos: bool = True,
261
+ freq_shift: int = 0,
262
+ down_block_types: Tuple[str] = (
263
+ "CrossAttnDownBlockMV2D",
264
+ "CrossAttnDownBlockMV2D",
265
+ "CrossAttnDownBlockMV2D",
266
+ "DownBlock2D",
267
+ ),
268
+ mid_block_type: Optional[str] = "UNetMidBlockMV2DCrossAttn",
269
+ up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlockMV2D", "CrossAttnUpBlockMV2D", "CrossAttnUpBlockMV2D"),
270
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
271
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
272
+ layers_per_block: Union[int, Tuple[int]] = 2,
273
+ downsample_padding: int = 1,
274
+ mid_block_scale_factor: float = 1,
275
+ act_fn: str = "silu",
276
+ norm_num_groups: Optional[int] = 32,
277
+ norm_eps: float = 1e-5,
278
+ cross_attention_dim: Union[int, Tuple[int]] = 1280,
279
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
280
+ encoder_hid_dim: Optional[int] = None,
281
+ encoder_hid_dim_type: Optional[str] = None,
282
+ attention_head_dim: Union[int, Tuple[int]] = 8,
283
+ num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
284
+ dual_cross_attention: bool = False,
285
+ use_linear_projection: bool = False,
286
+ class_embed_type: Optional[str] = None,
287
+ addition_embed_type: Optional[str] = None,
288
+ addition_time_embed_dim: Optional[int] = None,
289
+ num_class_embeds: Optional[int] = None,
290
+ upcast_attention: bool = False,
291
+ resnet_time_scale_shift: str = "default",
292
+ resnet_skip_time_act: bool = False,
293
+ resnet_out_scale_factor: int = 1.0,
294
+ time_embedding_type: str = "positional",
295
+ time_embedding_dim: Optional[int] = None,
296
+ time_embedding_act_fn: Optional[str] = None,
297
+ timestep_post_act: Optional[str] = None,
298
+ time_cond_proj_dim: Optional[int] = None,
299
+ conv_in_kernel: int = 3,
300
+ conv_out_kernel: int = 3,
301
+ projection_class_embeddings_input_dim: Optional[int] = None,
302
+ projection_camera_embeddings_input_dim: Optional[int] = None,
303
+ class_embeddings_concat: bool = False,
304
+ mid_block_only_cross_attention: Optional[bool] = None,
305
+ cross_attention_norm: Optional[str] = None,
306
+ addition_embed_type_num_heads=64,
307
+ num_views: int = 1,
308
+ cd_attention_last: bool = False,
309
+ cd_attention_mid: bool = False,
310
+ multiview_attention: bool = True,
311
+ sparse_mv_attention: bool = False,
312
+ selfattn_block: str = "custom",
313
+ mvcd_attention: bool = False,
314
+ regress_elevation: bool = False,
315
+ regress_focal_length: bool = False,
316
+ num_regress_blocks: int = 4,
317
+ use_dino: bool = False,
318
+ addition_downsample: bool = False,
319
+ addition_channels: Optional[Tuple[int]] = (1280, 1280, 1280),
320
+ ):
321
+ super().__init__()
322
+
323
+ self.sample_size = sample_size
324
+ self.num_views = num_views
325
+ self.mvcd_attention = mvcd_attention
326
+ if num_attention_heads is not None:
327
+ raise ValueError(
328
+ "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
329
+ )
330
+
331
+ # If `num_attention_heads` is not defined (which is the case for most models)
332
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
333
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
334
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
335
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
336
+ # which is why we correct for the naming here.
337
+ num_attention_heads = num_attention_heads or attention_head_dim
338
+
339
+ # Check inputs
340
+ if len(down_block_types) != len(up_block_types):
341
+ raise ValueError(
342
+ f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
343
+ )
344
+
345
+ if len(block_out_channels) != len(down_block_types):
346
+ raise ValueError(
347
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
348
+ )
349
+
350
+ if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
351
+ raise ValueError(
352
+ f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
353
+ )
354
+
355
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
356
+ raise ValueError(
357
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
358
+ )
359
+
360
+ if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
361
+ raise ValueError(
362
+ f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
363
+ )
364
+
365
+ if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
366
+ raise ValueError(
367
+ f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
368
+ )
369
+
370
+ if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
371
+ raise ValueError(
372
+ f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
373
+ )
374
+
375
+ # input
376
+ conv_in_padding = (conv_in_kernel - 1) // 2
377
+ self.conv_in = nn.Conv2d(
378
+ in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
379
+ )
380
+
381
+ # time
382
+ if time_embedding_type == "fourier":
383
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
384
+ if time_embed_dim % 2 != 0:
385
+ raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
386
+ self.time_proj = GaussianFourierProjection(
387
+ time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
388
+ )
389
+ timestep_input_dim = time_embed_dim
390
+ elif time_embedding_type == "positional":
391
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
392
+
393
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
394
+ timestep_input_dim = block_out_channels[0]
395
+ else:
396
+ raise ValueError(
397
+ f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
398
+ )
399
+
400
+ self.time_embedding = TimestepEmbedding(
401
+ timestep_input_dim,
402
+ time_embed_dim,
403
+ act_fn=act_fn,
404
+ post_act_fn=timestep_post_act,
405
+ cond_proj_dim=time_cond_proj_dim,
406
+ )
407
+
408
+ if encoder_hid_dim_type is None and encoder_hid_dim is not None:
409
+ encoder_hid_dim_type = "text_proj"
410
+ self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
411
+ logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
412
+
413
+ if encoder_hid_dim is None and encoder_hid_dim_type is not None:
414
+ raise ValueError(
415
+ f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
416
+ )
417
+
418
+ if encoder_hid_dim_type == "text_proj":
419
+ self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
420
+ elif encoder_hid_dim_type == "text_image_proj":
421
+ # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
422
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
423
+ # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
424
+ self.encoder_hid_proj = TextImageProjection(
425
+ text_embed_dim=encoder_hid_dim,
426
+ image_embed_dim=cross_attention_dim,
427
+ cross_attention_dim=cross_attention_dim,
428
+ )
429
+ elif encoder_hid_dim_type == "image_proj":
430
+ # Kandinsky 2.2
431
+ self.encoder_hid_proj = ImageProjection(
432
+ image_embed_dim=encoder_hid_dim,
433
+ cross_attention_dim=cross_attention_dim,
434
+ )
435
+ elif encoder_hid_dim_type is not None:
436
+ raise ValueError(
437
+ f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
438
+ )
439
+ else:
440
+ self.encoder_hid_proj = None
441
+
442
+ # class embedding
443
+ if class_embed_type is None and num_class_embeds is not None:
444
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
445
+ elif class_embed_type == "timestep":
446
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
447
+ elif class_embed_type == "identity":
448
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
449
+ elif class_embed_type == "projection":
450
+ if projection_class_embeddings_input_dim is None:
451
+ raise ValueError(
452
+ "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
453
+ )
454
+ # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
455
+ # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
456
+ # 2. it projects from an arbitrary input dimension.
457
+ #
458
+ # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
459
+ # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
460
+ # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
461
+ self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
462
+ elif class_embed_type == "simple_projection":
463
+ if projection_class_embeddings_input_dim is None:
464
+ raise ValueError(
465
+ "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
466
+ )
467
+ self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)
468
+ else:
469
+ self.class_embedding = None
470
+
471
+ if addition_embed_type == "text":
472
+ if encoder_hid_dim is not None:
473
+ text_time_embedding_from_dim = encoder_hid_dim
474
+ else:
475
+ text_time_embedding_from_dim = cross_attention_dim
476
+
477
+ self.add_embedding = TextTimeEmbedding(
478
+ text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
479
+ )
480
+ elif addition_embed_type == "text_image":
481
+ # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
482
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
483
+ # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
484
+ self.add_embedding = TextImageTimeEmbedding(
485
+ text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
486
+ )
487
+ elif addition_embed_type == "text_time":
488
+ self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
489
+ self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
490
+ elif addition_embed_type == "image":
491
+ # Kandinsky 2.2
492
+ self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
493
+ elif addition_embed_type == "image_hint":
494
+ # Kandinsky 2.2 ControlNet
495
+ self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
496
+ elif addition_embed_type is not None:
497
+ raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
498
+
499
+ if time_embedding_act_fn is None:
500
+ self.time_embed_act = None
501
+ else:
502
+ self.time_embed_act = get_activation(time_embedding_act_fn)
503
+
504
+ self.down_blocks = nn.ModuleList([])
505
+ self.up_blocks = nn.ModuleList([])
506
+
507
+ if isinstance(only_cross_attention, bool):
508
+ if mid_block_only_cross_attention is None:
509
+ mid_block_only_cross_attention = only_cross_attention
510
+
511
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
512
+
513
+ if mid_block_only_cross_attention is None:
514
+ mid_block_only_cross_attention = False
515
+
516
+ if isinstance(num_attention_heads, int):
517
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
518
+
519
+ if isinstance(attention_head_dim, int):
520
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
521
+
522
+ if isinstance(cross_attention_dim, int):
523
+ cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
524
+
525
+ if isinstance(layers_per_block, int):
526
+ layers_per_block = [layers_per_block] * len(down_block_types)
527
+
528
+ if isinstance(transformer_layers_per_block, int):
529
+ transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
530
+
531
+ if class_embeddings_concat:
532
+ # The time embeddings are concatenated with the class embeddings. The dimension of the
533
+ # time embeddings passed to the down, middle, and up blocks is twice the dimension of the
534
+ # regular time embeddings
535
+ blocks_time_embed_dim = time_embed_dim * 2
536
+ else:
537
+ blocks_time_embed_dim = time_embed_dim
538
+
539
+ # down
540
+ output_channel = block_out_channels[0]
541
+ for i, down_block_type in enumerate(down_block_types):
542
+ input_channel = output_channel
543
+ output_channel = block_out_channels[i]
544
+ is_final_block = i == len(block_out_channels) - 1
545
+
546
+ down_block = get_down_block(
547
+ down_block_type,
548
+ num_layers=layers_per_block[i],
549
+ transformer_layers_per_block=transformer_layers_per_block[i],
550
+ in_channels=input_channel,
551
+ out_channels=output_channel,
552
+ temb_channels=blocks_time_embed_dim,
553
+ add_downsample=not is_final_block,
554
+ resnet_eps=norm_eps,
555
+ resnet_act_fn=act_fn,
556
+ resnet_groups=norm_num_groups,
557
+ cross_attention_dim=cross_attention_dim[i],
558
+ num_attention_heads=num_attention_heads[i],
559
+ downsample_padding=downsample_padding,
560
+ dual_cross_attention=dual_cross_attention,
561
+ use_linear_projection=use_linear_projection,
562
+ only_cross_attention=only_cross_attention[i],
563
+ upcast_attention=upcast_attention,
564
+ resnet_time_scale_shift=resnet_time_scale_shift,
565
+ resnet_skip_time_act=resnet_skip_time_act,
566
+ resnet_out_scale_factor=resnet_out_scale_factor,
567
+ cross_attention_norm=cross_attention_norm,
568
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
569
+ num_views=num_views,
570
+ cd_attention_last=cd_attention_last,
571
+ cd_attention_mid=cd_attention_mid,
572
+ multiview_attention=multiview_attention,
573
+ sparse_mv_attention=sparse_mv_attention,
574
+ selfattn_block=selfattn_block,
575
+ mvcd_attention=mvcd_attention,
576
+ use_dino=use_dino
577
+ )
578
+ self.down_blocks.append(down_block)
579
+
580
+ # mid
581
+ if mid_block_type == "UNetMidBlock2DCrossAttn":
582
+ self.mid_block = UNetMidBlock2DCrossAttn(
583
+ transformer_layers_per_block=transformer_layers_per_block[-1],
584
+ in_channels=block_out_channels[-1],
585
+ temb_channels=blocks_time_embed_dim,
586
+ resnet_eps=norm_eps,
587
+ resnet_act_fn=act_fn,
588
+ output_scale_factor=mid_block_scale_factor,
589
+ resnet_time_scale_shift=resnet_time_scale_shift,
590
+ cross_attention_dim=cross_attention_dim[-1],
591
+ num_attention_heads=num_attention_heads[-1],
592
+ resnet_groups=norm_num_groups,
593
+ dual_cross_attention=dual_cross_attention,
594
+ use_linear_projection=use_linear_projection,
595
+ upcast_attention=upcast_attention,
596
+ )
597
+ # custom MV2D attention block
598
+ elif mid_block_type == "UNetMidBlockMV2DCrossAttn":
599
+ self.mid_block = UNetMidBlockMV2DCrossAttn(
600
+ transformer_layers_per_block=transformer_layers_per_block[-1],
601
+ in_channels=block_out_channels[-1],
602
+ temb_channels=blocks_time_embed_dim,
603
+ resnet_eps=norm_eps,
604
+ resnet_act_fn=act_fn,
605
+ output_scale_factor=mid_block_scale_factor,
606
+ resnet_time_scale_shift=resnet_time_scale_shift,
607
+ cross_attention_dim=cross_attention_dim[-1],
608
+ num_attention_heads=num_attention_heads[-1],
609
+ resnet_groups=norm_num_groups,
610
+ dual_cross_attention=dual_cross_attention,
611
+ use_linear_projection=use_linear_projection,
612
+ upcast_attention=upcast_attention,
613
+ num_views=num_views,
614
+ cd_attention_last=cd_attention_last,
615
+ cd_attention_mid=cd_attention_mid,
616
+ multiview_attention=multiview_attention,
617
+ sparse_mv_attention=sparse_mv_attention,
618
+ selfattn_block=selfattn_block,
619
+ mvcd_attention=mvcd_attention,
620
+ use_dino=use_dino
621
+ )
622
+ elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn":
623
+ self.mid_block = UNetMidBlock2DSimpleCrossAttn(
624
+ in_channels=block_out_channels[-1],
625
+ temb_channels=blocks_time_embed_dim,
626
+ resnet_eps=norm_eps,
627
+ resnet_act_fn=act_fn,
628
+ output_scale_factor=mid_block_scale_factor,
629
+ cross_attention_dim=cross_attention_dim[-1],
630
+ attention_head_dim=attention_head_dim[-1],
631
+ resnet_groups=norm_num_groups,
632
+ resnet_time_scale_shift=resnet_time_scale_shift,
633
+ skip_time_act=resnet_skip_time_act,
634
+ only_cross_attention=mid_block_only_cross_attention,
635
+ cross_attention_norm=cross_attention_norm,
636
+ )
637
+ elif mid_block_type is None:
638
+ self.mid_block = None
639
+ else:
640
+ raise ValueError(f"unknown mid_block_type : {mid_block_type}")
641
+
642
+ self.addition_downsample = addition_downsample
643
+ if self.addition_downsample:
644
+ inc = block_out_channels[-1]
645
+ self.downsample = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
646
+ self.conv_block = nn.ModuleList()
647
+ self.conv_block.append(BasicConvBlock(inc, addition_channels[0], stride=1))
648
+ for dim_ in addition_channels[1:-1]:
649
+ self.conv_block.append(BasicConvBlock(dim_, dim_, stride=1))
650
+ self.conv_block.append(BasicConvBlock(dim_, inc))
651
+ self.addition_conv_out = nn.Conv2d(inc, inc, kernel_size=1, bias=False)
652
+ nn.init.zeros_(self.addition_conv_out.weight.data)
653
+ self.addition_act_out = nn.SiLU()
654
+ self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
655
+
656
+ self.regress_elevation = regress_elevation
657
+ self.regress_focal_length = regress_focal_length
658
+ if regress_elevation or regress_focal_length:
659
+ self.pool = nn.AdaptiveAvgPool2d((1, 1))
660
+ self.camera_embedding = TimestepEmbedding(projection_camera_embeddings_input_dim, time_embed_dim=time_embed_dim)
661
+
662
+ regress_in_dim = block_out_channels[-1]*2 if mvcd_attention else block_out_channels
663
+
664
+ if regress_elevation:
665
+ self.elevation_regressor = ResidualLiner(regress_in_dim, 1, 1280, act=None, num_block=num_regress_blocks)
666
+ if regress_focal_length:
667
+ self.focal_regressor = ResidualLiner(regress_in_dim, 1, 1280, act=None, num_block=num_regress_blocks)
668
+ '''
669
+ self.regress_elevation = regress_elevation
670
+ self.regress_focal_length = regress_focal_length
671
+ if regress_elevation and (not regress_focal_length):
672
+ print("Regressing elevation")
673
+ cam_dim = 1
674
+ elif regress_focal_length and (not regress_elevation):
675
+ print("Regressing focal length")
676
+ cam_dim = 6
677
+ elif regress_elevation and regress_focal_length:
678
+ print("Regressing both elevation and focal length")
679
+ cam_dim = 7
680
+ else:
681
+ cam_dim = 0
682
+ assert projection_camera_embeddings_input_dim == 2*cam_dim, "projection_camera_embeddings_input_dim should be 2*cam_dim"
683
+ if regress_elevation or regress_focal_length:
684
+ self.elevation_regressor = nn.ModuleList([
685
+ nn.Linear(block_out_channels[-1], 1280),
686
+ nn.SiLU(),
687
+ nn.Linear(1280, 1280),
688
+ nn.SiLU(),
689
+ nn.Linear(1280, cam_dim)
690
+ ])
691
+ self.pool = nn.AdaptiveAvgPool2d((1, 1))
692
+ self.focal_act = nn.Softmax(dim=-1)
693
+ self.camera_embedding = TimestepEmbedding(projection_camera_embeddings_input_dim, time_embed_dim=time_embed_dim)
694
+ '''
695
+
696
+ # count how many layers upsample the images
697
+ self.num_upsamplers = 0
698
+
699
+ # up
700
+ reversed_block_out_channels = list(reversed(block_out_channels))
701
+ reversed_num_attention_heads = list(reversed(num_attention_heads))
702
+ reversed_layers_per_block = list(reversed(layers_per_block))
703
+ reversed_cross_attention_dim = list(reversed(cross_attention_dim))
704
+ reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
705
+ only_cross_attention = list(reversed(only_cross_attention))
706
+
707
+ output_channel = reversed_block_out_channels[0]
708
+ for i, up_block_type in enumerate(up_block_types):
709
+ is_final_block = i == len(block_out_channels) - 1
710
+
711
+ prev_output_channel = output_channel
712
+ output_channel = reversed_block_out_channels[i]
713
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
714
+
715
+ # add upsample block for all BUT final layer
716
+ if not is_final_block:
717
+ add_upsample = True
718
+ self.num_upsamplers += 1
719
+ else:
720
+ add_upsample = False
721
+
722
+ up_block = get_up_block(
723
+ up_block_type,
724
+ num_layers=reversed_layers_per_block[i] + 1,
725
+ transformer_layers_per_block=reversed_transformer_layers_per_block[i],
726
+ in_channels=input_channel,
727
+ out_channels=output_channel,
728
+ prev_output_channel=prev_output_channel,
729
+ temb_channels=blocks_time_embed_dim,
730
+ add_upsample=add_upsample,
731
+ resnet_eps=norm_eps,
732
+ resnet_act_fn=act_fn,
733
+ resnet_groups=norm_num_groups,
734
+ cross_attention_dim=reversed_cross_attention_dim[i],
735
+ num_attention_heads=reversed_num_attention_heads[i],
736
+ dual_cross_attention=dual_cross_attention,
737
+ use_linear_projection=use_linear_projection,
738
+ only_cross_attention=only_cross_attention[i],
739
+ upcast_attention=upcast_attention,
740
+ resnet_time_scale_shift=resnet_time_scale_shift,
741
+ resnet_skip_time_act=resnet_skip_time_act,
742
+ resnet_out_scale_factor=resnet_out_scale_factor,
743
+ cross_attention_norm=cross_attention_norm,
744
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
745
+ num_views=num_views,
746
+ cd_attention_last=cd_attention_last,
747
+ cd_attention_mid=cd_attention_mid,
748
+ multiview_attention=multiview_attention,
749
+ sparse_mv_attention=sparse_mv_attention,
750
+ selfattn_block=selfattn_block,
751
+ mvcd_attention=mvcd_attention,
752
+ use_dino=use_dino
753
+ )
754
+ self.up_blocks.append(up_block)
755
+ prev_output_channel = output_channel
756
+
757
+ # out
758
+ if norm_num_groups is not None:
759
+ self.conv_norm_out = nn.GroupNorm(
760
+ num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
761
+ )
762
+
763
+ self.conv_act = get_activation(act_fn)
764
+
765
+ else:
766
+ self.conv_norm_out = None
767
+ self.conv_act = None
768
+
769
+ conv_out_padding = (conv_out_kernel - 1) // 2
770
+ self.conv_out = nn.Conv2d(
771
+ block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
772
+ )
773
+
774
+ @property
775
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
776
+ r"""
777
+ Returns:
778
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
779
+ indexed by its weight name.
780
+ """
781
+ # set recursively
782
+ processors = {}
783
+
784
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
785
+ if hasattr(module, "set_processor"):
786
+ processors[f"{name}.processor"] = module.processor
787
+
788
+ for sub_name, child in module.named_children():
789
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
790
+
791
+ return processors
792
+
793
+ for name, module in self.named_children():
794
+ fn_recursive_add_processors(name, module, processors)
795
+
796
+ return processors
797
+
798
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
799
+ r"""
800
+ Sets the attention processor to use to compute attention.
801
+
802
+ Parameters:
803
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
804
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
805
+ for **all** `Attention` layers.
806
+
807
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
808
+ processor. This is strongly recommended when setting trainable attention processors.
809
+
810
+ """
811
+ count = len(self.attn_processors.keys())
812
+
813
+ if isinstance(processor, dict) and len(processor) != count:
814
+ raise ValueError(
815
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
816
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
817
+ )
818
+
819
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
820
+ if hasattr(module, "set_processor"):
821
+ if not isinstance(processor, dict):
822
+ module.set_processor(processor)
823
+ else:
824
+ module.set_processor(processor.pop(f"{name}.processor"))
825
+
826
+ for sub_name, child in module.named_children():
827
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
828
+
829
+ for name, module in self.named_children():
830
+ fn_recursive_attn_processor(name, module, processor)
831
+
832
+ def set_default_attn_processor(self):
833
+ """
834
+ Disables custom attention processors and sets the default attention implementation.
835
+ """
836
+ self.set_attn_processor(AttnProcessor())
837
+
838
+ def set_attention_slice(self, slice_size):
839
+ r"""
840
+ Enable sliced attention computation.
841
+
842
+ When this option is enabled, the attention module splits the input tensor in slices to compute attention in
843
+ several steps. This is useful for saving some memory in exchange for a small decrease in speed.
844
+
845
+ Args:
846
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
847
+ When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
848
+ `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
849
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
850
+ must be a multiple of `slice_size`.
851
+ """
852
+ sliceable_head_dims = []
853
+
854
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
855
+ if hasattr(module, "set_attention_slice"):
856
+ sliceable_head_dims.append(module.sliceable_head_dim)
857
+
858
+ for child in module.children():
859
+ fn_recursive_retrieve_sliceable_dims(child)
860
+
861
+ # retrieve number of attention layers
862
+ for module in self.children():
863
+ fn_recursive_retrieve_sliceable_dims(module)
864
+
865
+ num_sliceable_layers = len(sliceable_head_dims)
866
+
867
+ if slice_size == "auto":
868
+ # half the attention head size is usually a good trade-off between
869
+ # speed and memory
870
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
871
+ elif slice_size == "max":
872
+ # make smallest slice possible
873
+ slice_size = num_sliceable_layers * [1]
874
+
875
+ slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
876
+
877
+ if len(slice_size) != len(sliceable_head_dims):
878
+ raise ValueError(
879
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
880
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
881
+ )
882
+
883
+ for i in range(len(slice_size)):
884
+ size = slice_size[i]
885
+ dim = sliceable_head_dims[i]
886
+ if size is not None and size > dim:
887
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
888
+
889
+ # Recursively walk through all the children.
890
+ # Any children which exposes the set_attention_slice method
891
+ # gets the message
892
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
893
+ if hasattr(module, "set_attention_slice"):
894
+ module.set_attention_slice(slice_size.pop())
895
+
896
+ for child in module.children():
897
+ fn_recursive_set_attention_slice(child, slice_size)
898
+
899
+ reversed_slice_size = list(reversed(slice_size))
900
+ for module in self.children():
901
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
902
+
903
+ def _set_gradient_checkpointing(self, module, value=False):
904
+ if isinstance(module, (CrossAttnDownBlock2D, CrossAttnDownBlockMV2D, DownBlock2D, CrossAttnUpBlock2D, CrossAttnUpBlockMV2D, UpBlock2D)):
905
+ module.gradient_checkpointing = value
906
+
907
+ def forward(
908
+ self,
909
+ sample: torch.FloatTensor,
910
+ timestep: Union[torch.Tensor, float, int],
911
+ encoder_hidden_states: torch.Tensor,
912
+ class_labels: Optional[torch.Tensor] = None,
913
+ timestep_cond: Optional[torch.Tensor] = None,
914
+ attention_mask: Optional[torch.Tensor] = None,
915
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
916
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
917
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
918
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
919
+ encoder_attention_mask: Optional[torch.Tensor] = None,
920
+ dino_feature: Optional[torch.Tensor] = None,
921
+ return_dict: bool = True,
922
+ vis_max_min: bool = False,
923
+ ) -> Union[UNetMV2DConditionOutput, Tuple]:
924
+ r"""
925
+ The [`UNet2DConditionModel`] forward method.
926
+
927
+ Args:
928
+ sample (`torch.FloatTensor`):
929
+ The noisy input tensor with the following shape `(batch, channel, height, width)`.
930
+ timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
931
+ encoder_hidden_states (`torch.FloatTensor`):
932
+ The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
933
+ encoder_attention_mask (`torch.Tensor`):
934
+ A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
935
+ `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
936
+ which adds large negative values to the attention scores corresponding to "discard" tokens.
937
+ return_dict (`bool`, *optional*, defaults to `True`):
938
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
939
+ tuple.
940
+ cross_attention_kwargs (`dict`, *optional*):
941
+ A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
942
+ added_cond_kwargs: (`dict`, *optional*):
943
+ A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that
944
+ are passed along to the UNet blocks.
945
+
946
+ Returns:
947
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
948
+ If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
949
+ a `tuple` is returned where the first element is the sample tensor.
950
+ """
951
+ record_max_min = {}
952
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
953
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
954
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
955
+ # on the fly if necessary.
956
+ default_overall_up_factor = 2**self.num_upsamplers
957
+
958
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
959
+ forward_upsample_size = False
960
+ upsample_size = None
961
+
962
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
963
+ logger.info("Forward upsample size to force interpolation output size.")
964
+ forward_upsample_size = True
965
+
966
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension
967
+ # expects mask of shape:
968
+ # [batch, key_tokens]
969
+ # adds singleton query_tokens dimension:
970
+ # [batch, 1, key_tokens]
971
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
972
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
973
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
974
+ if attention_mask is not None:
975
+ # assume that mask is expressed as:
976
+ # (1 = keep, 0 = discard)
977
+ # convert mask into a bias that can be added to attention scores:
978
+ # (keep = +0, discard = -10000.0)
979
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
980
+ attention_mask = attention_mask.unsqueeze(1)
981
+
982
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
983
+ if encoder_attention_mask is not None:
984
+ encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
985
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
986
+
987
+ # 0. center input if necessary
988
+ if self.config.center_input_sample:
989
+ sample = 2 * sample - 1.0
990
+ # 1. time
991
+ timesteps = timestep
992
+ if not torch.is_tensor(timesteps):
993
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
994
+ # This would be a good case for the `match` statement (Python 3.10+)
995
+ is_mps = sample.device.type == "mps"
996
+ if isinstance(timestep, float):
997
+ dtype = torch.float32 if is_mps else torch.float64
998
+ else:
999
+ dtype = torch.int32 if is_mps else torch.int64
1000
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
1001
+ elif len(timesteps.shape) == 0:
1002
+ timesteps = timesteps[None].to(sample.device)
1003
+
1004
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
1005
+ timesteps = timesteps.expand(sample.shape[0])
1006
+
1007
+ t_emb = self.time_proj(timesteps)
1008
+
1009
+ # `Timesteps` does not contain any weights and will always return f32 tensors
1010
+ # but time_embedding might actually be running in fp16. so we need to cast here.
1011
+ # there might be better ways to encapsulate this.
1012
+ t_emb = t_emb.to(dtype=sample.dtype)
1013
+
1014
+ emb = self.time_embedding(t_emb, timestep_cond)
1015
+ aug_emb = None
1016
+ if self.class_embedding is not None:
1017
+ if class_labels is None:
1018
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
1019
+
1020
+ if self.config.class_embed_type == "timestep":
1021
+ class_labels = self.time_proj(class_labels)
1022
+
1023
+ # `Timesteps` does not contain any weights and will always return f32 tensors
1024
+ # there might be better ways to encapsulate this.
1025
+ class_labels = class_labels.to(dtype=sample.dtype)
1026
+
1027
+ class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
1028
+ if self.config.class_embeddings_concat:
1029
+ emb = torch.cat([emb, class_emb], dim=-1)
1030
+ else:
1031
+ emb = emb + class_emb
1032
+
1033
+ if self.config.addition_embed_type == "text":
1034
+ aug_emb = self.add_embedding(encoder_hidden_states)
1035
+ elif self.config.addition_embed_type == "text_image":
1036
+ # Kandinsky 2.1 - style
1037
+ if "image_embeds" not in added_cond_kwargs:
1038
+ raise ValueError(
1039
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
1040
+ )
1041
+
1042
+ image_embs = added_cond_kwargs.get("image_embeds")
1043
+ text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
1044
+ aug_emb = self.add_embedding(text_embs, image_embs)
1045
+ elif self.config.addition_embed_type == "text_time":
1046
+ # SDXL - style
1047
+ if "text_embeds" not in added_cond_kwargs:
1048
+ raise ValueError(
1049
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
1050
+ )
1051
+ text_embeds = added_cond_kwargs.get("text_embeds")
1052
+ if "time_ids" not in added_cond_kwargs:
1053
+ raise ValueError(
1054
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
1055
+ )
1056
+ time_ids = added_cond_kwargs.get("time_ids")
1057
+ time_embeds = self.add_time_proj(time_ids.flatten())
1058
+ time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
1059
+
1060
+ add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
1061
+ add_embeds = add_embeds.to(emb.dtype)
1062
+ aug_emb = self.add_embedding(add_embeds)
1063
+ elif self.config.addition_embed_type == "image":
1064
+ # Kandinsky 2.2 - style
1065
+ if "image_embeds" not in added_cond_kwargs:
1066
+ raise ValueError(
1067
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
1068
+ )
1069
+ image_embs = added_cond_kwargs.get("image_embeds")
1070
+ aug_emb = self.add_embedding(image_embs)
1071
+ elif self.config.addition_embed_type == "image_hint":
1072
+ # Kandinsky 2.2 - style
1073
+ if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
1074
+ raise ValueError(
1075
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
1076
+ )
1077
+ image_embs = added_cond_kwargs.get("image_embeds")
1078
+ hint = added_cond_kwargs.get("hint")
1079
+ aug_emb, hint = self.add_embedding(image_embs, hint)
1080
+ sample = torch.cat([sample, hint], dim=1)
1081
+
1082
+ emb = emb + aug_emb if aug_emb is not None else emb
1083
+ emb_pre_act = emb
1084
+ if self.time_embed_act is not None:
1085
+ emb = self.time_embed_act(emb)
1086
+
1087
+ if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
1088
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
1089
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
1090
+ # Kadinsky 2.1 - style
1091
+ if "image_embeds" not in added_cond_kwargs:
1092
+ raise ValueError(
1093
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
1094
+ )
1095
+
1096
+ image_embeds = added_cond_kwargs.get("image_embeds")
1097
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
1098
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":
1099
+ # Kandinsky 2.2 - style
1100
+ if "image_embeds" not in added_cond_kwargs:
1101
+ raise ValueError(
1102
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
1103
+ )
1104
+ image_embeds = added_cond_kwargs.get("image_embeds")
1105
+ encoder_hidden_states = self.encoder_hid_proj(image_embeds)
1106
+ # 2. pre-process
1107
+ sample = self.conv_in(sample)
1108
+ # 3. down
1109
+
1110
+ is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
1111
+ is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None
1112
+
1113
+ down_block_res_samples = (sample,)
1114
+ for i, downsample_block in enumerate(self.down_blocks):
1115
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
1116
+ # For t2i-adapter CrossAttnDownBlock2D
1117
+ additional_residuals = {}
1118
+ if is_adapter and len(down_block_additional_residuals) > 0:
1119
+ additional_residuals["additional_residuals"] = down_block_additional_residuals.pop(0)
1120
+
1121
+ sample, res_samples = downsample_block(
1122
+ hidden_states=sample,
1123
+ temb=emb,
1124
+ encoder_hidden_states=encoder_hidden_states,
1125
+ dino_feature=dino_feature,
1126
+ attention_mask=attention_mask,
1127
+ cross_attention_kwargs=cross_attention_kwargs,
1128
+ encoder_attention_mask=encoder_attention_mask,
1129
+ **additional_residuals,
1130
+ )
1131
+ else:
1132
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
1133
+
1134
+ if is_adapter and len(down_block_additional_residuals) > 0:
1135
+ sample += down_block_additional_residuals.pop(0)
1136
+
1137
+ down_block_res_samples += res_samples
1138
+
1139
+ if is_controlnet:
1140
+ new_down_block_res_samples = ()
1141
+
1142
+ for down_block_res_sample, down_block_additional_residual in zip(
1143
+ down_block_res_samples, down_block_additional_residuals
1144
+ ):
1145
+ down_block_res_sample = down_block_res_sample + down_block_additional_residual
1146
+ new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
1147
+
1148
+ down_block_res_samples = new_down_block_res_samples
1149
+
1150
+ if self.addition_downsample:
1151
+ global_sample = sample
1152
+ global_sample = self.downsample(global_sample)
1153
+ for layer in self.conv_block:
1154
+ global_sample = layer(global_sample)
1155
+ global_sample = self.addition_act_out(self.addition_conv_out(global_sample))
1156
+ global_sample = self.upsample(global_sample)
1157
+ # 4. mid
1158
+ if self.mid_block is not None:
1159
+ sample = self.mid_block(
1160
+ sample,
1161
+ emb,
1162
+ encoder_hidden_states=encoder_hidden_states,
1163
+ dino_feature=dino_feature,
1164
+ attention_mask=attention_mask,
1165
+ cross_attention_kwargs=cross_attention_kwargs,
1166
+ encoder_attention_mask=encoder_attention_mask,
1167
+ )
1168
+ # 4.1 regress elevation and focal length
1169
+ # # predict elevation -> embed -> projection -> add to time emb
1170
+ if self.regress_elevation or self.regress_focal_length:
1171
+ pool_embeds = self.pool(sample.detach()).squeeze(-1).squeeze(-1) # (2B, C)
1172
+ if self.mvcd_attention:
1173
+ pool_embeds_normal, pool_embeds_color = torch.chunk(pool_embeds, 2, dim=0)
1174
+ pool_embeds = torch.cat([pool_embeds_normal, pool_embeds_color], dim=-1) # (B, 2C)
1175
+ pose_pred = []
1176
+ if self.regress_elevation:
1177
+ ele_pred = self.elevation_regressor(pool_embeds)
1178
+ ele_pred = rearrange(ele_pred, '(b v) c -> b v c', v=self.num_views)
1179
+ ele_pred = torch.mean(ele_pred, dim=1)
1180
+ pose_pred.append(ele_pred) # b, c
1181
+
1182
+ if self.regress_focal_length:
1183
+ focal_pred = self.focal_regressor(pool_embeds)
1184
+ focal_pred = rearrange(focal_pred, '(b v) c -> b v c', v=self.num_views)
1185
+ focal_pred = torch.mean(focal_pred, dim=1)
1186
+ pose_pred.append(focal_pred)
1187
+ pose_pred = torch.cat(pose_pred, dim=-1)
1188
+ # 'e_de_da_sincos', (B, 2)
1189
+ pose_embeds = torch.cat([
1190
+ torch.sin(pose_pred),
1191
+ torch.cos(pose_pred)
1192
+ ], dim=-1)
1193
+ pose_embeds = self.camera_embedding(pose_embeds)
1194
+ pose_embeds = torch.repeat_interleave(pose_embeds, self.num_views, 0)
1195
+ if self.mvcd_attention:
1196
+ pose_embeds = torch.cat([pose_embeds,] * 2, dim=0)
1197
+
1198
+ emb = pose_embeds + emb_pre_act
1199
+ if self.time_embed_act is not None:
1200
+ emb = self.time_embed_act(emb)
1201
+
1202
+ if is_controlnet:
1203
+ sample = sample + mid_block_additional_residual
1204
+
1205
+ if self.addition_downsample:
1206
+ sample = sample + global_sample
1207
+
1208
+ # 5. up
1209
+ for i, upsample_block in enumerate(self.up_blocks):
1210
+ is_final_block = i == len(self.up_blocks) - 1
1211
+
1212
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
1213
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
1214
+
1215
+ # if we have not reached the final block and need to forward the
1216
+ # upsample size, we do it here
1217
+ if not is_final_block and forward_upsample_size:
1218
+ upsample_size = down_block_res_samples[-1].shape[2:]
1219
+
1220
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
1221
+ sample = upsample_block(
1222
+ hidden_states=sample,
1223
+ temb=emb,
1224
+ res_hidden_states_tuple=res_samples,
1225
+ encoder_hidden_states=encoder_hidden_states,
1226
+ dino_feature=dino_feature,
1227
+ cross_attention_kwargs=cross_attention_kwargs,
1228
+ upsample_size=upsample_size,
1229
+ attention_mask=attention_mask,
1230
+ encoder_attention_mask=encoder_attention_mask,
1231
+ )
1232
+ else:
1233
+ sample = upsample_block(
1234
+ hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
1235
+ )
1236
+ if torch.isnan(sample).any() or torch.isinf(sample).any():
1237
+ print("NAN in sample, stop training.")
1238
+ exit()
1239
+ # 6. post-process
1240
+ if self.conv_norm_out:
1241
+ sample = self.conv_norm_out(sample)
1242
+ sample = self.conv_act(sample)
1243
+ sample = self.conv_out(sample)
1244
+ if not return_dict:
1245
+ return (sample, pose_pred)
1246
+ if self.regress_elevation or self.regress_focal_length:
1247
+ return UNetMV2DConditionOutput(sample=sample), pose_pred
1248
+ else:
1249
+ return UNetMV2DConditionOutput(sample=sample)
1250
+
1251
+
1252
+ @classmethod
1253
+ def from_pretrained_2d(
1254
+ cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
1255
+ camera_embedding_type: str, num_views: int, sample_size: int,
1256
+ zero_init_conv_in: bool = True, zero_init_camera_projection: bool = False,
1257
+ projection_camera_embeddings_input_dim: int=2,
1258
+ cd_attention_last: bool = False, num_regress_blocks: int = 4,
1259
+ cd_attention_mid: bool = False, multiview_attention: bool = True,
1260
+ sparse_mv_attention: bool = False, selfattn_block: str = 'custom', mvcd_attention: bool = False,
1261
+ in_channels: int = 8, out_channels: int = 4, unclip: bool = False, regress_elevation: bool = False, regress_focal_length: bool = False,
1262
+ init_mvattn_with_selfattn: bool= False, use_dino: bool = False, addition_downsample: bool = False,
1263
+ **kwargs
1264
+ ):
1265
+ r"""
1266
+ Instantiate a pretrained PyTorch model from a pretrained model configuration.
1267
+
1268
+ The model is set in evaluation mode - `model.eval()` - by default, and dropout modules are deactivated. To
1269
+ train the model, set it back in training mode with `model.train()`.
1270
+
1271
+ Parameters:
1272
+ pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
1273
+ Can be either:
1274
+
1275
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
1276
+ the Hub.
1277
+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
1278
+ with [`~ModelMixin.save_pretrained`].
1279
+
1280
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
1281
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
1282
+ is not used.
1283
+ torch_dtype (`str` or `torch.dtype`, *optional*):
1284
+ Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
1285
+ dtype is automatically derived from the model's weights.
1286
+ force_download (`bool`, *optional*, defaults to `False`):
1287
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
1288
+ cached versions if they exist.
1289
+ resume_download (`bool`, *optional*, defaults to `False`):
1290
+ Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
1291
+ incompletely downloaded files are deleted.
1292
+ proxies (`Dict[str, str]`, *optional*):
1293
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
1294
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
1295
+ output_loading_info (`bool`, *optional*, defaults to `False`):
1296
+ Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
1297
+ local_files_only(`bool`, *optional*, defaults to `False`):
1298
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
1299
+ won't be downloaded from the Hub.
1300
+ use_auth_token (`str` or *bool*, *optional*):
1301
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
1302
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
1303
+ revision (`str`, *optional*, defaults to `"main"`):
1304
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
1305
+ allowed by Git.
1306
+ from_flax (`bool`, *optional*, defaults to `False`):
1307
+ Load the model weights from a Flax checkpoint save file.
1308
+ subfolder (`str`, *optional*, defaults to `""`):
1309
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
1310
+ mirror (`str`, *optional*):
1311
+ Mirror source to resolve accessibility issues if you're downloading a model in China. We do not
1312
+ guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
1313
+ information.
1314
+ device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
1315
+ A map that specifies where each submodule should go. It doesn't need to be defined for each
1316
+ parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the
1317
+ same device.
1318
+
1319
+ Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For
1320
+ more information about each option see [designing a device
1321
+ map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
1322
+ max_memory (`Dict`, *optional*):
1323
+ A dictionary device identifier for the maximum memory. Will default to the maximum memory available for
1324
+ each GPU and the available CPU RAM if unset.
1325
+ offload_folder (`str` or `os.PathLike`, *optional*):
1326
+ The path to offload weights if `device_map` contains the value `"disk"`.
1327
+ offload_state_dict (`bool`, *optional*):
1328
+ If `True`, temporarily offloads the CPU state dict to the hard drive to avoid running out of CPU RAM if
1329
+ the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to `True`
1330
+ when there is some disk offload.
1331
+ low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
1332
+ Speed up model loading only loading the pretrained weights and not initializing the weights. This also
1333
+ tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
1334
+ Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
1335
+ argument to `True` will raise an error.
1336
+ variant (`str`, *optional*):
1337
+ Load weights from a specified `variant` filename such as `"fp16"` or `"ema"`. This is ignored when
1338
+ loading `from_flax`.
1339
+ use_safetensors (`bool`, *optional*, defaults to `None`):
1340
+ If set to `None`, the `safetensors` weights are downloaded if they're available **and** if the
1341
+ `safetensors` library is installed. If set to `True`, the model is forcibly loaded from `safetensors`
1342
+ weights. If set to `False`, `safetensors` weights are not loaded.
1343
+
1344
+ <Tip>
1345
+
1346
+ To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with
1347
+ `huggingface-cli login`. You can also activate the special
1348
+ ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use this method in a
1349
+ firewalled environment.
1350
+
1351
+ </Tip>
1352
+
1353
+ Example:
1354
+
1355
+ ```py
1356
+ from diffusers import UNet2DConditionModel
1357
+
1358
+ unet = UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet")
1359
+ ```
1360
+
1361
+ If you get the error message below, you need to finetune the weights for your downstream task:
1362
+
1363
+ ```bash
1364
+ Some weights of UNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match:
1365
+ - conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated
1366
+ You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
1367
+ ```
1368
+ """
1369
+ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
1370
+ ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
1371
+ force_download = kwargs.pop("force_download", False)
1372
+ from_flax = kwargs.pop("from_flax", False)
1373
+ resume_download = kwargs.pop("resume_download", False)
1374
+ proxies = kwargs.pop("proxies", None)
1375
+ output_loading_info = kwargs.pop("output_loading_info", False)
1376
+ local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
1377
+ use_auth_token = kwargs.pop("use_auth_token", None)
1378
+ revision = kwargs.pop("revision", None)
1379
+ torch_dtype = kwargs.pop("torch_dtype", None)
1380
+ subfolder = kwargs.pop("subfolder", None)
1381
+ device_map = kwargs.pop("device_map", None)
1382
+ max_memory = kwargs.pop("max_memory", None)
1383
+ offload_folder = kwargs.pop("offload_folder", None)
1384
+ offload_state_dict = kwargs.pop("offload_state_dict", False)
1385
+ variant = kwargs.pop("variant", None)
1386
+ use_safetensors = kwargs.pop("use_safetensors", None)
1387
+
1388
+ if use_safetensors:
1389
+ raise ValueError(
1390
+ "`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetensors"
1391
+ )
1392
+
1393
+ allow_pickle = False
1394
+ if use_safetensors is None:
1395
+ use_safetensors = True
1396
+ allow_pickle = True
1397
+
1398
+ if device_map is not None and not is_accelerate_available():
1399
+ raise NotImplementedError(
1400
+ "Loading and dispatching requires `accelerate`. Please make sure to install accelerate or set"
1401
+ " `device_map=None`. You can install accelerate with `pip install accelerate`."
1402
+ )
1403
+
1404
+ # Check if we can handle device_map and dispatching the weights
1405
+ if device_map is not None and not is_torch_version(">=", "1.9.0"):
1406
+ raise NotImplementedError(
1407
+ "Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
1408
+ " `device_map=None`."
1409
+ )
1410
+
1411
+ # Load config if we don't provide a configuration
1412
+ config_path = pretrained_model_name_or_path
1413
+
1414
+ user_agent = {
1415
+ "diffusers": __version__,
1416
+ "file_type": "model",
1417
+ "framework": "pytorch",
1418
+ }
1419
+
1420
+ # load config
1421
+ config, unused_kwargs, commit_hash = cls.load_config(
1422
+ config_path,
1423
+ cache_dir=cache_dir,
1424
+ return_unused_kwargs=True,
1425
+ return_commit_hash=True,
1426
+ force_download=force_download,
1427
+ resume_download=resume_download,
1428
+ proxies=proxies,
1429
+ local_files_only=local_files_only,
1430
+ use_auth_token=use_auth_token,
1431
+ revision=revision,
1432
+ subfolder=subfolder,
1433
+ device_map=device_map,
1434
+ max_memory=max_memory,
1435
+ offload_folder=offload_folder,
1436
+ offload_state_dict=offload_state_dict,
1437
+ user_agent=user_agent,
1438
+ **kwargs,
1439
+ )
1440
+
1441
+ # modify config
1442
+ config["_class_name"] = cls.__name__
1443
+ config['in_channels'] = in_channels
1444
+ config['out_channels'] = out_channels
1445
+ config['sample_size'] = sample_size # training resolution
1446
+ config['num_views'] = num_views
1447
+ config['cd_attention_last'] = cd_attention_last
1448
+ config['cd_attention_mid'] = cd_attention_mid
1449
+ config['multiview_attention'] = multiview_attention
1450
+ config['sparse_mv_attention'] = sparse_mv_attention
1451
+ config['selfattn_block'] = selfattn_block
1452
+ config['mvcd_attention'] = mvcd_attention
1453
+ config["down_block_types"] = [
1454
+ "CrossAttnDownBlockMV2D",
1455
+ "CrossAttnDownBlockMV2D",
1456
+ "CrossAttnDownBlockMV2D",
1457
+ "DownBlock2D"
1458
+ ]
1459
+ config['mid_block_type'] = "UNetMidBlockMV2DCrossAttn"
1460
+ config["up_block_types"] = [
1461
+ "UpBlock2D",
1462
+ "CrossAttnUpBlockMV2D",
1463
+ "CrossAttnUpBlockMV2D",
1464
+ "CrossAttnUpBlockMV2D"
1465
+ ]
1466
+
1467
+
1468
+ config['regress_elevation'] = regress_elevation # true
1469
+ config['regress_focal_length'] = regress_focal_length # true
1470
+ config['projection_camera_embeddings_input_dim'] = projection_camera_embeddings_input_dim # 2 for elevation and 10 for focal_length
1471
+ config['use_dino'] = use_dino
1472
+ config['num_regress_blocks'] = num_regress_blocks
1473
+ config['addition_downsample'] = addition_downsample
1474
+ # load model
1475
+ model_file = None
1476
+ if from_flax:
1477
+ raise NotImplementedError
1478
+ else:
1479
+ if use_safetensors:
1480
+ try:
1481
+ model_file = _get_model_file(
1482
+ pretrained_model_name_or_path,
1483
+ weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant),
1484
+ cache_dir=cache_dir,
1485
+ force_download=force_download,
1486
+ resume_download=resume_download,
1487
+ proxies=proxies,
1488
+ local_files_only=local_files_only,
1489
+ use_auth_token=use_auth_token,
1490
+ revision=revision,
1491
+ subfolder=subfolder,
1492
+ user_agent=user_agent,
1493
+ commit_hash=commit_hash,
1494
+ )
1495
+ except IOError as e:
1496
+ if not allow_pickle:
1497
+ raise e
1498
+ pass
1499
+ if model_file is None:
1500
+ model_file = _get_model_file(
1501
+ pretrained_model_name_or_path,
1502
+ weights_name=_add_variant(WEIGHTS_NAME, variant),
1503
+ cache_dir=cache_dir,
1504
+ force_download=force_download,
1505
+ resume_download=resume_download,
1506
+ proxies=proxies,
1507
+ local_files_only=local_files_only,
1508
+ use_auth_token=use_auth_token,
1509
+ revision=revision,
1510
+ subfolder=subfolder,
1511
+ user_agent=user_agent,
1512
+ commit_hash=commit_hash,
1513
+ )
1514
+
1515
+ model = cls.from_config(config, **unused_kwargs)
1516
+ import copy
1517
+ state_dict_pretrain = load_state_dict(model_file, variant=variant)
1518
+ state_dict = copy.deepcopy(state_dict_pretrain)
1519
+
1520
+ if init_mvattn_with_selfattn:
1521
+ for key in state_dict_pretrain:
1522
+ if 'attn1' in key:
1523
+ key_mv = key.replace('attn1', 'attn_mv')
1524
+ state_dict[key_mv] = state_dict_pretrain[key]
1525
+ if 'to_out.0.weight' in key:
1526
+ nn.init.zeros_(state_dict[key_mv].data)
1527
+ if 'transformer_blocks' in key and 'norm1' in key: # in case that initialize the norm layer in resnet block
1528
+ key_mv = key.replace('norm1', 'norm_mv')
1529
+ state_dict[key_mv] = state_dict_pretrain[key]
1530
+ # del state_dict_pretrain
1531
+
1532
+ model._convert_deprecated_attention_blocks(state_dict)
1533
+
1534
+ conv_in_weight = state_dict['conv_in.weight']
1535
+ conv_out_weight = state_dict['conv_out.weight']
1536
+ model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model_2d(
1537
+ model,
1538
+ state_dict,
1539
+ model_file,
1540
+ pretrained_model_name_or_path,
1541
+ ignore_mismatched_sizes=True,
1542
+ )
1543
+ if any([key == 'conv_in.weight' for key, _, _ in mismatched_keys]):
1544
+ # initialize from the original SD structure
1545
+ model.conv_in.weight.data[:,:4] = conv_in_weight
1546
+
1547
+ # whether to place all zero to new layers?
1548
+ if zero_init_conv_in:
1549
+ model.conv_in.weight.data[:,4:] = 0.
1550
+
1551
+ if any([key == 'conv_out.weight' for key, _, _ in mismatched_keys]):
1552
+ # initialize from the original SD structure
1553
+ model.conv_out.weight.data[:,:4] = conv_out_weight
1554
+ if out_channels == 8: # copy for the last 4 channels
1555
+ model.conv_out.weight.data[:, 4:] = conv_out_weight
1556
+
1557
+ if zero_init_camera_projection: # true
1558
+ params = [p for p in model.camera_embedding.parameters()]
1559
+ torch.nn.init.zeros_(params[-1].data)
1560
+
1561
+ loading_info = {
1562
+ "missing_keys": missing_keys,
1563
+ "unexpected_keys": unexpected_keys,
1564
+ "mismatched_keys": mismatched_keys,
1565
+ "error_msgs": error_msgs,
1566
+ }
1567
+
1568
+ if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
1569
+ raise ValueError(
1570
+ f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}."
1571
+ )
1572
+ elif torch_dtype is not None:
1573
+ model = model.to(torch_dtype)
1574
+
1575
+ model.register_to_config(_name_or_path=pretrained_model_name_or_path)
1576
+
1577
+ # Set model in evaluation mode to deactivate DropOut modules by default
1578
+ model.eval()
1579
+ if output_loading_info:
1580
+ return model, loading_info
1581
+ return model
1582
+
1583
+ @classmethod
1584
+ def _load_pretrained_model_2d(
1585
+ cls,
1586
+ model,
1587
+ state_dict,
1588
+ resolved_archive_file,
1589
+ pretrained_model_name_or_path,
1590
+ ignore_mismatched_sizes=False,
1591
+ ):
1592
+ # Retrieve missing & unexpected_keys
1593
+ model_state_dict = model.state_dict()
1594
+ loaded_keys = list(state_dict.keys())
1595
+
1596
+ expected_keys = list(model_state_dict.keys())
1597
+
1598
+ original_loaded_keys = loaded_keys
1599
+
1600
+ missing_keys = list(set(expected_keys) - set(loaded_keys))
1601
+ unexpected_keys = list(set(loaded_keys) - set(expected_keys))
1602
+
1603
+ # Make sure we are able to load base models as well as derived models (with heads)
1604
+ model_to_load = model
1605
+
1606
+ def _find_mismatched_keys(
1607
+ state_dict,
1608
+ model_state_dict,
1609
+ loaded_keys,
1610
+ ignore_mismatched_sizes,
1611
+ ):
1612
+ mismatched_keys = []
1613
+ if ignore_mismatched_sizes:
1614
+ for checkpoint_key in loaded_keys:
1615
+ model_key = checkpoint_key
1616
+
1617
+ if (
1618
+ model_key in model_state_dict
1619
+ and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
1620
+ ):
1621
+ mismatched_keys.append(
1622
+ (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
1623
+ )
1624
+ del state_dict[checkpoint_key]
1625
+ return mismatched_keys
1626
+
1627
+ if state_dict is not None:
1628
+ # Whole checkpoint
1629
+ mismatched_keys = _find_mismatched_keys(
1630
+ state_dict,
1631
+ model_state_dict,
1632
+ original_loaded_keys,
1633
+ ignore_mismatched_sizes,
1634
+ )
1635
+ error_msgs = _load_state_dict_into_model(model_to_load, state_dict)
1636
+
1637
+ if len(error_msgs) > 0:
1638
+ error_msg = "\n\t".join(error_msgs)
1639
+ if "size mismatch" in error_msg:
1640
+ error_msg += (
1641
+ "\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method."
1642
+ )
1643
+ raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
1644
+
1645
+ if len(unexpected_keys) > 0:
1646
+ logger.warning(
1647
+ f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
1648
+ f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
1649
+ f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task"
1650
+ " or with another architecture (e.g. initializing a BertForSequenceClassification model from a"
1651
+ " BertForPreTraining model).\n- This IS NOT expected if you are initializing"
1652
+ f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly"
1653
+ " identical (initializing a BertForSequenceClassification model from a"
1654
+ " BertForSequenceClassification model)."
1655
+ )
1656
+ else:
1657
+ logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
1658
+ if len(missing_keys) > 0:
1659
+ logger.warning(
1660
+ f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
1661
+ f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
1662
+ " TRAIN this model on a down-stream task to be able to use it for predictions and inference."
1663
+ )
1664
+ elif len(mismatched_keys) == 0:
1665
+ logger.info(
1666
+ f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at"
1667
+ f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the"
1668
+ f" checkpoint was trained on, you can already use {model.__class__.__name__} for predictions"
1669
+ " without further training."
1670
+ )
1671
+ if len(mismatched_keys) > 0:
1672
+ mismatched_warning = "\n".join(
1673
+ [
1674
+ f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
1675
+ for key, shape1, shape2 in mismatched_keys
1676
+ ]
1677
+ )
1678
+ logger.warning(
1679
+ f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
1680
+ f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not"
1681
+ f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be"
1682
+ " able to use it for predictions and inference."
1683
+ )
1684
+
1685
+ return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs
multiview/pipeline_multiclass.py ADDED
@@ -0,0 +1,656 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ import warnings
3
+ from typing import Callable, List, Optional, Union, Dict, Any
4
+ import PIL
5
+ import torch
6
+ import kornia
7
+ from packaging import version
8
+ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection, CLIPFeatureExtractor, CLIPTextModel
9
+ from diffusers.utils.import_utils import is_accelerate_available
10
+ from diffusers.configuration_utils import FrozenDict
11
+ from diffusers.image_processor import VaeImageProcessor
12
+ from diffusers.models import AutoencoderKL, UNet2DConditionModel
13
+ from diffusers.models.embeddings import get_timestep_embedding
14
+ from diffusers.schedulers import KarrasDiffusionSchedulers
15
+ from diffusers.utils import deprecate, logging
16
+ from diffusers.utils.torch_utils import randn_tensor
17
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
18
+ from diffusers.pipelines.stable_diffusion.stable_unclip_image_normalizer import StableUnCLIPImageNormalizer
19
+ import os
20
+ import torchvision.transforms.functional as TF
21
+ from einops import rearrange
22
+ logger = logging.get_logger(__name__)
23
+
24
+
25
+ def CLIP_preprocess(x):
26
+ dtype = x.dtype
27
+ # following openai's implementation
28
+ # TODO HF OpenAI CLIP preprocessing issue https://github.com/huggingface/transformers/issues/22505#issuecomment-1650170741
29
+ # follow openai preprocessing to keep exact same, input tensor [-1, 1], otherwise the preprocessing will be different, https://github.com/huggingface/transformers/pull/22608
30
+ if isinstance(x, torch.Tensor):
31
+ if x.min() < -1.0 or x.max() > 1.0:
32
+ raise ValueError("Expected input tensor to have values in the range [-1, 1]")
33
+ x = kornia.geometry.resize(x.to(torch.float32), (224, 224), interpolation='bicubic', align_corners=True, antialias=False).to(dtype=dtype)
34
+ x = (x + 1.) / 2.
35
+ # renormalize according to clip
36
+ x = kornia.enhance.normalize(x, torch.Tensor([0.48145466, 0.4578275, 0.40821073]),
37
+ torch.Tensor([0.26862954, 0.26130258, 0.27577711]))
38
+ return x
39
+
40
+
41
+ class StableUnCLIPImg2ImgPipeline(DiffusionPipeline):
42
+ """
43
+ Pipeline for text-guided image to image generation using stable unCLIP.
44
+
45
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
46
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
47
+
48
+ Args:
49
+ feature_extractor ([`CLIPFeatureExtractor`]):
50
+ Feature extractor for image pre-processing before being encoded.
51
+ image_encoder ([`CLIPVisionModelWithProjection`]):
52
+ CLIP vision model for encoding images.
53
+ image_normalizer ([`StableUnCLIPImageNormalizer`]):
54
+ Used to normalize the predicted image embeddings before the noise is applied and un-normalize the image
55
+ embeddings after the noise has been applied.
56
+ image_noising_scheduler ([`KarrasDiffusionSchedulers`]):
57
+ Noise schedule for adding noise to the predicted image embeddings. The amount of noise to add is determined
58
+ by `noise_level` in `StableUnCLIPPipeline.__call__`.
59
+ text_encoder ([`CLIPTextModel`]):
60
+ Frozen text-encoder.
61
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
62
+ scheduler ([`KarrasDiffusionSchedulers`]):
63
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents.
64
+ vae ([`AutoencoderKL`]):
65
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
66
+ """
67
+ # image encoding components
68
+ feature_extractor: CLIPFeatureExtractor
69
+ image_encoder: CLIPVisionModelWithProjection
70
+ # image noising components
71
+ image_normalizer: StableUnCLIPImageNormalizer
72
+ image_noising_scheduler: KarrasDiffusionSchedulers
73
+ # regular denoising components
74
+ text_encoder: CLIPTextModel
75
+ unet: UNet2DConditionModel
76
+ scheduler: KarrasDiffusionSchedulers
77
+ vae: AutoencoderKL
78
+
79
+ def __init__(
80
+ self,
81
+ # image encoding components
82
+ feature_extractor: CLIPFeatureExtractor,
83
+ image_encoder: CLIPVisionModelWithProjection,
84
+ # image noising components
85
+ image_normalizer: StableUnCLIPImageNormalizer,
86
+ image_noising_scheduler: KarrasDiffusionSchedulers,
87
+ # regular denoising components
88
+ text_encoder: CLIPTextModel,
89
+ unet: UNet2DConditionModel,
90
+ scheduler: KarrasDiffusionSchedulers,
91
+ # vae
92
+ vae: AutoencoderKL,
93
+ num_views: int = 4,
94
+ ):
95
+ super().__init__()
96
+
97
+ self.register_modules(
98
+ feature_extractor=feature_extractor,
99
+ image_encoder=image_encoder,
100
+ image_normalizer=image_normalizer,
101
+ image_noising_scheduler=image_noising_scheduler,
102
+ text_encoder=text_encoder,
103
+ unet=unet,
104
+ scheduler=scheduler,
105
+ vae=vae,
106
+ )
107
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
108
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
109
+ self.num_views: int = num_views
110
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
111
+ def enable_vae_slicing(self):
112
+ r"""
113
+ Enable sliced VAE decoding.
114
+
115
+ When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
116
+ steps. This is useful to save some memory and allow larger batch sizes.
117
+ """
118
+ self.vae.enable_slicing()
119
+
120
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing
121
+ def disable_vae_slicing(self):
122
+ r"""
123
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to
124
+ computing decoding in one step.
125
+ """
126
+ self.vae.disable_slicing()
127
+
128
+ def enable_sequential_cpu_offload(self, gpu_id=0):
129
+ r"""
130
+ Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's
131
+ models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only
132
+ when their specific submodule has its `forward` method called.
133
+ """
134
+ if is_accelerate_available():
135
+ from accelerate import cpu_offload
136
+ else:
137
+ raise ImportError("Please install accelerate via `pip install accelerate`")
138
+
139
+ device = torch.device(f"cuda:{gpu_id}")
140
+
141
+ # TODO: self.image_normalizer.{scale,unscale} are not covered by the offload hooks, so they fails if added to the list
142
+ models = [
143
+ self.image_encoder,
144
+ self.text_encoder,
145
+ self.unet,
146
+ self.vae,
147
+ ]
148
+ for cpu_offloaded_model in models:
149
+ if cpu_offloaded_model is not None:
150
+ cpu_offload(cpu_offloaded_model, device)
151
+
152
+ @property
153
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
154
+ def _execution_device(self):
155
+ r"""
156
+ Returns the device on which the pipeline's models will be executed. After calling
157
+ `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
158
+ hooks.
159
+ """
160
+ if not hasattr(self.unet, "_hf_hook"):
161
+ return self.device
162
+ for module in self.unet.modules():
163
+ if (
164
+ hasattr(module, "_hf_hook")
165
+ and hasattr(module._hf_hook, "execution_device")
166
+ and module._hf_hook.execution_device is not None
167
+ ):
168
+ return torch.device(module._hf_hook.execution_device)
169
+ return self.device
170
+
171
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
172
+ def _encode_prompt(
173
+ self,
174
+ prompt,
175
+ device,
176
+ num_images_per_prompt,
177
+ do_classifier_free_guidance,
178
+ negative_prompt=None,
179
+ prompt_embeds: Optional[torch.FloatTensor] = None,
180
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
181
+ lora_scale: Optional[float] = None,
182
+ ):
183
+ r"""
184
+ Encodes the prompt into text encoder hidden states.
185
+
186
+ Args:
187
+ prompt (`str` or `List[str]`, *optional*):
188
+ prompt to be encoded
189
+ device: (`torch.device`):
190
+ torch device
191
+ num_images_per_prompt (`int`):
192
+ number of images that should be generated per prompt
193
+ do_classifier_free_guidance (`bool`):
194
+ whether to use classifier free guidance or not
195
+ negative_prompt (`str` or `List[str]`, *optional*):
196
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
197
+ `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
198
+ Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
199
+ prompt_embeds (`torch.FloatTensor`, *optional*):
200
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
201
+ provided, text embeddings will be generated from `prompt` input argument.
202
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
203
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
204
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
205
+ argument.
206
+ """
207
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
208
+
209
+ if do_classifier_free_guidance:
210
+ # For classifier free guidance, we need to do two forward passes.
211
+ # Here we concatenate the unconditional and text embeddings into a single batch
212
+ # to avoid doing two forward passes
213
+ normal_prompt_embeds, color_prompt_embeds = torch.chunk(prompt_embeds, 2, dim=0)
214
+
215
+ prompt_embeds = torch.cat([normal_prompt_embeds, normal_prompt_embeds, color_prompt_embeds, color_prompt_embeds], 0)
216
+
217
+ return prompt_embeds
218
+
219
+ def _encode_image(
220
+ self,
221
+ # image_pil,
222
+ image,
223
+ device,
224
+ num_images_per_prompt,
225
+ do_classifier_free_guidance,
226
+ noise_level: int=0,
227
+ class_targets: list=None,
228
+ generator: Optional[torch.Generator] = None
229
+ ):
230
+ dtype = next(self.image_encoder.parameters()).dtype
231
+ # ______________________________clip image embedding______________________________
232
+ image_ = CLIP_preprocess(image)
233
+ image_embeds = self.image_encoder(image_).image_embeds
234
+
235
+ image_embeds_ls = []
236
+
237
+ for class_target in class_targets:
238
+ image_embeds_ls.append(self.noise_image_embeddings(
239
+ image_embeds=image_embeds,
240
+ noise_level=noise_level,
241
+ class_target=class_target,
242
+ generator=generator,
243
+ ).repeat(num_images_per_prompt, 1))
244
+
245
+ if do_classifier_free_guidance:
246
+ for idx in range(len(image_embeds_ls)):
247
+ normal_image_embeds, color_image_embeds = torch.chunk(image_embeds_ls[idx], 2, dim=0)
248
+ negative_prompt_embeds = torch.zeros_like(normal_image_embeds)
249
+
250
+ # For classifier free guidance, we need to do two forward passes.
251
+ # Here we concatenate the unconditional and text embeddings into a single batch
252
+ # to avoid doing two forward passes
253
+ image_embeds_ls[idx] = torch.cat([negative_prompt_embeds, normal_image_embeds, negative_prompt_embeds, color_image_embeds], 0)
254
+
255
+ # _____________________________vae input latents__________________________________________________
256
+ image_latents = self.vae.encode(image.to(self.vae.dtype)).latent_dist.mode() * self.vae.config.scaling_factor
257
+ # Note: repeat differently from official pipelines
258
+ image_latents = image_latents.repeat(num_images_per_prompt, 1, 1, 1)
259
+
260
+ if do_classifier_free_guidance:
261
+ normal_image_latents, color_image_latents = torch.chunk(image_latents, 2, dim=0)
262
+ image_latents = torch.cat([torch.zeros_like(normal_image_latents), normal_image_latents,
263
+ torch.zeros_like(color_image_latents), color_image_latents], 0)
264
+
265
+ return image_embeds_ls, image_latents
266
+
267
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
268
+ def decode_latents(self, latents):
269
+ latents = 1 / self.vae.config.scaling_factor * latents
270
+ image = self.vae.decode(latents).sample
271
+ image = (image / 2 + 0.5).clamp(0, 1)
272
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
273
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
274
+ return image
275
+
276
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
277
+ def prepare_extra_step_kwargs(self, generator, eta):
278
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
279
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
280
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
281
+ # and should be between [0, 1]
282
+
283
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
284
+ extra_step_kwargs = {}
285
+ if accepts_eta:
286
+ extra_step_kwargs["eta"] = eta
287
+
288
+ # check if the scheduler accepts generator
289
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
290
+ if accepts_generator:
291
+ extra_step_kwargs["generator"] = generator
292
+ return extra_step_kwargs
293
+
294
+ def check_inputs(
295
+ self,
296
+ prompt,
297
+ image,
298
+ height,
299
+ width,
300
+ callback_steps,
301
+ noise_level,
302
+ ):
303
+ if height % 8 != 0 or width % 8 != 0:
304
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
305
+
306
+ if (callback_steps is None) or (
307
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
308
+ ):
309
+ raise ValueError(
310
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
311
+ f" {type(callback_steps)}."
312
+ )
313
+
314
+ if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
315
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
316
+
317
+
318
+ if noise_level < 0 or noise_level >= self.image_noising_scheduler.config.num_train_timesteps:
319
+ raise ValueError(
320
+ f"`noise_level` must be between 0 and {self.image_noising_scheduler.config.num_train_timesteps - 1}, inclusive."
321
+ )
322
+
323
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
324
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
325
+ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
326
+ if isinstance(generator, list) and len(generator) != batch_size:
327
+ raise ValueError(
328
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
329
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
330
+ )
331
+
332
+ if latents is None:
333
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
334
+ latents = noise.clone()
335
+ else:
336
+ latents = latents.to(device)
337
+
338
+ # scale the initial noise by the standard deviation required by the scheduler
339
+ latents = latents * self.scheduler.init_noise_sigma
340
+ return latents, noise
341
+
342
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_unclip.StableUnCLIPPipeline.noise_image_embeddings
343
+ def noise_image_embeddings(
344
+ self,
345
+ image_embeds: torch.Tensor,
346
+ noise_level: int,
347
+ class_target: torch.Tensor,
348
+ noise: Optional[torch.FloatTensor] = None,
349
+ generator: Optional[torch.Generator] = None,
350
+ ):
351
+ """
352
+ Add noise to the image embeddings. The amount of noise is controlled by a `noise_level` input. A higher
353
+ `noise_level` increases the variance in the final un-noised images.
354
+
355
+ The noise is applied in two ways
356
+ 1. A noise schedule is applied directly to the embeddings
357
+ 2. A vector of sinusoidal time embeddings are appended to the output.
358
+
359
+ In both cases, the amount of noise is controlled by the same `noise_level`.
360
+
361
+ The embeddings are normalized before the noise is applied and un-normalized after the noise is applied.
362
+ """
363
+ if noise is None:
364
+ noise = randn_tensor(
365
+ image_embeds.shape, generator=generator, device=image_embeds.device, dtype=image_embeds.dtype
366
+ )
367
+
368
+ noise_level = torch.tensor([noise_level] * image_embeds.shape[0], device=image_embeds.device)
369
+
370
+ dtype = image_embeds.dtype
371
+
372
+ image_embeds = self.image_normalizer.scale(image_embeds)
373
+
374
+ image_embeds = self.image_noising_scheduler.add_noise(image_embeds, timesteps=noise_level, noise=noise)
375
+
376
+ image_embeds = self.image_normalizer.unscale(image_embeds)
377
+
378
+ noise_level = get_timestep_embedding(
379
+ timesteps=noise_level, embedding_dim=image_embeds.shape[-1], flip_sin_to_cos=True, downscale_freq_shift=0
380
+ )
381
+
382
+ # `get_timestep_embeddings` does not contain any weights and will always return f32 tensors,
383
+ # but we might actually be running in fp16. so we need to cast here.
384
+ # there might be better ways to encapsulate this.
385
+ image_embeds = image_embeds.to(dtype=dtype)
386
+ noise_level = noise_level.to(image_embeds.dtype)
387
+
388
+ image_embeds = torch.cat((image_embeds, class_target.repeat(image_embeds.shape[0] // class_target.shape[0], 1)), 1)
389
+
390
+ return image_embeds
391
+
392
+
393
+ @torch.no_grad()
394
+ def __call__(
395
+ self,
396
+ image: Union[torch.FloatTensor, PIL.Image.Image],
397
+ prompt: Union[str, List[str]],
398
+ prompt_embeds: torch.FloatTensor = None,
399
+ dino_feature: torch.FloatTensor = None,
400
+ height: Optional[int] = None,
401
+ width: Optional[int] = None,
402
+ num_inference_steps: int = 20,
403
+ guidance_scale: float = 10,
404
+ negative_prompt: Optional[Union[str, List[str]]] = None,
405
+ num_images_per_prompt: Optional[int] = 1,
406
+ eta: float = 0.0,
407
+ generator: Optional[torch.Generator] = None,
408
+ latents: Optional[torch.FloatTensor] = None,
409
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
410
+ output_type: Optional[str] = "pil",
411
+ return_dict: bool = True,
412
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
413
+ callback_steps: int = 1,
414
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
415
+ noise_level: int = 0,
416
+ image_embeds: Optional[torch.FloatTensor] = None,
417
+ return_elevation_focal: Optional[bool] = False,
418
+ gt_img_in: Optional[torch.FloatTensor] = None,
419
+ num_levels: Optional[int] = 3,
420
+ ):
421
+ r"""
422
+ Function invoked when calling the pipeline for generation.
423
+
424
+ Args:
425
+ prompt (`str` or `List[str]`, *optional*):
426
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
427
+ instead.
428
+ image (`torch.FloatTensor` or `PIL.Image.Image`):
429
+ `Image`, or tensor representing an image batch. The image will be encoded to its CLIP embedding which
430
+ the unet will be conditioned on. Note that the image is _not_ encoded by the vae and then used as the
431
+ latents in the denoising process such as in the standard stable diffusion text guided image variation
432
+ process.
433
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
434
+ The height in pixels of the generated image.
435
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
436
+ The width in pixels of the generated image.
437
+ num_inference_steps (`int`, *optional*, defaults to 20):
438
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
439
+ expense of slower inference.
440
+ guidance_scale (`float`, *optional*, defaults to 10.0):
441
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
442
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
443
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
444
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
445
+ usually at the expense of lower image quality.
446
+ negative_prompt (`str` or `List[str]`, *optional*):
447
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
448
+ `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
449
+ Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
450
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
451
+ The number of images to generate per prompt.
452
+ eta (`float`, *optional*, defaults to 0.0):
453
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
454
+ [`schedulers.DDIMScheduler`], will be ignored for others.
455
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
456
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
457
+ to make generation deterministic.
458
+ latents (`torch.FloatTensor`, *optional*):
459
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
460
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
461
+ tensor will ge generated by sampling using the supplied random `generator`.
462
+ prompt_embeds (`torch.FloatTensor`, *optional*):
463
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
464
+ provided, text embeddings will be generated from `prompt` input argument.
465
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
466
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
467
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
468
+ argument.
469
+ output_type (`str`, *optional*, defaults to `"pil"`):
470
+ The output format of the generate image. Choose between
471
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
472
+ return_dict (`bool`, *optional*, defaults to `True`):
473
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
474
+ plain tuple.
475
+ callback (`Callable`, *optional*):
476
+ A function that will be called every `callback_steps` steps during inference. The function will be
477
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
478
+ callback_steps (`int`, *optional*, defaults to 1):
479
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
480
+ called at every step.
481
+ cross_attention_kwargs (`dict`, *optional*):
482
+ A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under
483
+ `self.processor` in
484
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
485
+ noise_level (`int`, *optional*, defaults to `0`):
486
+ The amount of noise to add to the image embeddings. A higher `noise_level` increases the variance in
487
+ the final un-noised images. See `StableUnCLIPPipeline.noise_image_embeddings` for details.
488
+ image_embeds (`torch.FloatTensor`, *optional*):
489
+ Pre-generated CLIP embeddings to condition the unet on. Note that these are not latents to be used in
490
+ the denoising process. If you want to provide pre-generated latents, pass them to `__call__` as
491
+ `latents`.
492
+
493
+ Examples:
494
+
495
+ Returns:
496
+ [`~pipelines.ImagePipelineOutput`] or `tuple`: [`~ pipeline_utils.ImagePipelineOutput`] if `return_dict` is
497
+ True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated images.
498
+ """
499
+ # 0. Default height and width to unet
500
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
501
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
502
+
503
+ # 1. Check inputs. Raise error if not correct
504
+ self.check_inputs(
505
+ prompt=prompt,
506
+ image=image,
507
+ height=height,
508
+ width=width,
509
+ callback_steps=callback_steps,
510
+ noise_level=noise_level
511
+ )
512
+
513
+ # 2. Define call parameters
514
+ if isinstance(image, list):
515
+ batch_size = len(image)
516
+ elif isinstance(image, torch.Tensor):
517
+ batch_size = image.shape[0]
518
+ assert batch_size >= self.num_views and batch_size % self.num_views == 0
519
+ elif isinstance(image, PIL.Image.Image):
520
+ image = [image]*self.num_views*2
521
+ batch_size = self.num_views*2
522
+
523
+ if isinstance(prompt, str):
524
+ prompt = [prompt] * self.num_views * 2
525
+
526
+ device = self._execution_device
527
+
528
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
529
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
530
+ # corresponds to doing no classifier free guidance.
531
+ do_classifier_free_guidance = guidance_scale != 1.0
532
+
533
+ # 3. Encode input prompt
534
+ text_encoder_lora_scale = (
535
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
536
+ )
537
+ prompt_embeds = self._encode_prompt(
538
+ prompt=prompt,
539
+ device=device,
540
+ num_images_per_prompt=num_images_per_prompt,
541
+ do_classifier_free_guidance=do_classifier_free_guidance,
542
+ negative_prompt=negative_prompt,
543
+ prompt_embeds=prompt_embeds,
544
+ negative_prompt_embeds=negative_prompt_embeds,
545
+ lora_scale=text_encoder_lora_scale,
546
+ )
547
+
548
+
549
+ # 4. Encoder input image
550
+ noise_level = torch.tensor([noise_level], device=device)
551
+
552
+ class_targets = []
553
+ for level in [0, 1, 2]:
554
+ class_target = torch.tensor([0, 0, 0, 0]).cuda()
555
+ class_target[level] = 1
556
+ class_target = torch.repeat_interleave(class_target, 256).unsqueeze(0)
557
+ class_targets.append(class_target)
558
+
559
+ image_embeds_ls, image_latents = self._encode_image(
560
+ image=image,
561
+ device=device,
562
+ num_images_per_prompt=num_images_per_prompt,
563
+ do_classifier_free_guidance=do_classifier_free_guidance,
564
+ noise_level=noise_level,
565
+ class_targets=class_targets,
566
+ generator=generator,
567
+ )
568
+
569
+ # 5. Prepare timesteps
570
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
571
+ timesteps = self.scheduler.timesteps
572
+
573
+ # 6. Prepare latent variables
574
+ num_channels_latents = self.unet.config.out_channels
575
+ if gt_img_in is not None:
576
+ latents = gt_img_in * self.scheduler.init_noise_sigma
577
+ else:
578
+ latents, noise = self.prepare_latents(
579
+ batch_size=batch_size,
580
+ num_channels_latents=num_channels_latents,
581
+ height=height,
582
+ width=width,
583
+ dtype=prompt_embeds.dtype,
584
+ device=device,
585
+ generator=generator,
586
+ latents=latents,
587
+ )
588
+
589
+ # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
590
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
591
+
592
+ original_latents = latents.clone()
593
+ image_ls = []
594
+ now_range = range(1, 3) if num_levels == 2 else range(num_levels)
595
+ for level in now_range:
596
+ latents = original_latents.clone()
597
+ eles, focals = [], []
598
+ # 8. Denoising loop
599
+ for i, t in enumerate(self.progress_bar(timesteps)):
600
+ if do_classifier_free_guidance:
601
+ normal_latents, color_latents = torch.chunk(latents, 2, dim=0)
602
+ latent_model_input = torch.cat([normal_latents, normal_latents, color_latents, color_latents], 0)
603
+ else:
604
+ latent_model_input = latents
605
+
606
+ latent_model_input = torch.cat([
607
+ latent_model_input, image_latents
608
+ ], dim=1)
609
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
610
+
611
+ # predict the noise residual
612
+ unet_out = self.unet(
613
+ latent_model_input,
614
+ t,
615
+ encoder_hidden_states=prompt_embeds,
616
+ dino_feature=dino_feature,
617
+ class_labels=image_embeds_ls[level],
618
+ cross_attention_kwargs=cross_attention_kwargs,
619
+ return_dict=False)
620
+
621
+ noise_pred = unet_out[0]
622
+ if return_elevation_focal:
623
+ uncond_pose, pose = torch.chunk(unet_out[1], 2, 0)
624
+ pose = uncond_pose + guidance_scale * (pose - uncond_pose)
625
+ ele = pose[:, 0].detach().cpu().numpy() # b
626
+ eles.append(ele)
627
+ focal = pose[:, 1].detach().cpu().numpy()
628
+ focals.append(focal)
629
+
630
+ # perform guidance
631
+ if do_classifier_free_guidance:
632
+ normal_noise_pred_uncond, normal_noise_pred_text, color_noise_pred_uncond, color_noise_pred_text = torch.chunk(noise_pred, 4, dim=0)
633
+
634
+ noise_pred_uncond, noise_pred_text = torch.cat([normal_noise_pred_uncond, color_noise_pred_uncond], 0), torch.cat([normal_noise_pred_text, color_noise_pred_text], 0)
635
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
636
+
637
+ # compute the previous noisy sample x_t -> x_t-1
638
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
639
+
640
+ if callback is not None and i % callback_steps == 0:
641
+ callback(i, t, latents)
642
+
643
+ # 9. Post-processing
644
+ if not output_type == "latent":
645
+ if num_channels_latents == 8:
646
+ latents = torch.cat([latents[:, :4], latents[:, 4:]], dim=0)
647
+ with torch.no_grad():
648
+ image = self.vae.decode((latents / self.vae.config.scaling_factor).to(self.vae.dtype), return_dict=False)[0]
649
+ else:
650
+ image = latents
651
+
652
+ image = self.image_processor.postprocess(image, output_type=output_type)
653
+ image = ImagePipelineOutput(images=image)
654
+ image_ls.append(image)
655
+
656
+ return image_ls
refine/func.py ADDED
@@ -0,0 +1,427 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from pytorch3d.renderer.cameras import look_at_view_transform, OrthographicCameras, CamerasBase
3
+ from pytorch3d.renderer import (
4
+ RasterizationSettings,
5
+ TexturesVertex,
6
+ FoVPerspectiveCameras,
7
+ FoVOrthographicCameras,
8
+ )
9
+ from pytorch3d.structures import Meshes
10
+ from PIL import Image
11
+ from typing import List
12
+ from refine.render import _warmup
13
+ import pymeshlab as ml
14
+ from pymeshlab import Percentage
15
+ import nvdiffrast.torch as dr
16
+ import numpy as np
17
+
18
+
19
+ def _translation(x, y, z, device):
20
+ return torch.tensor([[1., 0, 0, x],
21
+ [0, 1, 0, y],
22
+ [0, 0, 1, z],
23
+ [0, 0, 0, 1]],device=device) #4,4
24
+
25
+ def _projection(r, device, l=None, t=None, b=None, n=1.0, f=50.0, flip_y=True):
26
+ """
27
+ see https://blog.csdn.net/wodownload2/article/details/85069240/
28
+ """
29
+ if l is None:
30
+ l = -r
31
+ if t is None:
32
+ t = r
33
+ if b is None:
34
+ b = -t
35
+ p = torch.zeros([4,4],device=device)
36
+ p[0,0] = 2*n/(r-l)
37
+ p[0,2] = (r+l)/(r-l)
38
+ p[1,1] = 2*n/(t-b) * (-1 if flip_y else 1)
39
+ p[1,2] = (t+b)/(t-b)
40
+ p[2,2] = -(f+n)/(f-n)
41
+ p[2,3] = -(2*f*n)/(f-n)
42
+ p[3,2] = -1
43
+ return p #4,4
44
+
45
+ def _orthographic(r, device, l=None, t=None, b=None, n=1.0, f=50.0, flip_y=True):
46
+ if l is None:
47
+ l = -r
48
+ if t is None:
49
+ t = r
50
+ if b is None:
51
+ b = -t
52
+ o = torch.zeros([4,4],device=device)
53
+ o[0,0] = 2/(r-l)
54
+ o[0,3] = -(r+l)/(r-l)
55
+ o[1,1] = 2/(t-b) * (-1 if flip_y else 1)
56
+ o[1,3] = -(t+b)/(t-b)
57
+ o[2,2] = -2/(f-n)
58
+ o[2,3] = -(f+n)/(f-n)
59
+ o[3,3] = 1
60
+ return o #4,4
61
+
62
+ def make_star_cameras(az_count,pol_count,distance:float=10.,r=None,image_size=[512,512],device='cuda'):
63
+ if r is None:
64
+ r = 1/distance
65
+ A = az_count
66
+ P = pol_count
67
+ C = A * P
68
+
69
+ phi = torch.arange(0,A) * (2*torch.pi/A)
70
+ phi_rot = torch.eye(3,device=device)[None,None].expand(A,1,3,3).clone()
71
+ phi_rot[:,0,2,2] = phi.cos()
72
+ phi_rot[:,0,2,0] = -phi.sin()
73
+ phi_rot[:,0,0,2] = phi.sin()
74
+ phi_rot[:,0,0,0] = phi.cos()
75
+
76
+ theta = torch.arange(1,P+1) * (torch.pi/(P+1)) - torch.pi/2
77
+ theta_rot = torch.eye(3,device=device)[None,None].expand(1,P,3,3).clone()
78
+ theta_rot[0,:,1,1] = theta.cos()
79
+ theta_rot[0,:,1,2] = -theta.sin()
80
+ theta_rot[0,:,2,1] = theta.sin()
81
+ theta_rot[0,:,2,2] = theta.cos()
82
+
83
+ mv = torch.empty((C,4,4), device=device)
84
+ mv[:] = torch.eye(4, device=device)
85
+ mv[:,:3,:3] = (theta_rot @ phi_rot).reshape(C,3,3)
86
+ mv = _translation(0, 0, -distance, device) @ mv
87
+
88
+ return mv, _projection(r,device)
89
+
90
+
91
+ def make_star_cameras_orthographic(az_count,pol_count,distance:float=10.,r=None,image_size=[512,512],device='cuda'):
92
+ mv, _ = make_star_cameras(az_count,pol_count,distance,r,image_size,device)
93
+ if r is None:
94
+ r = 1
95
+ return mv, _orthographic(r,device)
96
+
97
+
98
+ def get_camera(world_to_cam, fov_in_degrees=60, focal_length=1 / (2**0.5), cam_type='fov'):
99
+ # pytorch3d expects transforms as row-vectors, so flip rotation: https://github.com/facebookresearch/pytorch3d/issues/1183
100
+ R = world_to_cam[:3, :3].t()[None, ...]
101
+ T = world_to_cam[:3, 3][None, ...]
102
+ if cam_type == 'fov':
103
+ camera = FoVPerspectiveCameras(device=world_to_cam.device, R=R, T=T, fov=fov_in_degrees, degrees=True)
104
+ else:
105
+ focal_length = 1 / focal_length
106
+ camera = FoVOrthographicCameras(device=world_to_cam.device, R=R, T=T, min_x=-focal_length, max_x=focal_length, min_y=-focal_length, max_y=focal_length)
107
+ return camera
108
+
109
+
110
+ def get_cameras_list(azim_list, device, focal=2/1.35, dist=1.1):
111
+ ret = []
112
+ for azim in azim_list:
113
+ R, T = look_at_view_transform(dist, 0, azim)
114
+ w2c = torch.cat([R[0].T, T[0, :, None]], dim=1)
115
+ cameras: OrthographicCameras = get_camera(w2c, focal_length=focal, cam_type='orthogonal').to(device)
116
+ ret.append(cameras)
117
+ return ret
118
+
119
+
120
+ def to_py3d_mesh(vertices, faces, normals=None):
121
+ from pytorch3d.structures import Meshes
122
+ from pytorch3d.renderer.mesh.textures import TexturesVertex
123
+ mesh = Meshes(verts=[vertices], faces=[faces], textures=None)
124
+ if normals is None:
125
+ normals = mesh.verts_normals_packed()
126
+ # set normals as vertext colors
127
+ mesh.textures = TexturesVertex(verts_features=[normals / 2 + 0.5])
128
+ return mesh
129
+
130
+
131
+ def from_py3d_mesh(mesh):
132
+ return mesh.verts_list()[0], mesh.faces_list()[0], mesh.textures.verts_features_packed()
133
+
134
+
135
+ class Pix2FacesRenderer:
136
+ def __init__(self, device="cuda"):
137
+ self._glctx = dr.RasterizeCudaContext(device=device)
138
+ self.device = device
139
+ _warmup(self._glctx, device)
140
+
141
+ def transform_vertices(self, meshes: Meshes, cameras: CamerasBase):
142
+ vertices = cameras.transform_points_ndc(meshes.verts_padded())
143
+
144
+ perspective_correct = cameras.is_perspective()
145
+ znear = cameras.get_znear()
146
+ if isinstance(znear, torch.Tensor):
147
+ znear = znear.min().item()
148
+ z_clip = None if not perspective_correct or znear is None else znear / 2
149
+
150
+ if z_clip:
151
+ vertices = vertices[vertices[..., 2] >= cameras.get_znear()][None] # clip
152
+ vertices = vertices * torch.tensor([-1, -1, 1]).to(vertices)
153
+ vertices = torch.cat([vertices, torch.ones_like(vertices[..., :1])], dim=-1).to(torch.float32)
154
+ return vertices
155
+
156
+ def render_pix2faces_nvdiff(self, meshes: Meshes, cameras: CamerasBase, H=512, W=512):
157
+ meshes = meshes.to(self.device)
158
+ cameras = cameras.to(self.device)
159
+ vertices = self.transform_vertices(meshes, cameras)
160
+ faces = meshes.faces_packed().to(torch.int32)
161
+ rast_out,_ = dr.rasterize(self._glctx, vertices, faces, resolution=(H, W), grad_db=False) #C,H,W,4
162
+ pix_to_face = rast_out[..., -1].to(torch.int32) - 1
163
+ return pix_to_face
164
+
165
+ pix2faces_renderer = Pix2FacesRenderer()
166
+
167
+ def get_visible_faces(meshes: Meshes, cameras: CamerasBase, resolution=1024):
168
+ # pix_to_face = render_pix2faces_py3d(meshes, cameras, H=resolution, W=resolution)['pix_to_face']
169
+ pix_to_face = pix2faces_renderer.render_pix2faces_nvdiff(meshes, cameras, H=resolution, W=resolution)
170
+
171
+ unique_faces = torch.unique(pix_to_face.flatten())
172
+ unique_faces = unique_faces[unique_faces != -1]
173
+ return unique_faces
174
+
175
+
176
+ def project_color(meshes: Meshes, cameras: CamerasBase, pil_image: Image.Image, use_alpha=True, eps=0.05, resolution=1024, device="cuda") -> dict:
177
+ """
178
+ Projects color from a given image onto a 3D mesh.
179
+
180
+ Args:
181
+ meshes (pytorch3d.structures.Meshes): The 3D mesh object.
182
+ cameras (pytorch3d.renderer.cameras.CamerasBase): The camera object.
183
+ pil_image (PIL.Image.Image): The input image.
184
+ use_alpha (bool, optional): Whether to use the alpha channel of the image. Defaults to True.
185
+ eps (float, optional): The threshold for selecting visible faces. Defaults to 0.05.
186
+ resolution (int, optional): The resolution of the projection. Defaults to 1024.
187
+ device (str, optional): The device to use for computation. Defaults to "cuda".
188
+ debug (bool, optional): Whether to save debug images. Defaults to False.
189
+
190
+ Returns:
191
+ dict: A dictionary containing the following keys:
192
+ - "new_texture" (TexturesVertex): The updated texture with interpolated colors.
193
+ - "valid_verts" (Tensor of [M,3]): The indices of the vertices being projected.
194
+ - "valid_colors" (Tensor of [M,3]): The interpolated colors for the valid vertices.
195
+ """
196
+ meshes = meshes.to(device)
197
+ cameras = cameras.to(device)
198
+ image = torch.from_numpy(np.array(pil_image.convert("RGBA")) / 255.).permute((2, 0, 1)).float().to(device) # in CHW format of [0, 1.]
199
+ unique_faces = get_visible_faces(meshes, cameras, resolution=resolution)
200
+
201
+ # visible faces
202
+ faces_normals = meshes.faces_normals_packed()[unique_faces]
203
+ faces_normals = faces_normals / faces_normals.norm(dim=1, keepdim=True)
204
+ world_points = cameras.unproject_points(torch.tensor([[[0., 0., 0.1], [0., 0., 0.2]]]).to(device))[0]
205
+ view_direction = world_points[1] - world_points[0]
206
+ view_direction = view_direction / view_direction.norm(dim=0, keepdim=True)
207
+
208
+ # find invalid faces
209
+ cos_angles = (faces_normals * view_direction).sum(dim=1)
210
+ assert cos_angles.mean() < 0, f"The view direction is not correct. cos_angles.mean()={cos_angles.mean()}"
211
+ selected_faces = unique_faces[cos_angles < -eps]
212
+
213
+ # find verts
214
+ faces = meshes.faces_packed()[selected_faces] # [N, 3]
215
+ verts = torch.unique(faces.flatten()) # [N, 1]
216
+ verts_coordinates = meshes.verts_packed()[verts] # [N, 3]
217
+
218
+ # compute color
219
+ pt_tensor = cameras.transform_points(verts_coordinates)[..., :2] # NDC space points
220
+ valid = ~((pt_tensor.isnan()|(pt_tensor<-1)|(1<pt_tensor)).any(dim=1)) # checked, correct
221
+ valid_pt = pt_tensor[valid, :]
222
+ valid_idx = verts[valid]
223
+ valid_color = torch.nn.functional.grid_sample(image[None].flip((-1, -2)), valid_pt[None, :, None, :], align_corners=False, padding_mode="reflection", mode="bilinear")[0, :, :, 0].T.clamp(0, 1) # [N, 4], note that bicubic may give invalid value
224
+ alpha, valid_color = valid_color[:, 3:], valid_color[:, :3]
225
+ if not use_alpha:
226
+ alpha = torch.ones_like(alpha)
227
+
228
+ # modify color
229
+ old_colors = meshes.textures.verts_features_packed()
230
+ old_colors[valid_idx] = valid_color * alpha + old_colors[valid_idx] * (1 - alpha)
231
+ new_texture = TexturesVertex(verts_features=[old_colors])
232
+
233
+ valid_verts_normals = meshes.verts_normals_packed()[valid_idx]
234
+ valid_verts_normals = valid_verts_normals / valid_verts_normals.norm(dim=1, keepdim=True).clamp_min(0.001)
235
+ cos_angles = (valid_verts_normals * view_direction).sum(dim=1)
236
+ return {
237
+ "new_texture": new_texture,
238
+ "valid_verts": valid_idx,
239
+ "valid_colors": valid_color,
240
+ "valid_alpha": alpha,
241
+ "cos_angles": cos_angles,
242
+ }
243
+
244
+ def complete_unseen_vertex_color(meshes: Meshes, valid_index: torch.Tensor) -> dict:
245
+ """
246
+ meshes: the mesh with vertex color to be completed.
247
+ valid_index: the index of the valid vertices, where valid means colors are fixed. [V, 1]
248
+ """
249
+ valid_index = valid_index.to(meshes.device)
250
+ colors = meshes.textures.verts_features_packed() # [V, 3]
251
+ V = colors.shape[0]
252
+
253
+ invalid_index = torch.ones_like(colors[:, 0]).bool() # [V]
254
+ invalid_index[valid_index] = False
255
+ invalid_index = torch.arange(V).to(meshes.device)[invalid_index]
256
+
257
+ L = meshes.laplacian_packed()
258
+ E = torch.sparse_coo_tensor(torch.tensor([list(range(V))] * 2), torch.ones((V,)), size=(V, V)).to(meshes.device)
259
+ L = L + E
260
+ # import pdb; pdb.set_trace()
261
+ # E = torch.eye(V, layout=torch.sparse_coo, device=meshes.device)
262
+ # L = L + E
263
+ colored_count = torch.ones_like(colors[:, 0]) # [V]
264
+ colored_count[invalid_index] = 0
265
+ L_invalid = torch.index_select(L, 0, invalid_index) # sparse [IV, V]
266
+
267
+ total_colored = colored_count.sum()
268
+ coloring_round = 0
269
+ stage = "uncolored"
270
+ from tqdm import tqdm
271
+ pbar = tqdm(miniters=100)
272
+ while stage == "uncolored" or coloring_round > 0:
273
+ new_color = torch.matmul(L_invalid, colors * colored_count[:, None]) # [IV, 3]
274
+ new_count = torch.matmul(L_invalid, colored_count)[:, None] # [IV, 1]
275
+ colors[invalid_index] = torch.where(new_count > 0, new_color / new_count, colors[invalid_index])
276
+ colored_count[invalid_index] = (new_count[:, 0] > 0).float()
277
+
278
+ new_total_colored = colored_count.sum()
279
+ if new_total_colored > total_colored:
280
+ total_colored = new_total_colored
281
+ coloring_round += 1
282
+ else:
283
+ stage = "colored"
284
+ coloring_round -= 1
285
+ pbar.update(1)
286
+ if coloring_round > 10000:
287
+ print("coloring_round > 10000, break")
288
+ break
289
+ assert not torch.isnan(colors).any()
290
+ meshes.textures = TexturesVertex(verts_features=[colors])
291
+ return meshes
292
+
293
+
294
+ def multiview_color_projection(meshes: Meshes, image_list: List[Image.Image], cameras_list: List[CamerasBase]=None, camera_focal: float = 2 / 1.35, weights=None, eps=0.05, resolution=1024, device="cuda", reweight_with_cosangle="square", use_alpha=True, confidence_threshold=0.1, complete_unseen=False, below_confidence_strategy="smooth", distract_mask=None) -> Meshes:
295
+ """
296
+ Projects color from a given image onto a 3D mesh.
297
+
298
+ Args:
299
+ meshes (pytorch3d.structures.Meshes): The 3D mesh object, only one mesh.
300
+ image_list (PIL.Image.Image): List of images.
301
+ cameras_list (list): List of cameras.
302
+ camera_focal (float, optional): The focal length of the camera, if cameras_list is not passed. Defaults to 2 / 1.35.
303
+ weights (list, optional): List of weights for each image, for ['front', 'front_right', 'right', 'back', 'left', 'front_left']. Defaults to None.
304
+ eps (float, optional): The threshold for selecting visible faces. Defaults to 0.05.
305
+ resolution (int, optional): The resolution of the projection. Defaults to 1024.
306
+ device (str, optional): The device to use for computation. Defaults to "cuda".
307
+ reweight_with_cosangle (str, optional): Whether to reweight the color with the angle between the view direction and the vertex normal. Defaults to None.
308
+ use_alpha (bool, optional): Whether to use the alpha channel of the image. Defaults to True.
309
+ confidence_threshold (float, optional): The threshold for the confidence of the projected color, if final projection weight is less than this, we will use the original color. Defaults to 0.1.
310
+ complete_unseen (bool, optional): Whether to complete the unseen vertex color using laplacian. Defaults to False.
311
+
312
+ Returns:
313
+ Meshes: the colored mesh
314
+ """
315
+ # 1. preprocess inputs
316
+ if image_list is None:
317
+ raise ValueError("image_list is None")
318
+ if cameras_list is None:
319
+ raise ValueError("cameras_list is None")
320
+ if weights is None:
321
+ raise ValueError("weights is None, and can not be guessed from image_list")
322
+
323
+ # 2. run projection
324
+ meshes = meshes.clone().to(device)
325
+ if weights is None:
326
+ weights = [1. for _ in range(len(cameras_list))]
327
+ assert len(cameras_list) == len(image_list) == len(weights)
328
+ original_color = meshes.textures.verts_features_packed()
329
+ assert not torch.isnan(original_color).any()
330
+ texture_counts = torch.zeros_like(original_color[..., :1])
331
+ texture_values = torch.zeros_like(original_color)
332
+ max_texture_counts = torch.zeros_like(original_color[..., :1])
333
+ max_texture_values = torch.zeros_like(original_color)
334
+ for camera, image, weight in zip(cameras_list, image_list, weights):
335
+ ret = project_color(meshes, camera, image, eps=eps, resolution=resolution, device=device, use_alpha=use_alpha)
336
+ if reweight_with_cosangle == "linear":
337
+ weight = (ret['cos_angles'].abs() * weight)[:, None]
338
+ elif reweight_with_cosangle == "square":
339
+ weight = (ret['cos_angles'].abs() ** 2 * weight)[:, None]
340
+ if use_alpha:
341
+ weight = weight * ret['valid_alpha']
342
+ assert weight.min() > -0.0001
343
+ texture_counts[ret['valid_verts']] += weight
344
+ texture_values[ret['valid_verts']] += ret['valid_colors'] * weight
345
+ max_texture_values[ret['valid_verts']] = torch.where(weight > max_texture_counts[ret['valid_verts']], ret['valid_colors'], max_texture_values[ret['valid_verts']])
346
+ max_texture_counts[ret['valid_verts']] = torch.max(max_texture_counts[ret['valid_verts']], weight)
347
+
348
+ # Method2
349
+ texture_values = torch.where(texture_counts > confidence_threshold, texture_values / texture_counts, texture_values)
350
+ if below_confidence_strategy == "smooth":
351
+ texture_values = torch.where(texture_counts <= confidence_threshold, (original_color * (confidence_threshold - texture_counts) + texture_values) / confidence_threshold, texture_values)
352
+ elif below_confidence_strategy == "original":
353
+ texture_values = torch.where(texture_counts <= confidence_threshold, original_color, texture_values)
354
+ else:
355
+ raise ValueError(f"below_confidence_strategy={below_confidence_strategy} is not supported")
356
+ assert not torch.isnan(texture_values).any()
357
+ meshes.textures = TexturesVertex(verts_features=[texture_values])
358
+
359
+ if distract_mask is not None:
360
+ import cv2
361
+ pil_distract_mask = (distract_mask * 255).astype(np.uint8)
362
+ pil_distract_mask = cv2.erode(pil_distract_mask, np.ones((3, 3), np.uint8), iterations=2)
363
+ pil_distract_mask = Image.fromarray(pil_distract_mask)
364
+ ret = project_color(meshes, cameras_list[0], pil_distract_mask, eps=eps, resolution=resolution, device=device, use_alpha=use_alpha)
365
+ distract_valid_mask = ret['valid_colors'][:, 0] > 0.5
366
+ distract_invalid_index = ret['valid_verts'][~distract_valid_mask]
367
+
368
+ # invalid index's neighbors also should included
369
+ L = meshes.laplacian_packed()
370
+ # Convert invalid indices to a boolean mask
371
+ distract_invalid_mask = torch.zeros(meshes.verts_packed().shape[0:1], dtype=torch.bool, device=device)
372
+ distract_invalid_mask[distract_invalid_index] = True
373
+
374
+ # Find neighbors: multiply Laplacian with invalid_mask and check non-zero values
375
+ # Extract COO format (L.indices() gives [2, N] shape: row, col; L.values() gives values)
376
+ row_indices, col_indices = L.coalesce().indices()
377
+ invalid_rows = distract_invalid_mask[row_indices]
378
+ neighbor_indices = col_indices[invalid_rows]
379
+
380
+ # Combine original invalids with their neighbors
381
+ combined_invalid_mask = distract_invalid_mask.clone()
382
+ combined_invalid_mask[neighbor_indices] = True
383
+
384
+ # repeat
385
+ invalid_rows = combined_invalid_mask[row_indices]
386
+ neighbor_indices = col_indices[invalid_rows]
387
+ combined_invalid_mask[neighbor_indices] = True
388
+
389
+ # Apply to texture counts and values
390
+ texture_counts[combined_invalid_mask] = 0
391
+ texture_values[combined_invalid_mask] = 0
392
+
393
+
394
+ if complete_unseen:
395
+ meshes = complete_unseen_vertex_color(meshes, torch.arange(texture_values.shape[0]).to(device)[texture_counts[:, 0] >= confidence_threshold])
396
+ ret_mesh = meshes.detach()
397
+ del meshes
398
+ return ret_mesh
399
+
400
+
401
+ def meshlab_mesh_to_py3dmesh(mesh: ml.Mesh) -> Meshes:
402
+ verts = torch.from_numpy(mesh.vertex_matrix()).float()
403
+ faces = torch.from_numpy(mesh.face_matrix()).long()
404
+ colors = torch.from_numpy(mesh.vertex_color_matrix()[..., :3]).float()
405
+ textures = TexturesVertex(verts_features=[colors])
406
+ return Meshes(verts=[verts], faces=[faces], textures=textures)
407
+
408
+
409
+ def to_pyml_mesh(vertices,faces):
410
+ m1 = ml.Mesh(
411
+ vertex_matrix=vertices.cpu().float().numpy().astype(np.float64),
412
+ face_matrix=faces.cpu().long().numpy().astype(np.int32),
413
+ )
414
+ return m1
415
+
416
+
417
+ def simple_clean_mesh(pyml_mesh: ml.Mesh, apply_smooth=True, stepsmoothnum=1, apply_sub_divide=False, sub_divide_threshold=0.25):
418
+ ms = ml.MeshSet()
419
+ ms.add_mesh(pyml_mesh, "cube_mesh")
420
+
421
+ if apply_smooth:
422
+ ms.apply_filter("apply_coord_laplacian_smoothing", stepsmoothnum=stepsmoothnum, cotangentweight=False)
423
+ if apply_sub_divide: # 5s, slow
424
+ ms.apply_filter("meshing_repair_non_manifold_vertices")
425
+ ms.apply_filter("meshing_repair_non_manifold_edges", method='Remove Faces')
426
+ ms.apply_filter("meshing_surface_subdivision_loop", iterations=2, threshold=Percentage(sub_divide_threshold))
427
+ return meshlab_mesh_to_py3dmesh(ms.current_mesh())