2gnak dylanebert HF staff commited on
Commit
84a6427
0 Parent(s):

Duplicate from dylanebert/LGM-full

Browse files

Co-authored-by: Dylan Ebert <dylanebert@users.noreply.huggingface.co>

.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ pipeline_tag: image-to-3d
4
+ ---
5
+
6
+ # LGM Full
7
+
8
+ This custom pipeline encapsulates the full [LGM](https://huggingface.co/ashawkey/LGM) pipeline, including [multi-view diffusion](https://huggingface.co/ashawkey/imagedream-ipmv-diffusers).
9
+
10
+ It is provided as a resource for the [ML for 3D Course](https://huggingface.co/learn/ml-for-3d-course).
11
+
12
+ Original LGM paper: [LGM: Large Multi-View Gaussian Model for High-Resolution 3D Content Creation](https://huggingface.co/papers/2402.05054).
feature_extractor/preprocessor_config.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "crop_size": {
3
+ "height": 224,
4
+ "width": 224
5
+ },
6
+ "do_center_crop": true,
7
+ "do_convert_rgb": true,
8
+ "do_normalize": true,
9
+ "do_rescale": true,
10
+ "do_resize": true,
11
+ "feature_extractor_type": "CLIPFeatureExtractor",
12
+ "image_mean": [
13
+ 0.48145466,
14
+ 0.4578275,
15
+ 0.40821073
16
+ ],
17
+ "image_processor_type": "CLIPImageProcessor",
18
+ "image_std": [
19
+ 0.26862954,
20
+ 0.26130258,
21
+ 0.27577711
22
+ ],
23
+ "resample": 3,
24
+ "rescale_factor": 0.00392156862745098,
25
+ "size": {
26
+ "shortest_edge": 224
27
+ },
28
+ "use_square_size": false
29
+ }
image_encoder/config.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "laion/CLIP-ViT-H-14-laion2B-s32B-b79K",
3
+ "architectures": [
4
+ "CLIPVisionModel"
5
+ ],
6
+ "attention_dropout": 0.0,
7
+ "dropout": 0.0,
8
+ "hidden_act": "gelu",
9
+ "hidden_size": 1280,
10
+ "image_size": 224,
11
+ "initializer_factor": 1.0,
12
+ "initializer_range": 0.02,
13
+ "intermediate_size": 5120,
14
+ "layer_norm_eps": 1e-05,
15
+ "model_type": "clip_vision_model",
16
+ "num_attention_heads": 16,
17
+ "num_channels": 3,
18
+ "num_hidden_layers": 32,
19
+ "patch_size": 14,
20
+ "projection_dim": 1024,
21
+ "torch_dtype": "float16",
22
+ "transformers_version": "4.35.2"
23
+ }
image_encoder/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2a56cfd4ffcf40be097c430324ec184cc37187f6dafef128ef9225438a3c03c4
3
+ size 1261595704
lgm/config.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "_class_name": "LGM",
3
+ "_diffusers_version": "0.25.0"
4
+ }
lgm/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:79e5160e1fc45559515579a7e41ffc22606cf41c3ed8581b09dae9b4ce437099
3
+ size 830126192
lgm/lgm.py ADDED
@@ -0,0 +1,808 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import warnings
3
+ from functools import partial
4
+ from typing import Literal, Tuple
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from diff_gaussian_rasterization import (
10
+ GaussianRasterizationSettings,
11
+ GaussianRasterizer,
12
+ )
13
+ from diffusers import ConfigMixin, ModelMixin
14
+ from torch import Tensor, nn
15
+
16
+
17
+ def look_at(campos):
18
+ forward_vector = -campos / np.linalg.norm(campos, axis=-1)
19
+ up_vector = np.array([0, 1, 0], dtype=np.float32)
20
+ right_vector = np.cross(up_vector, forward_vector)
21
+ up_vector = np.cross(forward_vector, right_vector)
22
+ R = np.stack([right_vector, up_vector, forward_vector], axis=-1)
23
+ return R
24
+
25
+
26
+ def orbit_camera(elevation, azimuth, radius=1):
27
+ elevation = np.deg2rad(elevation)
28
+ azimuth = np.deg2rad(azimuth)
29
+ x = radius * np.cos(elevation) * np.sin(azimuth)
30
+ y = -radius * np.sin(elevation)
31
+ z = radius * np.cos(elevation) * np.cos(azimuth)
32
+ campos = np.array([x, y, z])
33
+ T = np.eye(4, dtype=np.float32)
34
+ T[:3, :3] = look_at(campos)
35
+ T[:3, 3] = campos
36
+ return T
37
+
38
+
39
+ def get_rays(pose, h, w, fovy, opengl=True):
40
+ x, y = torch.meshgrid(
41
+ torch.arange(w, device=pose.device),
42
+ torch.arange(h, device=pose.device),
43
+ indexing="xy",
44
+ )
45
+ x = x.flatten()
46
+ y = y.flatten()
47
+
48
+ cx = w * 0.5
49
+ cy = h * 0.5
50
+
51
+ focal = h * 0.5 / np.tan(0.5 * np.deg2rad(fovy))
52
+
53
+ camera_dirs = F.pad(
54
+ torch.stack(
55
+ [
56
+ (x - cx + 0.5) / focal,
57
+ (y - cy + 0.5) / focal * (-1.0 if opengl else 1.0),
58
+ ],
59
+ dim=-1,
60
+ ),
61
+ (0, 1),
62
+ value=(-1.0 if opengl else 1.0),
63
+ )
64
+
65
+ rays_d = camera_dirs @ pose[:3, :3].transpose(0, 1)
66
+ rays_o = pose[:3, 3].unsqueeze(0).expand_as(rays_d)
67
+
68
+ rays_o = rays_o.view(h, w, 3)
69
+ rays_d = F.normalize(rays_d, dim=-1).view(h, w, 3)
70
+
71
+ return rays_o, rays_d
72
+
73
+
74
+ class GaussianRenderer:
75
+ def __init__(self, fovy, output_size):
76
+ self.output_size = output_size
77
+
78
+ self.bg_color = torch.tensor([1, 1, 1], dtype=torch.float32, device="cuda")
79
+
80
+ zfar = 2.5
81
+ znear = 0.1
82
+ self.tan_half_fov = np.tan(0.5 * np.deg2rad(fovy))
83
+ self.proj_matrix = torch.zeros(4, 4, dtype=torch.float32)
84
+ self.proj_matrix[0, 0] = 1 / self.tan_half_fov
85
+ self.proj_matrix[1, 1] = 1 / self.tan_half_fov
86
+ self.proj_matrix[2, 2] = (zfar + znear) / (zfar - znear)
87
+ self.proj_matrix[3, 2] = -(zfar * znear) / (zfar - znear)
88
+ self.proj_matrix[2, 3] = 1
89
+
90
+ def render(
91
+ self,
92
+ gaussians,
93
+ cam_view,
94
+ cam_view_proj,
95
+ cam_pos,
96
+ bg_color=None,
97
+ scale_modifier=1,
98
+ ):
99
+ device = gaussians.device
100
+ B, V = cam_view.shape[:2]
101
+
102
+ images = []
103
+ alphas = []
104
+ for b in range(B):
105
+
106
+ means3D = gaussians[b, :, 0:3].contiguous().float()
107
+ opacity = gaussians[b, :, 3:4].contiguous().float()
108
+ scales = gaussians[b, :, 4:7].contiguous().float()
109
+ rotations = gaussians[b, :, 7:11].contiguous().float()
110
+ rgbs = gaussians[b, :, 11:].contiguous().float()
111
+
112
+ for v in range(V):
113
+ view_matrix = cam_view[b, v].float()
114
+ view_proj_matrix = cam_view_proj[b, v].float()
115
+ campos = cam_pos[b, v].float()
116
+
117
+ raster_settings = GaussianRasterizationSettings(
118
+ image_height=self.output_size,
119
+ image_width=self.output_size,
120
+ tanfovx=self.tan_half_fov,
121
+ tanfovy=self.tan_half_fov,
122
+ bg=self.bg_color if bg_color is None else bg_color,
123
+ scale_modifier=scale_modifier,
124
+ viewmatrix=view_matrix,
125
+ projmatrix=view_proj_matrix,
126
+ sh_degree=0,
127
+ campos=campos,
128
+ prefiltered=False,
129
+ debug=False,
130
+ )
131
+
132
+ rasterizer = GaussianRasterizer(raster_settings=raster_settings)
133
+
134
+ rendered_image, _, _, rendered_alpha = rasterizer(
135
+ means3D=means3D,
136
+ means2D=torch.zeros_like(
137
+ means3D, dtype=torch.float32, device=device
138
+ ),
139
+ shs=None,
140
+ colors_precomp=rgbs,
141
+ opacities=opacity,
142
+ scales=scales,
143
+ rotations=rotations,
144
+ cov3D_precomp=None,
145
+ )
146
+
147
+ rendered_image = rendered_image.clamp(0, 1)
148
+
149
+ images.append(rendered_image)
150
+ alphas.append(rendered_alpha)
151
+
152
+ images = torch.stack(images, dim=0).view(
153
+ B, V, 3, self.output_size, self.output_size
154
+ )
155
+ alphas = torch.stack(alphas, dim=0).view(
156
+ B, V, 1, self.output_size, self.output_size
157
+ )
158
+
159
+ return {"image": images, "alpha": alphas}
160
+
161
+ def save_ply(self, gaussians, path):
162
+ assert gaussians.shape[0] == 1, "only support batch size 1"
163
+
164
+ from plyfile import PlyData, PlyElement
165
+
166
+ means3D = gaussians[0, :, 0:3].contiguous().float()
167
+ opacity = gaussians[0, :, 3:4].contiguous().float()
168
+ scales = gaussians[0, :, 4:7].contiguous().float()
169
+ rotations = gaussians[0, :, 7:11].contiguous().float()
170
+ shs = gaussians[0, :, 11:].unsqueeze(1).contiguous().float()
171
+
172
+ mask = opacity.squeeze(-1) >= 0.005
173
+ means3D = means3D[mask]
174
+ opacity = opacity[mask]
175
+ scales = scales[mask]
176
+ rotations = rotations[mask]
177
+ shs = shs[mask]
178
+
179
+ opacity = opacity.clamp(1e-6, 1 - 1e-6)
180
+ opacity = torch.log(opacity / (1 - opacity))
181
+ scales = torch.log(scales + 1e-8)
182
+ shs = (shs - 0.5) / 0.28209479177387814
183
+
184
+ xyzs = means3D.detach().cpu().numpy()
185
+ f_dc = (
186
+ shs.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy()
187
+ )
188
+ opacities = opacity.detach().cpu().numpy()
189
+ scales = scales.detach().cpu().numpy()
190
+ rotations = rotations.detach().cpu().numpy()
191
+
192
+ h = ["x", "y", "z"]
193
+ for i in range(f_dc.shape[1]):
194
+ h.append("f_dc_{}".format(i))
195
+ h.append("opacity")
196
+ for i in range(scales.shape[1]):
197
+ h.append("scale_{}".format(i))
198
+ for i in range(rotations.shape[1]):
199
+ h.append("rot_{}".format(i))
200
+
201
+ dtype_full = [(attribute, "f4") for attribute in h]
202
+
203
+ elements = np.empty(xyzs.shape[0], dtype=dtype_full)
204
+ attributes = np.concatenate((xyzs, f_dc, opacities, scales, rotations), axis=1)
205
+ elements[:] = list(map(tuple, attributes))
206
+ el = PlyElement.describe(elements, "vertex")
207
+
208
+ PlyData([el]).write(path)
209
+
210
+
211
+ class LGM(ModelMixin, ConfigMixin):
212
+ def __init__(self):
213
+ super().__init__()
214
+
215
+ self.input_size = 256
216
+ self.splat_size = 128
217
+ self.output_size = 512
218
+ self.radius = 1.5
219
+ self.fovy = 49.1
220
+
221
+ self.unet = UNet(
222
+ 9,
223
+ 14,
224
+ down_channels=(64, 128, 256, 512, 1024, 1024),
225
+ down_attention=(False, False, False, True, True, True),
226
+ mid_attention=True,
227
+ up_channels=(1024, 1024, 512, 256, 128),
228
+ up_attention=(True, True, True, False, False),
229
+ )
230
+
231
+ self.conv = nn.Conv2d(14, 14, kernel_size=1)
232
+ self.gs = GaussianRenderer(self.fovy, self.output_size)
233
+
234
+ self.pos_act = lambda x: x.clamp(-1, 1)
235
+ self.scale_act = lambda x: 0.1 * F.softplus(x)
236
+ self.opacity_act = lambda x: torch.sigmoid(x)
237
+ self.rot_act = F.normalize
238
+ self.rgb_act = lambda x: 0.5 * torch.tanh(x) + 0.5
239
+
240
+ def prepare_default_rays(self, device, elevation=0):
241
+ cam_poses = np.stack(
242
+ [
243
+ orbit_camera(elevation, 0, radius=self.radius),
244
+ orbit_camera(elevation, 90, radius=self.radius),
245
+ orbit_camera(elevation, 180, radius=self.radius),
246
+ orbit_camera(elevation, 270, radius=self.radius),
247
+ ],
248
+ axis=0,
249
+ )
250
+ cam_poses = torch.from_numpy(cam_poses)
251
+
252
+ rays_embeddings = []
253
+ for i in range(cam_poses.shape[0]):
254
+ rays_o, rays_d = get_rays(
255
+ cam_poses[i], self.input_size, self.input_size, self.fovy
256
+ )
257
+ rays_plucker = torch.cat(
258
+ [torch.cross(rays_o, rays_d, dim=-1), rays_d], dim=-1
259
+ )
260
+ rays_embeddings.append(rays_plucker)
261
+
262
+ rays_embeddings = (
263
+ torch.stack(rays_embeddings, dim=0)
264
+ .permute(0, 3, 1, 2)
265
+ .contiguous()
266
+ .to(device)
267
+ )
268
+
269
+ return rays_embeddings
270
+
271
+ def forward(self, images):
272
+ B, V, C, H, W = images.shape
273
+ images = images.view(B * V, C, H, W)
274
+
275
+ x = self.unet(images)
276
+ x = self.conv(x)
277
+
278
+ x = x.reshape(B, 4, 14, self.splat_size, self.splat_size)
279
+
280
+ x = x.permute(0, 1, 3, 4, 2).reshape(B, -1, 14)
281
+
282
+ pos = self.pos_act(x[..., 0:3])
283
+ opacity = self.opacity_act(x[..., 3:4])
284
+ scale = self.scale_act(x[..., 4:7])
285
+ rotation = self.rot_act(x[..., 7:11])
286
+ rgbs = self.rgb_act(x[..., 11:])
287
+
288
+ q = torch.tensor([0, 0, 1, 0], dtype=pos.dtype, device=pos.device)
289
+ R = torch.tensor(
290
+ [
291
+ [-1, 0, 0],
292
+ [0, -1, 0],
293
+ [0, 0, 1],
294
+ ],
295
+ dtype=pos.dtype,
296
+ device=pos.device,
297
+ )
298
+
299
+ pos = torch.matmul(pos, R.T)
300
+
301
+ def multiply_quat(q1, q2):
302
+ w1, x1, y1, z1 = q1.unbind(-1)
303
+ w2, x2, y2, z2 = q2.unbind(-1)
304
+ w = w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2
305
+ x = w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2
306
+ y = w1 * y2 + y1 * w2 + z1 * x2 - x1 * z2
307
+ z = w1 * z2 + z1 * w2 + x1 * y2 - y1 * x2
308
+ return torch.stack([w, x, y, z], dim=-1)
309
+
310
+ for i in range(B):
311
+ rotation[i, :] = multiply_quat(q, rotation[i, :])
312
+
313
+ gaussians = torch.cat([pos, opacity, scale, rotation, rgbs], dim=-1)
314
+
315
+ return gaussians
316
+
317
+
318
+ # =============================================================================
319
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
320
+ #
321
+ # This source code is licensed under the Apache License, Version 2.0
322
+ # found in the LICENSE file in the root directory of this source tree.
323
+
324
+ # References:
325
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
326
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
327
+ # =============================================================================
328
+ XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
329
+ try:
330
+ if XFORMERS_ENABLED:
331
+ from xformers.ops import memory_efficient_attention, unbind
332
+
333
+ XFORMERS_AVAILABLE = True
334
+ warnings.warn("xFormers is available (Attention)")
335
+ else:
336
+ warnings.warn("xFormers is disabled (Attention)")
337
+ raise ImportError
338
+ except ImportError:
339
+ XFORMERS_AVAILABLE = False
340
+ warnings.warn("xFormers is not available (Attention)")
341
+
342
+
343
+ class Attention(nn.Module):
344
+ def __init__(
345
+ self,
346
+ dim: int,
347
+ num_heads: int = 8,
348
+ qkv_bias: bool = False,
349
+ proj_bias: bool = True,
350
+ attn_drop: float = 0.0,
351
+ proj_drop: float = 0.0,
352
+ ) -> None:
353
+ super().__init__()
354
+ self.num_heads = num_heads
355
+ head_dim = dim // num_heads
356
+ self.scale = head_dim**-0.5
357
+
358
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
359
+ self.attn_drop = nn.Dropout(attn_drop)
360
+ self.proj = nn.Linear(dim, dim, bias=proj_bias)
361
+ self.proj_drop = nn.Dropout(proj_drop)
362
+
363
+ def forward(self, x: Tensor) -> Tensor:
364
+ B, N, C = x.shape
365
+ qkv = (
366
+ self.qkv(x)
367
+ .reshape(B, N, 3, self.num_heads, C // self.num_heads)
368
+ .permute(2, 0, 3, 1, 4)
369
+ )
370
+
371
+ q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
372
+ attn = q @ k.transpose(-2, -1)
373
+
374
+ attn = attn.softmax(dim=-1)
375
+ attn = self.attn_drop(attn)
376
+
377
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
378
+ x = self.proj(x)
379
+ x = self.proj_drop(x)
380
+ return x
381
+
382
+
383
+ class MemEffAttention(Attention):
384
+ def forward(self, x: Tensor, attn_bias=None) -> Tensor:
385
+ if not XFORMERS_AVAILABLE:
386
+ if attn_bias is not None:
387
+ raise AssertionError("xFormers is required for using nested tensors")
388
+ return super().forward(x)
389
+
390
+ B, N, C = x.shape
391
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
392
+
393
+ q, k, v = unbind(qkv, 2)
394
+
395
+ x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
396
+ x = x.reshape([B, N, C])
397
+
398
+ x = self.proj(x)
399
+ x = self.proj_drop(x)
400
+ return x
401
+
402
+
403
+ class CrossAttention(nn.Module):
404
+ def __init__(
405
+ self,
406
+ dim: int,
407
+ dim_q: int,
408
+ dim_k: int,
409
+ dim_v: int,
410
+ num_heads: int = 8,
411
+ qkv_bias: bool = False,
412
+ proj_bias: bool = True,
413
+ attn_drop: float = 0.0,
414
+ proj_drop: float = 0.0,
415
+ ) -> None:
416
+ super().__init__()
417
+ self.dim = dim
418
+ self.num_heads = num_heads
419
+ head_dim = dim // num_heads
420
+ self.scale = head_dim**-0.5
421
+
422
+ self.to_q = nn.Linear(dim_q, dim, bias=qkv_bias)
423
+ self.to_k = nn.Linear(dim_k, dim, bias=qkv_bias)
424
+ self.to_v = nn.Linear(dim_v, dim, bias=qkv_bias)
425
+ self.attn_drop = nn.Dropout(attn_drop)
426
+ self.proj = nn.Linear(dim, dim, bias=proj_bias)
427
+ self.proj_drop = nn.Dropout(proj_drop)
428
+
429
+ def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
430
+ B, N, _ = q.shape
431
+ M = k.shape[1]
432
+
433
+ q = self.scale * self.to_q(q).reshape(
434
+ B, N, self.num_heads, self.dim // self.num_heads
435
+ ).permute(0, 2, 1, 3)
436
+ k = (
437
+ self.to_k(k)
438
+ .reshape(B, M, self.num_heads, self.dim // self.num_heads)
439
+ .permute(0, 2, 1, 3)
440
+ )
441
+ v = (
442
+ self.to_v(v)
443
+ .reshape(B, M, self.num_heads, self.dim // self.num_heads)
444
+ .permute(0, 2, 1, 3)
445
+ )
446
+
447
+ attn = q @ k.transpose(-2, -1)
448
+
449
+ attn = attn.softmax(dim=-1)
450
+ attn = self.attn_drop(attn)
451
+
452
+ x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
453
+ x = self.proj(x)
454
+ x = self.proj_drop(x)
455
+ return x
456
+
457
+
458
+ class MemEffCrossAttention(CrossAttention):
459
+ def forward(self, q: Tensor, k: Tensor, v: Tensor, attn_bias=None) -> Tensor:
460
+ if not XFORMERS_AVAILABLE:
461
+ if attn_bias is not None:
462
+ raise AssertionError("xFormers is required for using nested tensors")
463
+ return super().forward(q, k, v)
464
+
465
+ B, N, _ = q.shape
466
+ M = k.shape[1]
467
+
468
+ q = self.scale * self.to_q(q).reshape(
469
+ B, N, self.num_heads, self.dim // self.num_heads
470
+ )
471
+ k = self.to_k(k).reshape(B, M, self.num_heads, self.dim // self.num_heads)
472
+ v = self.to_v(v).reshape(B, M, self.num_heads, self.dim // self.num_heads)
473
+
474
+ x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
475
+ x = x.reshape(B, N, -1)
476
+
477
+ x = self.proj(x)
478
+ x = self.proj_drop(x)
479
+ return x
480
+
481
+
482
+ # =============================================================================
483
+ # End of xFormers
484
+
485
+
486
+ class MVAttention(nn.Module):
487
+ def __init__(
488
+ self,
489
+ dim: int,
490
+ num_heads: int = 8,
491
+ qkv_bias: bool = False,
492
+ proj_bias: bool = True,
493
+ attn_drop: float = 0.0,
494
+ proj_drop: float = 0.0,
495
+ groups: int = 32,
496
+ eps: float = 1e-5,
497
+ residual: bool = True,
498
+ skip_scale: float = 1,
499
+ num_frames: int = 4,
500
+ ):
501
+ super().__init__()
502
+
503
+ self.residual = residual
504
+ self.skip_scale = skip_scale
505
+ self.num_frames = num_frames
506
+
507
+ self.norm = nn.GroupNorm(
508
+ num_groups=groups, num_channels=dim, eps=eps, affine=True
509
+ )
510
+ self.attn = MemEffAttention(
511
+ dim, num_heads, qkv_bias, proj_bias, attn_drop, proj_drop
512
+ )
513
+
514
+ def forward(self, x):
515
+ BV, C, H, W = x.shape
516
+ B = BV // self.num_frames
517
+
518
+ res = x
519
+ x = self.norm(x)
520
+
521
+ x = (
522
+ x.reshape(B, self.num_frames, C, H, W)
523
+ .permute(0, 1, 3, 4, 2)
524
+ .reshape(B, -1, C)
525
+ )
526
+ x = self.attn(x)
527
+ x = (
528
+ x.reshape(B, self.num_frames, H, W, C)
529
+ .permute(0, 1, 4, 2, 3)
530
+ .reshape(BV, C, H, W)
531
+ )
532
+
533
+ if self.residual:
534
+ x = (x + res) * self.skip_scale
535
+ return x
536
+
537
+
538
+ class ResnetBlock(nn.Module):
539
+ def __init__(
540
+ self,
541
+ in_channels: int,
542
+ out_channels: int,
543
+ resample: Literal["default", "up", "down"] = "default",
544
+ groups: int = 32,
545
+ eps: float = 1e-5,
546
+ skip_scale: float = 1,
547
+ ):
548
+ super().__init__()
549
+
550
+ self.in_channels = in_channels
551
+ self.out_channels = out_channels
552
+ self.skip_scale = skip_scale
553
+
554
+ self.norm1 = nn.GroupNorm(
555
+ num_groups=groups, num_channels=in_channels, eps=eps, affine=True
556
+ )
557
+ self.conv1 = nn.Conv2d(
558
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
559
+ )
560
+
561
+ self.norm2 = nn.GroupNorm(
562
+ num_groups=groups, num_channels=out_channels, eps=eps, affine=True
563
+ )
564
+ self.conv2 = nn.Conv2d(
565
+ out_channels, out_channels, kernel_size=3, stride=1, padding=1
566
+ )
567
+
568
+ self.act = F.silu
569
+
570
+ self.resample = None
571
+ if resample == "up":
572
+ self.resample = partial(F.interpolate, scale_factor=2.0, mode="nearest")
573
+ elif resample == "down":
574
+ self.resample = nn.AvgPool2d(kernel_size=2, stride=2)
575
+
576
+ self.shortcut = nn.Identity()
577
+ if self.in_channels != self.out_channels:
578
+ self.shortcut = nn.Conv2d(
579
+ in_channels, out_channels, kernel_size=1, bias=True
580
+ )
581
+
582
+ def forward(self, x):
583
+ res = x
584
+ x = self.norm1(x)
585
+ x = self.act(x)
586
+ if self.resample:
587
+ res = self.resample(res)
588
+ x = self.resample(x)
589
+ x = self.conv1(x)
590
+ x = self.norm2(x)
591
+ x = self.act(x)
592
+ x = self.conv2(x)
593
+ x = (x + self.shortcut(res)) * self.skip_scale
594
+ return x
595
+
596
+
597
+ class DownBlock(nn.Module):
598
+ def __init__(
599
+ self,
600
+ in_channels: int,
601
+ out_channels: int,
602
+ num_layers: int = 1,
603
+ downsample: bool = True,
604
+ attention: bool = True,
605
+ attention_heads: int = 16,
606
+ skip_scale: float = 1,
607
+ ):
608
+ super().__init__()
609
+
610
+ nets = []
611
+ attns = []
612
+ for i in range(num_layers):
613
+ in_channels = in_channels if i == 0 else out_channels
614
+ nets.append(ResnetBlock(in_channels, out_channels, skip_scale=skip_scale))
615
+ if attention:
616
+ attns.append(
617
+ MVAttention(out_channels, attention_heads, skip_scale=skip_scale)
618
+ )
619
+ else:
620
+ attns.append(None)
621
+ self.nets = nn.ModuleList(nets)
622
+ self.attns = nn.ModuleList(attns)
623
+
624
+ self.downsample = None
625
+ if downsample:
626
+ self.downsample = nn.Conv2d(
627
+ out_channels, out_channels, kernel_size=3, stride=2, padding=1
628
+ )
629
+
630
+ def forward(self, x):
631
+ xs = []
632
+ for attn, net in zip(self.attns, self.nets):
633
+ x = net(x)
634
+ if attn:
635
+ x = attn(x)
636
+ xs.append(x)
637
+ if self.downsample:
638
+ x = self.downsample(x)
639
+ xs.append(x)
640
+ return x, xs
641
+
642
+
643
+ class MidBlock(nn.Module):
644
+ def __init__(
645
+ self,
646
+ in_channels: int,
647
+ num_layers: int = 1,
648
+ attention: bool = True,
649
+ attention_heads: int = 16,
650
+ skip_scale: float = 1,
651
+ ):
652
+ super().__init__()
653
+
654
+ nets = []
655
+ attns = []
656
+ nets.append(ResnetBlock(in_channels, in_channels, skip_scale=skip_scale))
657
+ for _ in range(num_layers):
658
+ nets.append(ResnetBlock(in_channels, in_channels, skip_scale=skip_scale))
659
+ if attention:
660
+ attns.append(
661
+ MVAttention(in_channels, attention_heads, skip_scale=skip_scale)
662
+ )
663
+ else:
664
+ attns.append(None)
665
+ self.nets = nn.ModuleList(nets)
666
+ self.attns = nn.ModuleList(attns)
667
+
668
+ def forward(self, x):
669
+ x = self.nets[0](x)
670
+ for attn, net in zip(self.attns, self.nets[1:]):
671
+ if attn:
672
+ x = attn(x)
673
+ x = net(x)
674
+ return x
675
+
676
+
677
+ class UpBlock(nn.Module):
678
+ def __init__(
679
+ self,
680
+ in_channels: int,
681
+ prev_out_channels: int,
682
+ out_channels: int,
683
+ num_layers: int = 1,
684
+ upsample: bool = True,
685
+ attention: bool = True,
686
+ attention_heads: int = 16,
687
+ skip_scale: float = 1,
688
+ ):
689
+ super().__init__()
690
+
691
+ nets = []
692
+ attns = []
693
+ for i in range(num_layers):
694
+ cin = in_channels if i == 0 else out_channels
695
+ cskip = prev_out_channels if (i == num_layers - 1) else out_channels
696
+
697
+ nets.append(ResnetBlock(cin + cskip, out_channels, skip_scale=skip_scale))
698
+ if attention:
699
+ attns.append(
700
+ MVAttention(out_channels, attention_heads, skip_scale=skip_scale)
701
+ )
702
+ else:
703
+ attns.append(None)
704
+ self.nets = nn.ModuleList(nets)
705
+ self.attns = nn.ModuleList(attns)
706
+
707
+ self.upsample = None
708
+ if upsample:
709
+ self.upsample = nn.Conv2d(
710
+ out_channels, out_channels, kernel_size=3, stride=1, padding=1
711
+ )
712
+
713
+ def forward(self, x, xs):
714
+ for attn, net in zip(self.attns, self.nets):
715
+ res_x = xs[-1]
716
+ xs = xs[:-1]
717
+ x = torch.cat([x, res_x], dim=1)
718
+ x = net(x)
719
+ if attn:
720
+ x = attn(x)
721
+ if self.upsample:
722
+ x = F.interpolate(x, scale_factor=2.0, mode="nearest")
723
+ x = self.upsample(x)
724
+ return x
725
+
726
+
727
+ class UNet(nn.Module):
728
+ def __init__(
729
+ self,
730
+ in_channels: int = 9,
731
+ out_channels: int = 14,
732
+ down_channels: Tuple[int, ...] = (64, 128, 256, 512, 1024, 1024),
733
+ down_attention: Tuple[bool, ...] = (False, False, False, True, True, True),
734
+ mid_attention: bool = True,
735
+ up_channels: Tuple[int, ...] = (1024, 1024, 512, 256, 128),
736
+ up_attention: Tuple[bool, ...] = (True, True, True, False, False),
737
+ layers_per_block: int = 2,
738
+ skip_scale: float = np.sqrt(0.5),
739
+ ):
740
+ super().__init__()
741
+
742
+ self.conv_in = nn.Conv2d(
743
+ in_channels, down_channels[0], kernel_size=3, stride=1, padding=1
744
+ )
745
+
746
+ down_blocks = []
747
+ cout = down_channels[0]
748
+ for i in range(len(down_channels)):
749
+ cin = cout
750
+ cout = down_channels[i]
751
+
752
+ down_blocks.append(
753
+ DownBlock(
754
+ cin,
755
+ cout,
756
+ num_layers=layers_per_block,
757
+ downsample=(i != len(down_channels) - 1),
758
+ attention=down_attention[i],
759
+ skip_scale=skip_scale,
760
+ )
761
+ )
762
+ self.down_blocks = nn.ModuleList(down_blocks)
763
+
764
+ self.mid_block = MidBlock(
765
+ down_channels[-1], attention=mid_attention, skip_scale=skip_scale
766
+ )
767
+
768
+ up_blocks = []
769
+ cout = up_channels[0]
770
+ for i in range(len(up_channels)):
771
+ cin = cout
772
+ cout = up_channels[i]
773
+ cskip = down_channels[max(-2 - i, -len(down_channels))]
774
+
775
+ up_blocks.append(
776
+ UpBlock(
777
+ cin,
778
+ cskip,
779
+ cout,
780
+ num_layers=layers_per_block + 1,
781
+ upsample=(i != len(up_channels) - 1),
782
+ attention=up_attention[i],
783
+ skip_scale=skip_scale,
784
+ )
785
+ )
786
+ self.up_blocks = nn.ModuleList(up_blocks)
787
+ self.norm_out = nn.GroupNorm(
788
+ num_channels=up_channels[-1], num_groups=32, eps=1e-5
789
+ )
790
+ self.conv_out = nn.Conv2d(
791
+ up_channels[-1], out_channels, kernel_size=3, stride=1, padding=1
792
+ )
793
+
794
+ def forward(self, x):
795
+ x = self.conv_in(x)
796
+ xss = [x]
797
+ for block in self.down_blocks:
798
+ x, xs = block(x)
799
+ xss.extend(xs)
800
+ x = self.mid_block(x)
801
+ for block in self.up_blocks:
802
+ xs = xss[-len(block.nets) :]
803
+ xss = xss[: -len(block.nets)]
804
+ x = block(x, xs)
805
+ x = self.norm_out(x)
806
+ x = F.silu(x)
807
+ x = self.conv_out(x)
808
+ return x
model_index.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "LGMFullPipeline",
3
+ "_diffusers_version": "0.25.0",
4
+ "feature_extractor": ["transformers", "CLIPImageProcessor"],
5
+ "image_encoder": ["transformers", "CLIPVisionModel"],
6
+ "requires_safety_checker": false,
7
+ "scheduler": ["diffusers", "DDIMScheduler"],
8
+ "text_encoder": ["transformers", "CLIPTextModel"],
9
+ "tokenizer": ["transformers", "CLIPTokenizer"],
10
+ "unet": ["mv_unet", "MultiViewUNetModel"],
11
+ "vae": ["diffusers", "AutoencoderKL"],
12
+ "lgm": ["lgm", "LGM"]
13
+ }
pipeline.py ADDED
@@ -0,0 +1,1620 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ import math
3
+ from inspect import isfunction
4
+ from typing import Any, Callable, List, Optional, Union
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ import torchvision.transforms.functional as TF
11
+
12
+ # require xformers!
13
+ import xformers
14
+ import xformers.ops
15
+ from diffusers import AutoencoderKL, DiffusionPipeline
16
+ from diffusers.configuration_utils import ConfigMixin, FrozenDict
17
+ from diffusers.models.modeling_utils import ModelMixin
18
+ from diffusers.schedulers import DDIMScheduler
19
+ from diffusers.utils import (
20
+ deprecate,
21
+ is_accelerate_available,
22
+ is_accelerate_version,
23
+ logging,
24
+ )
25
+ from diffusers.utils.torch_utils import randn_tensor
26
+ from einops import rearrange, repeat
27
+ from kiui.cam import orbit_camera
28
+ from transformers import (
29
+ CLIPImageProcessor,
30
+ CLIPTextModel,
31
+ CLIPTokenizer,
32
+ CLIPVisionModel,
33
+ )
34
+
35
+
36
+ def get_camera(
37
+ num_frames,
38
+ elevation=15,
39
+ azimuth_start=0,
40
+ azimuth_span=360,
41
+ blender_coord=True,
42
+ extra_view=False,
43
+ ):
44
+ angle_gap = azimuth_span / num_frames
45
+ cameras = []
46
+ for azimuth in np.arange(azimuth_start, azimuth_span + azimuth_start, angle_gap):
47
+
48
+ pose = orbit_camera(
49
+ -elevation, azimuth, radius=1
50
+ ) # kiui's elevation is negated, [4, 4]
51
+
52
+ # opengl to blender
53
+ if blender_coord:
54
+ pose[2] *= -1
55
+ pose[[1, 2]] = pose[[2, 1]]
56
+
57
+ cameras.append(pose.flatten())
58
+
59
+ if extra_view:
60
+ cameras.append(np.zeros_like(cameras[0]))
61
+
62
+ return torch.from_numpy(np.stack(cameras, axis=0)).float() # [num_frames, 16]
63
+
64
+
65
+ def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
66
+ """
67
+ Create sinusoidal timestep embeddings.
68
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
69
+ These may be fractional.
70
+ :param dim: the dimension of the output.
71
+ :param max_period: controls the minimum frequency of the embeddings.
72
+ :return: an [N x dim] Tensor of positional embeddings.
73
+ """
74
+ if not repeat_only:
75
+ half = dim // 2
76
+ freqs = torch.exp(
77
+ -math.log(max_period)
78
+ * torch.arange(start=0, end=half, dtype=torch.float32)
79
+ / half
80
+ ).to(device=timesteps.device)
81
+ args = timesteps[:, None] * freqs[None]
82
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
83
+ if dim % 2:
84
+ embedding = torch.cat(
85
+ [embedding, torch.zeros_like(embedding[:, :1])], dim=-1
86
+ )
87
+ else:
88
+ embedding = repeat(timesteps, "b -> b d", d=dim)
89
+ # import pdb; pdb.set_trace()
90
+ return embedding
91
+
92
+
93
+ def zero_module(module):
94
+ """
95
+ Zero out the parameters of a module and return it.
96
+ """
97
+ for p in module.parameters():
98
+ p.detach().zero_()
99
+ return module
100
+
101
+
102
+ def conv_nd(dims, *args, **kwargs):
103
+ """
104
+ Create a 1D, 2D, or 3D convolution module.
105
+ """
106
+ if dims == 1:
107
+ return nn.Conv1d(*args, **kwargs)
108
+ elif dims == 2:
109
+ return nn.Conv2d(*args, **kwargs)
110
+ elif dims == 3:
111
+ return nn.Conv3d(*args, **kwargs)
112
+ raise ValueError(f"unsupported dimensions: {dims}")
113
+
114
+
115
+ def avg_pool_nd(dims, *args, **kwargs):
116
+ """
117
+ Create a 1D, 2D, or 3D average pooling module.
118
+ """
119
+ if dims == 1:
120
+ return nn.AvgPool1d(*args, **kwargs)
121
+ elif dims == 2:
122
+ return nn.AvgPool2d(*args, **kwargs)
123
+ elif dims == 3:
124
+ return nn.AvgPool3d(*args, **kwargs)
125
+ raise ValueError(f"unsupported dimensions: {dims}")
126
+
127
+
128
+ def default(val, d):
129
+ if val is not None:
130
+ return val
131
+ return d() if isfunction(d) else d
132
+
133
+
134
+ class GEGLU(nn.Module):
135
+ def __init__(self, dim_in, dim_out):
136
+ super().__init__()
137
+ self.proj = nn.Linear(dim_in, dim_out * 2)
138
+
139
+ def forward(self, x):
140
+ x, gate = self.proj(x).chunk(2, dim=-1)
141
+ return x * F.gelu(gate)
142
+
143
+
144
+ class FeedForward(nn.Module):
145
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
146
+ super().__init__()
147
+ inner_dim = int(dim * mult)
148
+ dim_out = default(dim_out, dim)
149
+ project_in = (
150
+ nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
151
+ if not glu
152
+ else GEGLU(dim, inner_dim)
153
+ )
154
+
155
+ self.net = nn.Sequential(
156
+ project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
157
+ )
158
+
159
+ def forward(self, x):
160
+ return self.net(x)
161
+
162
+
163
+ class MemoryEfficientCrossAttention(nn.Module):
164
+ # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
165
+ def __init__(
166
+ self,
167
+ query_dim,
168
+ context_dim=None,
169
+ heads=8,
170
+ dim_head=64,
171
+ dropout=0.0,
172
+ ip_dim=0,
173
+ ip_weight=1,
174
+ ):
175
+ super().__init__()
176
+
177
+ inner_dim = dim_head * heads
178
+ context_dim = default(context_dim, query_dim)
179
+
180
+ self.heads = heads
181
+ self.dim_head = dim_head
182
+
183
+ self.ip_dim = ip_dim
184
+ self.ip_weight = ip_weight
185
+
186
+ if self.ip_dim > 0:
187
+ self.to_k_ip = nn.Linear(context_dim, inner_dim, bias=False)
188
+ self.to_v_ip = nn.Linear(context_dim, inner_dim, bias=False)
189
+
190
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
191
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
192
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
193
+
194
+ self.to_out = nn.Sequential(
195
+ nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
196
+ )
197
+ self.attention_op: Optional[Any] = None
198
+
199
+ def forward(self, x, context=None):
200
+ q = self.to_q(x)
201
+ context = default(context, x)
202
+
203
+ if self.ip_dim > 0:
204
+ # context: [B, 77 + 16(ip), 1024]
205
+ token_len = context.shape[1]
206
+ context_ip = context[:, -self.ip_dim :, :]
207
+ k_ip = self.to_k_ip(context_ip)
208
+ v_ip = self.to_v_ip(context_ip)
209
+ context = context[:, : (token_len - self.ip_dim), :]
210
+
211
+ k = self.to_k(context)
212
+ v = self.to_v(context)
213
+
214
+ b, _, _ = q.shape
215
+ q, k, v = map(
216
+ lambda t: t.unsqueeze(3)
217
+ .reshape(b, t.shape[1], self.heads, self.dim_head)
218
+ .permute(0, 2, 1, 3)
219
+ .reshape(b * self.heads, t.shape[1], self.dim_head)
220
+ .contiguous(),
221
+ (q, k, v),
222
+ )
223
+
224
+ # actually compute the attention, what we cannot get enough of
225
+ out = xformers.ops.memory_efficient_attention(
226
+ q, k, v, attn_bias=None, op=self.attention_op
227
+ )
228
+
229
+ if self.ip_dim > 0:
230
+ k_ip, v_ip = map(
231
+ lambda t: t.unsqueeze(3)
232
+ .reshape(b, t.shape[1], self.heads, self.dim_head)
233
+ .permute(0, 2, 1, 3)
234
+ .reshape(b * self.heads, t.shape[1], self.dim_head)
235
+ .contiguous(),
236
+ (k_ip, v_ip),
237
+ )
238
+ # actually compute the attention, what we cannot get enough of
239
+ out_ip = xformers.ops.memory_efficient_attention(
240
+ q, k_ip, v_ip, attn_bias=None, op=self.attention_op
241
+ )
242
+ out = out + self.ip_weight * out_ip
243
+
244
+ out = (
245
+ out.unsqueeze(0)
246
+ .reshape(b, self.heads, out.shape[1], self.dim_head)
247
+ .permute(0, 2, 1, 3)
248
+ .reshape(b, out.shape[1], self.heads * self.dim_head)
249
+ )
250
+ return self.to_out(out)
251
+
252
+
253
+ class BasicTransformerBlock3D(nn.Module):
254
+
255
+ def __init__(
256
+ self,
257
+ dim,
258
+ n_heads,
259
+ d_head,
260
+ context_dim,
261
+ dropout=0.0,
262
+ gated_ff=True,
263
+ ip_dim=0,
264
+ ip_weight=1,
265
+ ):
266
+ super().__init__()
267
+
268
+ self.attn1 = MemoryEfficientCrossAttention(
269
+ query_dim=dim,
270
+ context_dim=None, # self-attention
271
+ heads=n_heads,
272
+ dim_head=d_head,
273
+ dropout=dropout,
274
+ )
275
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
276
+ self.attn2 = MemoryEfficientCrossAttention(
277
+ query_dim=dim,
278
+ context_dim=context_dim,
279
+ heads=n_heads,
280
+ dim_head=d_head,
281
+ dropout=dropout,
282
+ # ip only applies to cross-attention
283
+ ip_dim=ip_dim,
284
+ ip_weight=ip_weight,
285
+ )
286
+ self.norm1 = nn.LayerNorm(dim)
287
+ self.norm2 = nn.LayerNorm(dim)
288
+ self.norm3 = nn.LayerNorm(dim)
289
+
290
+ def forward(self, x, context=None, num_frames=1):
291
+ x = rearrange(x, "(b f) l c -> b (f l) c", f=num_frames).contiguous()
292
+ x = self.attn1(self.norm1(x), context=None) + x
293
+ x = rearrange(x, "b (f l) c -> (b f) l c", f=num_frames).contiguous()
294
+ x = self.attn2(self.norm2(x), context=context) + x
295
+ x = self.ff(self.norm3(x)) + x
296
+ return x
297
+
298
+
299
+ class SpatialTransformer3D(nn.Module):
300
+
301
+ def __init__(
302
+ self,
303
+ in_channels,
304
+ n_heads,
305
+ d_head,
306
+ context_dim, # cross attention input dim
307
+ depth=1,
308
+ dropout=0.0,
309
+ ip_dim=0,
310
+ ip_weight=1,
311
+ ):
312
+ super().__init__()
313
+
314
+ if not isinstance(context_dim, list):
315
+ context_dim = [context_dim]
316
+
317
+ self.in_channels = in_channels
318
+
319
+ inner_dim = n_heads * d_head
320
+ self.norm = nn.GroupNorm(
321
+ num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
322
+ )
323
+ self.proj_in = nn.Linear(in_channels, inner_dim)
324
+
325
+ self.transformer_blocks = nn.ModuleList(
326
+ [
327
+ BasicTransformerBlock3D(
328
+ inner_dim,
329
+ n_heads,
330
+ d_head,
331
+ context_dim=context_dim[d],
332
+ dropout=dropout,
333
+ ip_dim=ip_dim,
334
+ ip_weight=ip_weight,
335
+ )
336
+ for d in range(depth)
337
+ ]
338
+ )
339
+
340
+ self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
341
+
342
+ def forward(self, x, context=None, num_frames=1):
343
+ # note: if no context is given, cross-attention defaults to self-attention
344
+ if not isinstance(context, list):
345
+ context = [context]
346
+ b, c, h, w = x.shape
347
+ x_in = x
348
+ x = self.norm(x)
349
+ x = rearrange(x, "b c h w -> b (h w) c").contiguous()
350
+ x = self.proj_in(x)
351
+ for i, block in enumerate(self.transformer_blocks):
352
+ x = block(x, context=context[i], num_frames=num_frames)
353
+ x = self.proj_out(x)
354
+ x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous()
355
+
356
+ return x + x_in
357
+
358
+
359
+ class PerceiverAttention(nn.Module):
360
+ def __init__(self, *, dim, dim_head=64, heads=8):
361
+ super().__init__()
362
+ self.scale = dim_head**-0.5
363
+ self.dim_head = dim_head
364
+ self.heads = heads
365
+ inner_dim = dim_head * heads
366
+
367
+ self.norm1 = nn.LayerNorm(dim)
368
+ self.norm2 = nn.LayerNorm(dim)
369
+
370
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
371
+ self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
372
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
373
+
374
+ def forward(self, x, latents):
375
+ """
376
+ Args:
377
+ x (torch.Tensor): image features
378
+ shape (b, n1, D)
379
+ latent (torch.Tensor): latent features
380
+ shape (b, n2, D)
381
+ """
382
+ x = self.norm1(x)
383
+ latents = self.norm2(latents)
384
+
385
+ b, h, _ = latents.shape
386
+
387
+ q = self.to_q(latents)
388
+ kv_input = torch.cat((x, latents), dim=-2)
389
+ k, v = self.to_kv(kv_input).chunk(2, dim=-1)
390
+
391
+ q, k, v = map(
392
+ lambda t: t.reshape(b, t.shape[1], self.heads, -1)
393
+ .transpose(1, 2)
394
+ .reshape(b, self.heads, t.shape[1], -1)
395
+ .contiguous(),
396
+ (q, k, v),
397
+ )
398
+
399
+ # attention
400
+ scale = 1 / math.sqrt(math.sqrt(self.dim_head))
401
+ weight = (q * scale) @ (k * scale).transpose(
402
+ -2, -1
403
+ ) # More stable with f16 than dividing afterwards
404
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
405
+ out = weight @ v
406
+
407
+ out = out.permute(0, 2, 1, 3).reshape(b, h, -1)
408
+
409
+ return self.to_out(out)
410
+
411
+
412
+ class Resampler(nn.Module):
413
+ def __init__(
414
+ self,
415
+ dim=1024,
416
+ depth=8,
417
+ dim_head=64,
418
+ heads=16,
419
+ num_queries=8,
420
+ embedding_dim=768,
421
+ output_dim=1024,
422
+ ff_mult=4,
423
+ ):
424
+ super().__init__()
425
+ self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
426
+ self.proj_in = nn.Linear(embedding_dim, dim)
427
+ self.proj_out = nn.Linear(dim, output_dim)
428
+ self.norm_out = nn.LayerNorm(output_dim)
429
+
430
+ self.layers = nn.ModuleList([])
431
+ for _ in range(depth):
432
+ self.layers.append(
433
+ nn.ModuleList(
434
+ [
435
+ PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
436
+ nn.Sequential(
437
+ nn.LayerNorm(dim),
438
+ nn.Linear(dim, dim * ff_mult, bias=False),
439
+ nn.GELU(),
440
+ nn.Linear(dim * ff_mult, dim, bias=False),
441
+ ),
442
+ ]
443
+ )
444
+ )
445
+
446
+ def forward(self, x):
447
+ latents = self.latents.repeat(x.size(0), 1, 1)
448
+ x = self.proj_in(x)
449
+ for attn, ff in self.layers:
450
+ latents = attn(x, latents) + latents
451
+ latents = ff(latents) + latents
452
+
453
+ latents = self.proj_out(latents)
454
+ return self.norm_out(latents)
455
+
456
+
457
+ class CondSequential(nn.Sequential):
458
+ """
459
+ A sequential module that passes timestep embeddings to the children that
460
+ support it as an extra input.
461
+ """
462
+
463
+ def forward(self, x, emb, context=None, num_frames=1):
464
+ for layer in self:
465
+ if isinstance(layer, ResBlock):
466
+ x = layer(x, emb)
467
+ elif isinstance(layer, SpatialTransformer3D):
468
+ x = layer(x, context, num_frames=num_frames)
469
+ else:
470
+ x = layer(x)
471
+ return x
472
+
473
+
474
+ class Upsample(nn.Module):
475
+ """
476
+ An upsampling layer with an optional convolution.
477
+ :param channels: channels in the inputs and outputs.
478
+ :param use_conv: a bool determining if a convolution is applied.
479
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
480
+ upsampling occurs in the inner-two dimensions.
481
+ """
482
+
483
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
484
+ super().__init__()
485
+ self.channels = channels
486
+ self.out_channels = out_channels or channels
487
+ self.use_conv = use_conv
488
+ self.dims = dims
489
+ if use_conv:
490
+ self.conv = conv_nd(
491
+ dims, self.channels, self.out_channels, 3, padding=padding
492
+ )
493
+
494
+ def forward(self, x):
495
+ assert x.shape[1] == self.channels
496
+ if self.dims == 3:
497
+ x = F.interpolate(
498
+ x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
499
+ )
500
+ else:
501
+ x = F.interpolate(x, scale_factor=2, mode="nearest")
502
+ if self.use_conv:
503
+ x = self.conv(x)
504
+ return x
505
+
506
+
507
+ class Downsample(nn.Module):
508
+ """
509
+ A downsampling layer with an optional convolution.
510
+ :param channels: channels in the inputs and outputs.
511
+ :param use_conv: a bool determining if a convolution is applied.
512
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
513
+ downsampling occurs in the inner-two dimensions.
514
+ """
515
+
516
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
517
+ super().__init__()
518
+ self.channels = channels
519
+ self.out_channels = out_channels or channels
520
+ self.use_conv = use_conv
521
+ self.dims = dims
522
+ stride = 2 if dims != 3 else (1, 2, 2)
523
+ if use_conv:
524
+ self.op = conv_nd(
525
+ dims,
526
+ self.channels,
527
+ self.out_channels,
528
+ 3,
529
+ stride=stride,
530
+ padding=padding,
531
+ )
532
+ else:
533
+ assert self.channels == self.out_channels
534
+ self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
535
+
536
+ def forward(self, x):
537
+ assert x.shape[1] == self.channels
538
+ return self.op(x)
539
+
540
+
541
+ class ResBlock(nn.Module):
542
+ """
543
+ A residual block that can optionally change the number of channels.
544
+ :param channels: the number of input channels.
545
+ :param emb_channels: the number of timestep embedding channels.
546
+ :param dropout: the rate of dropout.
547
+ :param out_channels: if specified, the number of out channels.
548
+ :param use_conv: if True and out_channels is specified, use a spatial
549
+ convolution instead of a smaller 1x1 convolution to change the
550
+ channels in the skip connection.
551
+ :param dims: determines if the signal is 1D, 2D, or 3D.
552
+ :param up: if True, use this block for upsampling.
553
+ :param down: if True, use this block for downsampling.
554
+ """
555
+
556
+ def __init__(
557
+ self,
558
+ channels,
559
+ emb_channels,
560
+ dropout,
561
+ out_channels=None,
562
+ use_conv=False,
563
+ use_scale_shift_norm=False,
564
+ dims=2,
565
+ up=False,
566
+ down=False,
567
+ ):
568
+ super().__init__()
569
+ self.channels = channels
570
+ self.emb_channels = emb_channels
571
+ self.dropout = dropout
572
+ self.out_channels = out_channels or channels
573
+ self.use_conv = use_conv
574
+ self.use_scale_shift_norm = use_scale_shift_norm
575
+
576
+ self.in_layers = nn.Sequential(
577
+ nn.GroupNorm(32, channels),
578
+ nn.SiLU(),
579
+ conv_nd(dims, channels, self.out_channels, 3, padding=1),
580
+ )
581
+
582
+ self.updown = up or down
583
+
584
+ if up:
585
+ self.h_upd = Upsample(channels, False, dims)
586
+ self.x_upd = Upsample(channels, False, dims)
587
+ elif down:
588
+ self.h_upd = Downsample(channels, False, dims)
589
+ self.x_upd = Downsample(channels, False, dims)
590
+ else:
591
+ self.h_upd = self.x_upd = nn.Identity()
592
+
593
+ self.emb_layers = nn.Sequential(
594
+ nn.SiLU(),
595
+ nn.Linear(
596
+ emb_channels,
597
+ 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
598
+ ),
599
+ )
600
+ self.out_layers = nn.Sequential(
601
+ nn.GroupNorm(32, self.out_channels),
602
+ nn.SiLU(),
603
+ nn.Dropout(p=dropout),
604
+ zero_module(
605
+ conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
606
+ ),
607
+ )
608
+
609
+ if self.out_channels == channels:
610
+ self.skip_connection = nn.Identity()
611
+ elif use_conv:
612
+ self.skip_connection = conv_nd(
613
+ dims, channels, self.out_channels, 3, padding=1
614
+ )
615
+ else:
616
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
617
+
618
+ def forward(self, x, emb):
619
+ if self.updown:
620
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
621
+ h = in_rest(x)
622
+ h = self.h_upd(h)
623
+ x = self.x_upd(x)
624
+ h = in_conv(h)
625
+ else:
626
+ h = self.in_layers(x)
627
+ emb_out = self.emb_layers(emb).type(h.dtype)
628
+ while len(emb_out.shape) < len(h.shape):
629
+ emb_out = emb_out[..., None]
630
+ if self.use_scale_shift_norm:
631
+ out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
632
+ scale, shift = torch.chunk(emb_out, 2, dim=1)
633
+ h = out_norm(h) * (1 + scale) + shift
634
+ h = out_rest(h)
635
+ else:
636
+ h = h + emb_out
637
+ h = self.out_layers(h)
638
+ return self.skip_connection(x) + h
639
+
640
+
641
+ class MultiViewUNetModel(ModelMixin, ConfigMixin):
642
+ """
643
+ The full multi-view UNet model with attention, timestep embedding and camera embedding.
644
+ :param in_channels: channels in the input Tensor.
645
+ :param model_channels: base channel count for the model.
646
+ :param out_channels: channels in the output Tensor.
647
+ :param num_res_blocks: number of residual blocks per downsample.
648
+ :param attention_resolutions: a collection of downsample rates at which
649
+ attention will take place. May be a set, list, or tuple.
650
+ For example, if this contains 4, then at 4x downsampling, attention
651
+ will be used.
652
+ :param dropout: the dropout probability.
653
+ :param channel_mult: channel multiplier for each level of the UNet.
654
+ :param conv_resample: if True, use learned convolutions for upsampling and
655
+ downsampling.
656
+ :param dims: determines if the signal is 1D, 2D, or 3D.
657
+ :param num_classes: if specified (as an int), then this model will be
658
+ class-conditional with `num_classes` classes.
659
+ :param num_heads: the number of attention heads in each attention layer.
660
+ :param num_heads_channels: if specified, ignore num_heads and instead use
661
+ a fixed channel width per attention head.
662
+ :param num_heads_upsample: works with num_heads to set a different number
663
+ of heads for upsampling. Deprecated.
664
+ :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
665
+ :param resblock_updown: use residual blocks for up/downsampling.
666
+ :param use_new_attention_order: use a different attention pattern for potentially
667
+ increased efficiency.
668
+ :param camera_dim: dimensionality of camera input.
669
+ """
670
+
671
+ def __init__(
672
+ self,
673
+ image_size,
674
+ in_channels,
675
+ model_channels,
676
+ out_channels,
677
+ num_res_blocks,
678
+ attention_resolutions,
679
+ dropout=0,
680
+ channel_mult=(1, 2, 4, 8),
681
+ conv_resample=True,
682
+ dims=2,
683
+ num_classes=None,
684
+ num_heads=-1,
685
+ num_head_channels=-1,
686
+ num_heads_upsample=-1,
687
+ use_scale_shift_norm=False,
688
+ resblock_updown=False,
689
+ transformer_depth=1,
690
+ context_dim=None,
691
+ n_embed=None,
692
+ num_attention_blocks=None,
693
+ adm_in_channels=None,
694
+ camera_dim=None,
695
+ ip_dim=0, # imagedream uses ip_dim > 0
696
+ ip_weight=1.0,
697
+ **kwargs,
698
+ ):
699
+ super().__init__()
700
+ assert context_dim is not None
701
+
702
+ if num_heads_upsample == -1:
703
+ num_heads_upsample = num_heads
704
+
705
+ if num_heads == -1:
706
+ assert (
707
+ num_head_channels != -1
708
+ ), "Either num_heads or num_head_channels has to be set"
709
+
710
+ if num_head_channels == -1:
711
+ assert (
712
+ num_heads != -1
713
+ ), "Either num_heads or num_head_channels has to be set"
714
+
715
+ self.image_size = image_size
716
+ self.in_channels = in_channels
717
+ self.model_channels = model_channels
718
+ self.out_channels = out_channels
719
+ if isinstance(num_res_blocks, int):
720
+ self.num_res_blocks = len(channel_mult) * [num_res_blocks]
721
+ else:
722
+ if len(num_res_blocks) != len(channel_mult):
723
+ raise ValueError(
724
+ "provide num_res_blocks either as an int (globally constant) or "
725
+ "as a list/tuple (per-level) with the same length as channel_mult"
726
+ )
727
+ self.num_res_blocks = num_res_blocks
728
+
729
+ if num_attention_blocks is not None:
730
+ assert len(num_attention_blocks) == len(self.num_res_blocks)
731
+ assert all(
732
+ map(
733
+ lambda i: self.num_res_blocks[i] >= num_attention_blocks[i],
734
+ range(len(num_attention_blocks)),
735
+ )
736
+ )
737
+ print(
738
+ f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
739
+ f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
740
+ f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
741
+ f"attention will still not be set."
742
+ )
743
+
744
+ self.attention_resolutions = attention_resolutions
745
+ self.dropout = dropout
746
+ self.channel_mult = channel_mult
747
+ self.conv_resample = conv_resample
748
+ self.num_classes = num_classes
749
+ self.num_heads = num_heads
750
+ self.num_head_channels = num_head_channels
751
+ self.num_heads_upsample = num_heads_upsample
752
+ self.predict_codebook_ids = n_embed is not None
753
+
754
+ self.ip_dim = ip_dim
755
+ self.ip_weight = ip_weight
756
+
757
+ if self.ip_dim > 0:
758
+ self.image_embed = Resampler(
759
+ dim=context_dim,
760
+ depth=4,
761
+ dim_head=64,
762
+ heads=12,
763
+ num_queries=ip_dim, # num token
764
+ embedding_dim=1280,
765
+ output_dim=context_dim,
766
+ ff_mult=4,
767
+ )
768
+
769
+ time_embed_dim = model_channels * 4
770
+ self.time_embed = nn.Sequential(
771
+ nn.Linear(model_channels, time_embed_dim),
772
+ nn.SiLU(),
773
+ nn.Linear(time_embed_dim, time_embed_dim),
774
+ )
775
+
776
+ if camera_dim is not None:
777
+ time_embed_dim = model_channels * 4
778
+ self.camera_embed = nn.Sequential(
779
+ nn.Linear(camera_dim, time_embed_dim),
780
+ nn.SiLU(),
781
+ nn.Linear(time_embed_dim, time_embed_dim),
782
+ )
783
+
784
+ if self.num_classes is not None:
785
+ if isinstance(self.num_classes, int):
786
+ self.label_emb = nn.Embedding(self.num_classes, time_embed_dim)
787
+ elif self.num_classes == "continuous":
788
+ # print("setting up linear c_adm embedding layer")
789
+ self.label_emb = nn.Linear(1, time_embed_dim)
790
+ elif self.num_classes == "sequential":
791
+ assert adm_in_channels is not None
792
+ self.label_emb = nn.Sequential(
793
+ nn.Sequential(
794
+ nn.Linear(adm_in_channels, time_embed_dim),
795
+ nn.SiLU(),
796
+ nn.Linear(time_embed_dim, time_embed_dim),
797
+ )
798
+ )
799
+ else:
800
+ raise ValueError()
801
+
802
+ self.input_blocks = nn.ModuleList(
803
+ [CondSequential(conv_nd(dims, in_channels, model_channels, 3, padding=1))]
804
+ )
805
+ self._feature_size = model_channels
806
+ input_block_chans = [model_channels]
807
+ ch = model_channels
808
+ ds = 1
809
+ for level, mult in enumerate(channel_mult):
810
+ for nr in range(self.num_res_blocks[level]):
811
+ layers: List[Any] = [
812
+ ResBlock(
813
+ ch,
814
+ time_embed_dim,
815
+ dropout,
816
+ out_channels=mult * model_channels,
817
+ dims=dims,
818
+ use_scale_shift_norm=use_scale_shift_norm,
819
+ )
820
+ ]
821
+ ch = mult * model_channels
822
+ if ds in attention_resolutions:
823
+ if num_head_channels == -1:
824
+ dim_head = ch // num_heads
825
+ else:
826
+ num_heads = ch // num_head_channels
827
+ dim_head = num_head_channels
828
+
829
+ if num_attention_blocks is None or nr < num_attention_blocks[level]:
830
+ layers.append(
831
+ SpatialTransformer3D(
832
+ ch,
833
+ num_heads,
834
+ dim_head,
835
+ context_dim=context_dim,
836
+ depth=transformer_depth,
837
+ ip_dim=self.ip_dim,
838
+ ip_weight=self.ip_weight,
839
+ )
840
+ )
841
+ self.input_blocks.append(CondSequential(*layers))
842
+ self._feature_size += ch
843
+ input_block_chans.append(ch)
844
+ if level != len(channel_mult) - 1:
845
+ out_ch = ch
846
+ self.input_blocks.append(
847
+ CondSequential(
848
+ ResBlock(
849
+ ch,
850
+ time_embed_dim,
851
+ dropout,
852
+ out_channels=out_ch,
853
+ dims=dims,
854
+ use_scale_shift_norm=use_scale_shift_norm,
855
+ down=True,
856
+ )
857
+ if resblock_updown
858
+ else Downsample(
859
+ ch, conv_resample, dims=dims, out_channels=out_ch
860
+ )
861
+ )
862
+ )
863
+ ch = out_ch
864
+ input_block_chans.append(ch)
865
+ ds *= 2
866
+ self._feature_size += ch
867
+
868
+ if num_head_channels == -1:
869
+ dim_head = ch // num_heads
870
+ else:
871
+ num_heads = ch // num_head_channels
872
+ dim_head = num_head_channels
873
+
874
+ self.middle_block = CondSequential(
875
+ ResBlock(
876
+ ch,
877
+ time_embed_dim,
878
+ dropout,
879
+ dims=dims,
880
+ use_scale_shift_norm=use_scale_shift_norm,
881
+ ),
882
+ SpatialTransformer3D(
883
+ ch,
884
+ num_heads,
885
+ dim_head,
886
+ context_dim=context_dim,
887
+ depth=transformer_depth,
888
+ ip_dim=self.ip_dim,
889
+ ip_weight=self.ip_weight,
890
+ ),
891
+ ResBlock(
892
+ ch,
893
+ time_embed_dim,
894
+ dropout,
895
+ dims=dims,
896
+ use_scale_shift_norm=use_scale_shift_norm,
897
+ ),
898
+ )
899
+ self._feature_size += ch
900
+
901
+ self.output_blocks = nn.ModuleList([])
902
+ for level, mult in list(enumerate(channel_mult))[::-1]:
903
+ for i in range(self.num_res_blocks[level] + 1):
904
+ ich = input_block_chans.pop()
905
+ layers = [
906
+ ResBlock(
907
+ ch + ich,
908
+ time_embed_dim,
909
+ dropout,
910
+ out_channels=model_channels * mult,
911
+ dims=dims,
912
+ use_scale_shift_norm=use_scale_shift_norm,
913
+ )
914
+ ]
915
+ ch = model_channels * mult
916
+ if ds in attention_resolutions:
917
+ if num_head_channels == -1:
918
+ dim_head = ch // num_heads
919
+ else:
920
+ num_heads = ch // num_head_channels
921
+ dim_head = num_head_channels
922
+
923
+ if num_attention_blocks is None or i < num_attention_blocks[level]:
924
+ layers.append(
925
+ SpatialTransformer3D(
926
+ ch,
927
+ num_heads,
928
+ dim_head,
929
+ context_dim=context_dim,
930
+ depth=transformer_depth,
931
+ ip_dim=self.ip_dim,
932
+ ip_weight=self.ip_weight,
933
+ )
934
+ )
935
+ if level and i == self.num_res_blocks[level]:
936
+ out_ch = ch
937
+ layers.append(
938
+ ResBlock(
939
+ ch,
940
+ time_embed_dim,
941
+ dropout,
942
+ out_channels=out_ch,
943
+ dims=dims,
944
+ use_scale_shift_norm=use_scale_shift_norm,
945
+ up=True,
946
+ )
947
+ if resblock_updown
948
+ else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
949
+ )
950
+ ds //= 2
951
+ self.output_blocks.append(CondSequential(*layers))
952
+ self._feature_size += ch
953
+
954
+ self.out = nn.Sequential(
955
+ nn.GroupNorm(32, ch),
956
+ nn.SiLU(),
957
+ zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
958
+ )
959
+ if self.predict_codebook_ids:
960
+ self.id_predictor = nn.Sequential(
961
+ nn.GroupNorm(32, ch),
962
+ conv_nd(dims, model_channels, n_embed, 1),
963
+ # nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
964
+ )
965
+
966
+ def forward(
967
+ self,
968
+ x,
969
+ timesteps=None,
970
+ context=None,
971
+ y=None,
972
+ camera=None,
973
+ num_frames=1,
974
+ ip=None,
975
+ ip_img=None,
976
+ **kwargs,
977
+ ):
978
+ """
979
+ Apply the model to an input batch.
980
+ :param x: an [(N x F) x C x ...] Tensor of inputs. F is the number of frames (views).
981
+ :param timesteps: a 1-D batch of timesteps.
982
+ :param context: conditioning plugged in via crossattn
983
+ :param y: an [N] Tensor of labels, if class-conditional.
984
+ :param num_frames: a integer indicating number of frames for tensor reshaping.
985
+ :return: an [(N x F) x C x ...] Tensor of outputs. F is the number of frames (views).
986
+ """
987
+ assert (
988
+ x.shape[0] % num_frames == 0
989
+ ), "input batch size must be dividable by num_frames!"
990
+ assert (y is not None) == (
991
+ self.num_classes is not None
992
+ ), "must specify y if and only if the model is class-conditional"
993
+
994
+ hs = []
995
+
996
+ t_emb = timestep_embedding(
997
+ timesteps, self.model_channels, repeat_only=False
998
+ ).to(x.dtype)
999
+
1000
+ emb = self.time_embed(t_emb)
1001
+
1002
+ if self.num_classes is not None:
1003
+ assert y is not None
1004
+ assert y.shape[0] == x.shape[0]
1005
+ emb = emb + self.label_emb(y)
1006
+
1007
+ # Add camera embeddings
1008
+ if camera is not None:
1009
+ emb = emb + self.camera_embed(camera)
1010
+
1011
+ # imagedream variant
1012
+ if self.ip_dim > 0:
1013
+ x[(num_frames - 1) :: num_frames, :, :, :] = ip_img # place at [4, 9]
1014
+ ip_emb = self.image_embed(ip)
1015
+ context = torch.cat((context, ip_emb), 1)
1016
+
1017
+ h = x
1018
+ for module in self.input_blocks:
1019
+ h = module(h, emb, context, num_frames=num_frames)
1020
+ hs.append(h)
1021
+ h = self.middle_block(h, emb, context, num_frames=num_frames)
1022
+ for module in self.output_blocks:
1023
+ h = torch.cat([h, hs.pop()], dim=1)
1024
+ h = module(h, emb, context, num_frames=num_frames)
1025
+ h = h.type(x.dtype)
1026
+ if self.predict_codebook_ids:
1027
+ return self.id_predictor(h)
1028
+ else:
1029
+ return self.out(h)
1030
+
1031
+
1032
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
1033
+
1034
+
1035
+ class LGMFullPipeline(DiffusionPipeline):
1036
+
1037
+ _optional_components = ["feature_extractor", "image_encoder"]
1038
+
1039
+ def __init__(
1040
+ self,
1041
+ vae: AutoencoderKL,
1042
+ unet: MultiViewUNetModel,
1043
+ tokenizer: CLIPTokenizer,
1044
+ text_encoder: CLIPTextModel,
1045
+ scheduler: DDIMScheduler,
1046
+ # imagedream variant
1047
+ feature_extractor: CLIPImageProcessor,
1048
+ image_encoder: CLIPVisionModel,
1049
+ lgm,
1050
+ requires_safety_checker: bool = False,
1051
+ ):
1052
+ super().__init__()
1053
+
1054
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: # type: ignore
1055
+ deprecation_message = (
1056
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
1057
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " # type: ignore
1058
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
1059
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
1060
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
1061
+ " file"
1062
+ )
1063
+ deprecate(
1064
+ "steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False
1065
+ )
1066
+ new_config = dict(scheduler.config)
1067
+ new_config["steps_offset"] = 1
1068
+ scheduler._internal_dict = FrozenDict(new_config)
1069
+
1070
+ if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: # type: ignore
1071
+ deprecation_message = (
1072
+ f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
1073
+ " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
1074
+ " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
1075
+ " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
1076
+ " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
1077
+ )
1078
+ deprecate(
1079
+ "clip_sample not set", "1.0.0", deprecation_message, standard_warn=False
1080
+ )
1081
+ new_config = dict(scheduler.config)
1082
+ new_config["clip_sample"] = False
1083
+ scheduler._internal_dict = FrozenDict(new_config)
1084
+
1085
+ self.imagenet_default_mean = (0.485, 0.456, 0.406)
1086
+ self.imagenet_default_std = (0.229, 0.224, 0.225)
1087
+
1088
+ lgm = lgm.half().cuda()
1089
+
1090
+ self.register_modules(
1091
+ vae=vae,
1092
+ unet=unet,
1093
+ scheduler=scheduler,
1094
+ tokenizer=tokenizer,
1095
+ text_encoder=text_encoder,
1096
+ feature_extractor=feature_extractor,
1097
+ image_encoder=image_encoder,
1098
+ lgm=lgm,
1099
+ )
1100
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
1101
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
1102
+
1103
+ def save_ply(self, gaussians, path):
1104
+ self.lgm.gs.save_ply(gaussians, path)
1105
+
1106
+ def enable_vae_slicing(self):
1107
+ r"""
1108
+ Enable sliced VAE decoding.
1109
+
1110
+ When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
1111
+ steps. This is useful to save some memory and allow larger batch sizes.
1112
+ """
1113
+ self.vae.enable_slicing()
1114
+
1115
+ def disable_vae_slicing(self):
1116
+ r"""
1117
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to
1118
+ computing decoding in one step.
1119
+ """
1120
+ self.vae.disable_slicing()
1121
+
1122
+ def enable_vae_tiling(self):
1123
+ r"""
1124
+ Enable tiled VAE decoding.
1125
+
1126
+ When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in
1127
+ several steps. This is useful to save a large amount of memory and to allow the processing of larger images.
1128
+ """
1129
+ self.vae.enable_tiling()
1130
+
1131
+ def disable_vae_tiling(self):
1132
+ r"""
1133
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to
1134
+ computing decoding in one step.
1135
+ """
1136
+ self.vae.disable_tiling()
1137
+
1138
+ def enable_sequential_cpu_offload(self, gpu_id=0):
1139
+ r"""
1140
+ Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
1141
+ text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
1142
+ `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
1143
+ Note that offloading happens on a submodule basis. Memory savings are higher than with
1144
+ `enable_model_cpu_offload`, but performance is lower.
1145
+ """
1146
+ if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"):
1147
+ from accelerate import cpu_offload
1148
+ else:
1149
+ raise ImportError(
1150
+ "`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher"
1151
+ )
1152
+
1153
+ device = torch.device(f"cuda:{gpu_id}")
1154
+
1155
+ if self.device.type != "cpu":
1156
+ self.to("cpu", silence_dtype_warnings=True)
1157
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
1158
+
1159
+ for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
1160
+ cpu_offload(cpu_offloaded_model, device)
1161
+
1162
+ def enable_model_cpu_offload(self, gpu_id=0):
1163
+ r"""
1164
+ Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
1165
+ to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
1166
+ method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
1167
+ `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
1168
+ """
1169
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
1170
+ from accelerate import cpu_offload_with_hook
1171
+ else:
1172
+ raise ImportError(
1173
+ "`enable_model_offload` requires `accelerate v0.17.0` or higher."
1174
+ )
1175
+
1176
+ device = torch.device(f"cuda:{gpu_id}")
1177
+
1178
+ if self.device.type != "cpu":
1179
+ self.to("cpu", silence_dtype_warnings=True)
1180
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
1181
+
1182
+ hook = None
1183
+ for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
1184
+ _, hook = cpu_offload_with_hook(
1185
+ cpu_offloaded_model, device, prev_module_hook=hook
1186
+ )
1187
+
1188
+ # We'll offload the last model manually.
1189
+ self.final_offload_hook = hook
1190
+
1191
+ @property
1192
+ def _execution_device(self):
1193
+ r"""
1194
+ Returns the device on which the pipeline's models will be executed. After calling
1195
+ `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
1196
+ hooks.
1197
+ """
1198
+ if not hasattr(self.unet, "_hf_hook"):
1199
+ return self.device
1200
+ for module in self.unet.modules():
1201
+ if (
1202
+ hasattr(module, "_hf_hook")
1203
+ and hasattr(module._hf_hook, "execution_device")
1204
+ and module._hf_hook.execution_device is not None
1205
+ ):
1206
+ return torch.device(module._hf_hook.execution_device)
1207
+ return self.device
1208
+
1209
+ def _encode_prompt(
1210
+ self,
1211
+ prompt,
1212
+ device,
1213
+ num_images_per_prompt,
1214
+ do_classifier_free_guidance: bool,
1215
+ negative_prompt=None,
1216
+ ):
1217
+ r"""
1218
+ Encodes the prompt into text encoder hidden states.
1219
+
1220
+ Args:
1221
+ prompt (`str` or `List[str]`, *optional*):
1222
+ prompt to be encoded
1223
+ device: (`torch.device`):
1224
+ torch device
1225
+ num_images_per_prompt (`int`):
1226
+ number of images that should be generated per prompt
1227
+ do_classifier_free_guidance (`bool`):
1228
+ whether to use classifier free guidance or not
1229
+ negative_prompt (`str` or `List[str]`, *optional*):
1230
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
1231
+ `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
1232
+ Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
1233
+ prompt_embeds (`torch.FloatTensor`, *optional*):
1234
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
1235
+ provided, text embeddings will be generated from `prompt` input argument.
1236
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
1237
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
1238
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
1239
+ argument.
1240
+ """
1241
+ if prompt is not None and isinstance(prompt, str):
1242
+ batch_size = 1
1243
+ elif prompt is not None and isinstance(prompt, list):
1244
+ batch_size = len(prompt)
1245
+ else:
1246
+ raise ValueError(
1247
+ f"`prompt` should be either a string or a list of strings, but got {type(prompt)}."
1248
+ )
1249
+
1250
+ text_inputs = self.tokenizer(
1251
+ prompt,
1252
+ padding="max_length",
1253
+ max_length=self.tokenizer.model_max_length,
1254
+ truncation=True,
1255
+ return_tensors="pt",
1256
+ )
1257
+ text_input_ids = text_inputs.input_ids
1258
+ untruncated_ids = self.tokenizer(
1259
+ prompt, padding="longest", return_tensors="pt"
1260
+ ).input_ids
1261
+
1262
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
1263
+ text_input_ids, untruncated_ids
1264
+ ):
1265
+ removed_text = self.tokenizer.batch_decode(
1266
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
1267
+ )
1268
+ logger.warning(
1269
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
1270
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
1271
+ )
1272
+
1273
+ if (
1274
+ hasattr(self.text_encoder.config, "use_attention_mask")
1275
+ and self.text_encoder.config.use_attention_mask
1276
+ ):
1277
+ attention_mask = text_inputs.attention_mask.to(device)
1278
+ else:
1279
+ attention_mask = None
1280
+
1281
+ prompt_embeds = self.text_encoder(
1282
+ text_input_ids.to(device),
1283
+ attention_mask=attention_mask,
1284
+ )
1285
+ prompt_embeds = prompt_embeds[0]
1286
+
1287
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
1288
+
1289
+ bs_embed, seq_len, _ = prompt_embeds.shape
1290
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
1291
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
1292
+ prompt_embeds = prompt_embeds.view(
1293
+ bs_embed * num_images_per_prompt, seq_len, -1
1294
+ )
1295
+
1296
+ # get unconditional embeddings for classifier free guidance
1297
+ if do_classifier_free_guidance:
1298
+ uncond_tokens: List[str]
1299
+ if negative_prompt is None:
1300
+ uncond_tokens = [""] * batch_size
1301
+ elif type(prompt) is not type(negative_prompt):
1302
+ raise TypeError(
1303
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
1304
+ f" {type(prompt)}."
1305
+ )
1306
+ elif isinstance(negative_prompt, str):
1307
+ uncond_tokens = [negative_prompt]
1308
+ elif batch_size != len(negative_prompt):
1309
+ raise ValueError(
1310
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
1311
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
1312
+ " the batch size of `prompt`."
1313
+ )
1314
+ else:
1315
+ uncond_tokens = negative_prompt
1316
+
1317
+ max_length = prompt_embeds.shape[1]
1318
+ uncond_input = self.tokenizer(
1319
+ uncond_tokens,
1320
+ padding="max_length",
1321
+ max_length=max_length,
1322
+ truncation=True,
1323
+ return_tensors="pt",
1324
+ )
1325
+
1326
+ if (
1327
+ hasattr(self.text_encoder.config, "use_attention_mask")
1328
+ and self.text_encoder.config.use_attention_mask
1329
+ ):
1330
+ attention_mask = uncond_input.attention_mask.to(device)
1331
+ else:
1332
+ attention_mask = None
1333
+
1334
+ negative_prompt_embeds = self.text_encoder(
1335
+ uncond_input.input_ids.to(device),
1336
+ attention_mask=attention_mask,
1337
+ )
1338
+ negative_prompt_embeds = negative_prompt_embeds[0]
1339
+
1340
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
1341
+ seq_len = negative_prompt_embeds.shape[1]
1342
+
1343
+ negative_prompt_embeds = negative_prompt_embeds.to(
1344
+ dtype=self.text_encoder.dtype, device=device
1345
+ )
1346
+
1347
+ negative_prompt_embeds = negative_prompt_embeds.repeat(
1348
+ 1, num_images_per_prompt, 1
1349
+ )
1350
+ negative_prompt_embeds = negative_prompt_embeds.view(
1351
+ batch_size * num_images_per_prompt, seq_len, -1
1352
+ )
1353
+
1354
+ # For classifier free guidance, we need to do two forward passes.
1355
+ # Here we concatenate the unconditional and text embeddings into a single batch
1356
+ # to avoid doing two forward passes
1357
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
1358
+
1359
+ return prompt_embeds
1360
+
1361
+ def decode_latents(self, latents):
1362
+ latents = 1 / self.vae.config.scaling_factor * latents
1363
+ image = self.vae.decode(latents).sample
1364
+ image = (image / 2 + 0.5).clamp(0, 1)
1365
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
1366
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
1367
+ return image
1368
+
1369
+ def prepare_extra_step_kwargs(self, generator, eta):
1370
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
1371
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
1372
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
1373
+ # and should be between [0, 1]
1374
+
1375
+ accepts_eta = "eta" in set(
1376
+ inspect.signature(self.scheduler.step).parameters.keys()
1377
+ )
1378
+ extra_step_kwargs = {}
1379
+ if accepts_eta:
1380
+ extra_step_kwargs["eta"] = eta
1381
+
1382
+ # check if the scheduler accepts generator
1383
+ accepts_generator = "generator" in set(
1384
+ inspect.signature(self.scheduler.step).parameters.keys()
1385
+ )
1386
+ if accepts_generator:
1387
+ extra_step_kwargs["generator"] = generator
1388
+ return extra_step_kwargs
1389
+
1390
+ def prepare_latents(
1391
+ self,
1392
+ batch_size,
1393
+ num_channels_latents,
1394
+ height,
1395
+ width,
1396
+ dtype,
1397
+ device,
1398
+ generator,
1399
+ latents=None,
1400
+ ):
1401
+ shape = (
1402
+ batch_size,
1403
+ num_channels_latents,
1404
+ height // self.vae_scale_factor,
1405
+ width // self.vae_scale_factor,
1406
+ )
1407
+ if isinstance(generator, list) and len(generator) != batch_size:
1408
+ raise ValueError(
1409
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
1410
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
1411
+ )
1412
+
1413
+ if latents is None:
1414
+ latents = randn_tensor(
1415
+ shape, generator=generator, device=device, dtype=dtype
1416
+ )
1417
+ else:
1418
+ latents = latents.to(device)
1419
+
1420
+ # scale the initial noise by the standard deviation required by the scheduler
1421
+ latents = latents * self.scheduler.init_noise_sigma
1422
+ return latents
1423
+
1424
+ def encode_image(self, image, device, num_images_per_prompt):
1425
+ dtype = next(self.image_encoder.parameters()).dtype
1426
+
1427
+ if image.dtype == np.float32:
1428
+ image = (image * 255).astype(np.uint8)
1429
+
1430
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
1431
+ image = image.to(device=device, dtype=dtype)
1432
+
1433
+ image_embeds = self.image_encoder(
1434
+ image, output_hidden_states=True
1435
+ ).hidden_states[-2]
1436
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
1437
+
1438
+ return torch.zeros_like(image_embeds), image_embeds
1439
+
1440
+ def encode_image_latents(self, image, device, num_images_per_prompt):
1441
+
1442
+ dtype = next(self.image_encoder.parameters()).dtype
1443
+
1444
+ image = (
1445
+ torch.from_numpy(image).unsqueeze(0).permute(0, 3, 1, 2).to(device=device)
1446
+ ) # [1, 3, H, W]
1447
+ image = 2 * image - 1
1448
+ image = F.interpolate(image, (256, 256), mode="bilinear", align_corners=False)
1449
+ image = image.to(dtype=dtype)
1450
+
1451
+ posterior = self.vae.encode(image).latent_dist
1452
+ latents = posterior.sample() * self.vae.config.scaling_factor # [B, C, H, W]
1453
+ latents = latents.repeat_interleave(num_images_per_prompt, dim=0)
1454
+
1455
+ return torch.zeros_like(latents), latents
1456
+
1457
+ @torch.no_grad()
1458
+ def __call__(
1459
+ self,
1460
+ prompt: str = "",
1461
+ image: Optional[np.ndarray] = None,
1462
+ height: int = 256,
1463
+ width: int = 256,
1464
+ elevation: float = 0,
1465
+ num_inference_steps: int = 50,
1466
+ guidance_scale: float = 7.0,
1467
+ negative_prompt: str = "",
1468
+ num_images_per_prompt: int = 1,
1469
+ eta: float = 0.0,
1470
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
1471
+ output_type: Optional[str] = "numpy", # pil, numpy, latents
1472
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
1473
+ callback_steps: int = 1,
1474
+ num_frames: int = 4,
1475
+ device=torch.device("cuda:0"),
1476
+ ):
1477
+ self.unet = self.unet.to(device=device)
1478
+ self.vae = self.vae.to(device=device)
1479
+ self.text_encoder = self.text_encoder.to(device=device)
1480
+
1481
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
1482
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
1483
+ # corresponds to doing no classifier free guidance.
1484
+ do_classifier_free_guidance = guidance_scale > 1.0
1485
+
1486
+ # Prepare timesteps
1487
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
1488
+ timesteps = self.scheduler.timesteps
1489
+
1490
+ # imagedream variant
1491
+ if image is not None:
1492
+ assert isinstance(image, np.ndarray) and image.dtype == np.float32
1493
+ self.image_encoder = self.image_encoder.to(device=device)
1494
+ image_embeds_neg, image_embeds_pos = self.encode_image(
1495
+ image, device, num_images_per_prompt
1496
+ )
1497
+ image_latents_neg, image_latents_pos = self.encode_image_latents(
1498
+ image, device, num_images_per_prompt
1499
+ )
1500
+
1501
+ _prompt_embeds = self._encode_prompt(
1502
+ prompt=prompt,
1503
+ device=device,
1504
+ num_images_per_prompt=num_images_per_prompt,
1505
+ do_classifier_free_guidance=do_classifier_free_guidance,
1506
+ negative_prompt=negative_prompt,
1507
+ ) # type: ignore
1508
+ prompt_embeds_neg, prompt_embeds_pos = _prompt_embeds.chunk(2)
1509
+
1510
+ # Prepare latent variables
1511
+ actual_num_frames = num_frames if image is None else num_frames + 1
1512
+ latents: torch.Tensor = self.prepare_latents(
1513
+ actual_num_frames * num_images_per_prompt,
1514
+ 4,
1515
+ height,
1516
+ width,
1517
+ prompt_embeds_pos.dtype,
1518
+ device,
1519
+ generator,
1520
+ None,
1521
+ )
1522
+
1523
+ # Get camera
1524
+ camera = get_camera(
1525
+ num_frames, elevation=elevation, extra_view=(image is not None)
1526
+ ).to(dtype=latents.dtype, device=device)
1527
+ camera = camera.repeat_interleave(num_images_per_prompt, dim=0)
1528
+
1529
+ # Prepare extra step kwargs.
1530
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
1531
+
1532
+ # Denoising loop
1533
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
1534
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
1535
+ for i, t in enumerate(timesteps):
1536
+ # expand the latents if we are doing classifier free guidance
1537
+ multiplier = 2 if do_classifier_free_guidance else 1
1538
+ latent_model_input = torch.cat([latents] * multiplier)
1539
+ latent_model_input = self.scheduler.scale_model_input(
1540
+ latent_model_input, t
1541
+ )
1542
+
1543
+ unet_inputs = {
1544
+ "x": latent_model_input,
1545
+ "timesteps": torch.tensor(
1546
+ [t] * actual_num_frames * multiplier,
1547
+ dtype=latent_model_input.dtype,
1548
+ device=device,
1549
+ ),
1550
+ "context": torch.cat(
1551
+ [prompt_embeds_neg] * actual_num_frames
1552
+ + [prompt_embeds_pos] * actual_num_frames
1553
+ ),
1554
+ "num_frames": actual_num_frames,
1555
+ "camera": torch.cat([camera] * multiplier),
1556
+ }
1557
+
1558
+ if image is not None:
1559
+ unet_inputs["ip"] = torch.cat(
1560
+ [image_embeds_neg] * actual_num_frames
1561
+ + [image_embeds_pos] * actual_num_frames
1562
+ )
1563
+ unet_inputs["ip_img"] = torch.cat(
1564
+ [image_latents_neg] + [image_latents_pos]
1565
+ ) # no repeat
1566
+
1567
+ # predict the noise residual
1568
+ noise_pred = self.unet.forward(**unet_inputs)
1569
+
1570
+ # perform guidance
1571
+ if do_classifier_free_guidance:
1572
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1573
+ noise_pred = noise_pred_uncond + guidance_scale * (
1574
+ noise_pred_text - noise_pred_uncond
1575
+ )
1576
+
1577
+ # compute the previous noisy sample x_t -> x_t-1
1578
+ latents: torch.Tensor = self.scheduler.step(
1579
+ noise_pred, t, latents, **extra_step_kwargs, return_dict=False
1580
+ )[0]
1581
+
1582
+ # call the callback, if provided
1583
+ if i == len(timesteps) - 1 or (
1584
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
1585
+ ):
1586
+ progress_bar.update()
1587
+ if callback is not None and i % callback_steps == 0:
1588
+ callback(i, t, latents) # type: ignore
1589
+
1590
+ # Post-processing
1591
+ if output_type == "latent":
1592
+ image = latents
1593
+ elif output_type == "pil":
1594
+ image = self.decode_latents(latents)
1595
+ image = self.numpy_to_pil(image)
1596
+ else: # numpy
1597
+ image = self.decode_latents(latents)
1598
+
1599
+ # Offload last model to CPU
1600
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
1601
+ self.final_offload_hook.offload()
1602
+
1603
+ images = np.stack([image[1], image[2], image[3], image[0]], axis=0)
1604
+ images = torch.from_numpy(images).permute(0, 3, 1, 2).float().cuda()
1605
+ images = F.interpolate(
1606
+ images,
1607
+ size=(256, 256),
1608
+ mode="bilinear",
1609
+ align_corners=False,
1610
+ )
1611
+ images = TF.normalize(
1612
+ images, self.imagenet_default_mean, self.imagenet_default_std
1613
+ )
1614
+
1615
+ rays_embeddings = self.lgm.prepare_default_rays("cuda", elevation=0)
1616
+ images = torch.cat([images, rays_embeddings], dim=1).unsqueeze(0)
1617
+ images = images.half().cuda()
1618
+
1619
+ result = self.lgm(images)
1620
+ return result
requirements.txt ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ wheel
2
+ numpy
3
+ tyro
4
+ diffusers
5
+ dearpygui
6
+ einops
7
+ accelerate
8
+ gradio
9
+ imageio
10
+ imageio-ffmpeg
11
+ lpips
12
+ matplotlib
13
+ packaging
14
+ Pillow
15
+ pygltflib
16
+ rembg[gpu,cli]
17
+ rich
18
+ safetensors
19
+ scikit-image
20
+ scikit-learn
21
+ scipy
22
+ spaces
23
+ tqdm
24
+ transformers
25
+ trimesh
26
+ kiui >= 0.2.3
27
+ xatlas
28
+ roma
29
+ plyfile
30
+ torch == 2.2.0
31
+ torchvision == 0.17.0
32
+ torchaudio == 2.2.0
33
+ xformers
34
+ ushlex
scheduler/scheduler_config.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "DDIMScheduler",
3
+ "_diffusers_version": "0.25.0",
4
+ "beta_end": 0.012,
5
+ "beta_schedule": "scaled_linear",
6
+ "beta_start": 0.00085,
7
+ "clip_sample": false,
8
+ "clip_sample_range": 1.0,
9
+ "dynamic_thresholding_ratio": 0.995,
10
+ "num_train_timesteps": 1000,
11
+ "prediction_type": "epsilon",
12
+ "rescale_betas_zero_snr": false,
13
+ "sample_max_value": 1.0,
14
+ "set_alpha_to_one": false,
15
+ "steps_offset": 1,
16
+ "thresholding": false,
17
+ "timestep_spacing": "leading",
18
+ "trained_betas": null
19
+ }
text_encoder/config.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "stabilityai/stable-diffusion-2-1",
3
+ "architectures": [
4
+ "CLIPTextModel"
5
+ ],
6
+ "attention_dropout": 0.0,
7
+ "bos_token_id": 0,
8
+ "dropout": 0.0,
9
+ "eos_token_id": 2,
10
+ "hidden_act": "gelu",
11
+ "hidden_size": 1024,
12
+ "initializer_factor": 1.0,
13
+ "initializer_range": 0.02,
14
+ "intermediate_size": 4096,
15
+ "layer_norm_eps": 1e-05,
16
+ "max_position_embeddings": 77,
17
+ "model_type": "clip_text_model",
18
+ "num_attention_heads": 16,
19
+ "num_hidden_layers": 23,
20
+ "pad_token_id": 1,
21
+ "projection_dim": 512,
22
+ "torch_dtype": "float16",
23
+ "transformers_version": "4.35.2",
24
+ "vocab_size": 49408
25
+ }
text_encoder/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bc1827c465450322616f06dea41596eac7d493f4e95904dcb51f0fc745c4e13f
3
+ size 680820392
tokenizer/merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer/special_tokens_map.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<|startoftext|>",
4
+ "lstrip": false,
5
+ "normalized": true,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "<|endoftext|>",
11
+ "lstrip": false,
12
+ "normalized": true,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": "!",
17
+ "unk_token": {
18
+ "content": "<|endoftext|>",
19
+ "lstrip": false,
20
+ "normalized": true,
21
+ "rstrip": false,
22
+ "single_word": false
23
+ }
24
+ }
tokenizer/tokenizer_config.json ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "added_tokens_decoder": {
4
+ "0": {
5
+ "content": "!",
6
+ "lstrip": false,
7
+ "normalized": false,
8
+ "rstrip": false,
9
+ "single_word": false,
10
+ "special": true
11
+ },
12
+ "49406": {
13
+ "content": "<|startoftext|>",
14
+ "lstrip": false,
15
+ "normalized": true,
16
+ "rstrip": false,
17
+ "single_word": false,
18
+ "special": true
19
+ },
20
+ "49407": {
21
+ "content": "<|endoftext|>",
22
+ "lstrip": false,
23
+ "normalized": true,
24
+ "rstrip": false,
25
+ "single_word": false,
26
+ "special": true
27
+ }
28
+ },
29
+ "bos_token": "<|startoftext|>",
30
+ "clean_up_tokenization_spaces": true,
31
+ "do_lower_case": true,
32
+ "eos_token": "<|endoftext|>",
33
+ "errors": "replace",
34
+ "model_max_length": 77,
35
+ "pad_token": "!",
36
+ "tokenizer_class": "CLIPTokenizer",
37
+ "unk_token": "<|endoftext|>"
38
+ }
tokenizer/vocab.json ADDED
The diff for this file is too large to render. See raw diff
 
unet/config.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "MultiViewUNetModel",
3
+ "_diffusers_version": "0.25.0",
4
+ "attention_resolutions": [
5
+ 4,
6
+ 2,
7
+ 1
8
+ ],
9
+ "camera_dim": 16,
10
+ "channel_mult": [
11
+ 1,
12
+ 2,
13
+ 4,
14
+ 4
15
+ ],
16
+ "context_dim": 1024,
17
+ "image_size": 32,
18
+ "in_channels": 4,
19
+ "ip_dim": 16,
20
+ "model_channels": 320,
21
+ "num_head_channels": 64,
22
+ "num_res_blocks": 2,
23
+ "out_channels": 4,
24
+ "transformer_depth": 1
25
+ }
unet/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:28d8b241a54125fa0a041c1818a5dcdb717e6f5270eea1268172acd3ab0238e0
3
+ size 1883435904
unet/mv_unet.py ADDED
@@ -0,0 +1,1005 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numpy as np
3
+ from inspect import isfunction
4
+ from typing import Optional, Any, List
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from einops import rearrange, repeat
10
+
11
+ from diffusers.configuration_utils import ConfigMixin
12
+ from diffusers.models.modeling_utils import ModelMixin
13
+
14
+ # require xformers!
15
+ import xformers
16
+ import xformers.ops
17
+
18
+ from kiui.cam import orbit_camera
19
+
20
+ def get_camera(
21
+ num_frames, elevation=15, azimuth_start=0, azimuth_span=360, blender_coord=True, extra_view=False,
22
+ ):
23
+ angle_gap = azimuth_span / num_frames
24
+ cameras = []
25
+ for azimuth in np.arange(azimuth_start, azimuth_span + azimuth_start, angle_gap):
26
+
27
+ pose = orbit_camera(-elevation, azimuth, radius=1) # kiui's elevation is negated, [4, 4]
28
+
29
+ # opengl to blender
30
+ if blender_coord:
31
+ pose[2] *= -1
32
+ pose[[1, 2]] = pose[[2, 1]]
33
+
34
+ cameras.append(pose.flatten())
35
+
36
+ if extra_view:
37
+ cameras.append(np.zeros_like(cameras[0]))
38
+
39
+ return torch.from_numpy(np.stack(cameras, axis=0)).float() # [num_frames, 16]
40
+
41
+
42
+ def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
43
+ """
44
+ Create sinusoidal timestep embeddings.
45
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
46
+ These may be fractional.
47
+ :param dim: the dimension of the output.
48
+ :param max_period: controls the minimum frequency of the embeddings.
49
+ :return: an [N x dim] Tensor of positional embeddings.
50
+ """
51
+ if not repeat_only:
52
+ half = dim // 2
53
+ freqs = torch.exp(
54
+ -math.log(max_period)
55
+ * torch.arange(start=0, end=half, dtype=torch.float32)
56
+ / half
57
+ ).to(device=timesteps.device)
58
+ args = timesteps[:, None] * freqs[None]
59
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
60
+ if dim % 2:
61
+ embedding = torch.cat(
62
+ [embedding, torch.zeros_like(embedding[:, :1])], dim=-1
63
+ )
64
+ else:
65
+ embedding = repeat(timesteps, "b -> b d", d=dim)
66
+ # import pdb; pdb.set_trace()
67
+ return embedding
68
+
69
+
70
+ def zero_module(module):
71
+ """
72
+ Zero out the parameters of a module and return it.
73
+ """
74
+ for p in module.parameters():
75
+ p.detach().zero_()
76
+ return module
77
+
78
+
79
+ def conv_nd(dims, *args, **kwargs):
80
+ """
81
+ Create a 1D, 2D, or 3D convolution module.
82
+ """
83
+ if dims == 1:
84
+ return nn.Conv1d(*args, **kwargs)
85
+ elif dims == 2:
86
+ return nn.Conv2d(*args, **kwargs)
87
+ elif dims == 3:
88
+ return nn.Conv3d(*args, **kwargs)
89
+ raise ValueError(f"unsupported dimensions: {dims}")
90
+
91
+
92
+ def avg_pool_nd(dims, *args, **kwargs):
93
+ """
94
+ Create a 1D, 2D, or 3D average pooling module.
95
+ """
96
+ if dims == 1:
97
+ return nn.AvgPool1d(*args, **kwargs)
98
+ elif dims == 2:
99
+ return nn.AvgPool2d(*args, **kwargs)
100
+ elif dims == 3:
101
+ return nn.AvgPool3d(*args, **kwargs)
102
+ raise ValueError(f"unsupported dimensions: {dims}")
103
+
104
+
105
+ def default(val, d):
106
+ if val is not None:
107
+ return val
108
+ return d() if isfunction(d) else d
109
+
110
+
111
+ class GEGLU(nn.Module):
112
+ def __init__(self, dim_in, dim_out):
113
+ super().__init__()
114
+ self.proj = nn.Linear(dim_in, dim_out * 2)
115
+
116
+ def forward(self, x):
117
+ x, gate = self.proj(x).chunk(2, dim=-1)
118
+ return x * F.gelu(gate)
119
+
120
+
121
+ class FeedForward(nn.Module):
122
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
123
+ super().__init__()
124
+ inner_dim = int(dim * mult)
125
+ dim_out = default(dim_out, dim)
126
+ project_in = (
127
+ nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
128
+ if not glu
129
+ else GEGLU(dim, inner_dim)
130
+ )
131
+
132
+ self.net = nn.Sequential(
133
+ project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
134
+ )
135
+
136
+ def forward(self, x):
137
+ return self.net(x)
138
+
139
+
140
+ class MemoryEfficientCrossAttention(nn.Module):
141
+ # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
142
+ def __init__(
143
+ self,
144
+ query_dim,
145
+ context_dim=None,
146
+ heads=8,
147
+ dim_head=64,
148
+ dropout=0.0,
149
+ ip_dim=0,
150
+ ip_weight=1,
151
+ ):
152
+ super().__init__()
153
+
154
+ inner_dim = dim_head * heads
155
+ context_dim = default(context_dim, query_dim)
156
+
157
+ self.heads = heads
158
+ self.dim_head = dim_head
159
+
160
+ self.ip_dim = ip_dim
161
+ self.ip_weight = ip_weight
162
+
163
+ if self.ip_dim > 0:
164
+ self.to_k_ip = nn.Linear(context_dim, inner_dim, bias=False)
165
+ self.to_v_ip = nn.Linear(context_dim, inner_dim, bias=False)
166
+
167
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
168
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
169
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
170
+
171
+ self.to_out = nn.Sequential(
172
+ nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
173
+ )
174
+ self.attention_op: Optional[Any] = None
175
+
176
+ def forward(self, x, context=None):
177
+ q = self.to_q(x)
178
+ context = default(context, x)
179
+
180
+ if self.ip_dim > 0:
181
+ # context: [B, 77 + 16(ip), 1024]
182
+ token_len = context.shape[1]
183
+ context_ip = context[:, -self.ip_dim :, :]
184
+ k_ip = self.to_k_ip(context_ip)
185
+ v_ip = self.to_v_ip(context_ip)
186
+ context = context[:, : (token_len - self.ip_dim), :]
187
+
188
+ k = self.to_k(context)
189
+ v = self.to_v(context)
190
+
191
+ b, _, _ = q.shape
192
+ q, k, v = map(
193
+ lambda t: t.unsqueeze(3)
194
+ .reshape(b, t.shape[1], self.heads, self.dim_head)
195
+ .permute(0, 2, 1, 3)
196
+ .reshape(b * self.heads, t.shape[1], self.dim_head)
197
+ .contiguous(),
198
+ (q, k, v),
199
+ )
200
+
201
+ # actually compute the attention, what we cannot get enough of
202
+ out = xformers.ops.memory_efficient_attention(
203
+ q, k, v, attn_bias=None, op=self.attention_op
204
+ )
205
+
206
+ if self.ip_dim > 0:
207
+ k_ip, v_ip = map(
208
+ lambda t: t.unsqueeze(3)
209
+ .reshape(b, t.shape[1], self.heads, self.dim_head)
210
+ .permute(0, 2, 1, 3)
211
+ .reshape(b * self.heads, t.shape[1], self.dim_head)
212
+ .contiguous(),
213
+ (k_ip, v_ip),
214
+ )
215
+ # actually compute the attention, what we cannot get enough of
216
+ out_ip = xformers.ops.memory_efficient_attention(
217
+ q, k_ip, v_ip, attn_bias=None, op=self.attention_op
218
+ )
219
+ out = out + self.ip_weight * out_ip
220
+
221
+ out = (
222
+ out.unsqueeze(0)
223
+ .reshape(b, self.heads, out.shape[1], self.dim_head)
224
+ .permute(0, 2, 1, 3)
225
+ .reshape(b, out.shape[1], self.heads * self.dim_head)
226
+ )
227
+ return self.to_out(out)
228
+
229
+
230
+ class BasicTransformerBlock3D(nn.Module):
231
+
232
+ def __init__(
233
+ self,
234
+ dim,
235
+ n_heads,
236
+ d_head,
237
+ context_dim,
238
+ dropout=0.0,
239
+ gated_ff=True,
240
+ ip_dim=0,
241
+ ip_weight=1,
242
+ ):
243
+ super().__init__()
244
+
245
+ self.attn1 = MemoryEfficientCrossAttention(
246
+ query_dim=dim,
247
+ context_dim=None, # self-attention
248
+ heads=n_heads,
249
+ dim_head=d_head,
250
+ dropout=dropout,
251
+ )
252
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
253
+ self.attn2 = MemoryEfficientCrossAttention(
254
+ query_dim=dim,
255
+ context_dim=context_dim,
256
+ heads=n_heads,
257
+ dim_head=d_head,
258
+ dropout=dropout,
259
+ # ip only applies to cross-attention
260
+ ip_dim=ip_dim,
261
+ ip_weight=ip_weight,
262
+ )
263
+ self.norm1 = nn.LayerNorm(dim)
264
+ self.norm2 = nn.LayerNorm(dim)
265
+ self.norm3 = nn.LayerNorm(dim)
266
+
267
+ def forward(self, x, context=None, num_frames=1):
268
+ x = rearrange(x, "(b f) l c -> b (f l) c", f=num_frames).contiguous()
269
+ x = self.attn1(self.norm1(x), context=None) + x
270
+ x = rearrange(x, "b (f l) c -> (b f) l c", f=num_frames).contiguous()
271
+ x = self.attn2(self.norm2(x), context=context) + x
272
+ x = self.ff(self.norm3(x)) + x
273
+ return x
274
+
275
+
276
+ class SpatialTransformer3D(nn.Module):
277
+
278
+ def __init__(
279
+ self,
280
+ in_channels,
281
+ n_heads,
282
+ d_head,
283
+ context_dim, # cross attention input dim
284
+ depth=1,
285
+ dropout=0.0,
286
+ ip_dim=0,
287
+ ip_weight=1,
288
+ ):
289
+ super().__init__()
290
+
291
+ if not isinstance(context_dim, list):
292
+ context_dim = [context_dim]
293
+
294
+ self.in_channels = in_channels
295
+
296
+ inner_dim = n_heads * d_head
297
+ self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
298
+ self.proj_in = nn.Linear(in_channels, inner_dim)
299
+
300
+ self.transformer_blocks = nn.ModuleList(
301
+ [
302
+ BasicTransformerBlock3D(
303
+ inner_dim,
304
+ n_heads,
305
+ d_head,
306
+ context_dim=context_dim[d],
307
+ dropout=dropout,
308
+ ip_dim=ip_dim,
309
+ ip_weight=ip_weight,
310
+ )
311
+ for d in range(depth)
312
+ ]
313
+ )
314
+
315
+ self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
316
+
317
+
318
+ def forward(self, x, context=None, num_frames=1):
319
+ # note: if no context is given, cross-attention defaults to self-attention
320
+ if not isinstance(context, list):
321
+ context = [context]
322
+ b, c, h, w = x.shape
323
+ x_in = x
324
+ x = self.norm(x)
325
+ x = rearrange(x, "b c h w -> b (h w) c").contiguous()
326
+ x = self.proj_in(x)
327
+ for i, block in enumerate(self.transformer_blocks):
328
+ x = block(x, context=context[i], num_frames=num_frames)
329
+ x = self.proj_out(x)
330
+ x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous()
331
+
332
+ return x + x_in
333
+
334
+
335
+ class PerceiverAttention(nn.Module):
336
+ def __init__(self, *, dim, dim_head=64, heads=8):
337
+ super().__init__()
338
+ self.scale = dim_head ** -0.5
339
+ self.dim_head = dim_head
340
+ self.heads = heads
341
+ inner_dim = dim_head * heads
342
+
343
+ self.norm1 = nn.LayerNorm(dim)
344
+ self.norm2 = nn.LayerNorm(dim)
345
+
346
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
347
+ self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
348
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
349
+
350
+ def forward(self, x, latents):
351
+ """
352
+ Args:
353
+ x (torch.Tensor): image features
354
+ shape (b, n1, D)
355
+ latent (torch.Tensor): latent features
356
+ shape (b, n2, D)
357
+ """
358
+ x = self.norm1(x)
359
+ latents = self.norm2(latents)
360
+
361
+ b, l, _ = latents.shape
362
+
363
+ q = self.to_q(latents)
364
+ kv_input = torch.cat((x, latents), dim=-2)
365
+ k, v = self.to_kv(kv_input).chunk(2, dim=-1)
366
+
367
+ q, k, v = map(
368
+ lambda t: t.reshape(b, t.shape[1], self.heads, -1)
369
+ .transpose(1, 2)
370
+ .reshape(b, self.heads, t.shape[1], -1)
371
+ .contiguous(),
372
+ (q, k, v),
373
+ )
374
+
375
+ # attention
376
+ scale = 1 / math.sqrt(math.sqrt(self.dim_head))
377
+ weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
378
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
379
+ out = weight @ v
380
+
381
+ out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
382
+
383
+ return self.to_out(out)
384
+
385
+
386
+ class Resampler(nn.Module):
387
+ def __init__(
388
+ self,
389
+ dim=1024,
390
+ depth=8,
391
+ dim_head=64,
392
+ heads=16,
393
+ num_queries=8,
394
+ embedding_dim=768,
395
+ output_dim=1024,
396
+ ff_mult=4,
397
+ ):
398
+ super().__init__()
399
+ self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim ** 0.5)
400
+ self.proj_in = nn.Linear(embedding_dim, dim)
401
+ self.proj_out = nn.Linear(dim, output_dim)
402
+ self.norm_out = nn.LayerNorm(output_dim)
403
+
404
+ self.layers = nn.ModuleList([])
405
+ for _ in range(depth):
406
+ self.layers.append(
407
+ nn.ModuleList(
408
+ [
409
+ PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
410
+ nn.Sequential(
411
+ nn.LayerNorm(dim),
412
+ nn.Linear(dim, dim * ff_mult, bias=False),
413
+ nn.GELU(),
414
+ nn.Linear(dim * ff_mult, dim, bias=False),
415
+ )
416
+ ]
417
+ )
418
+ )
419
+
420
+ def forward(self, x):
421
+ latents = self.latents.repeat(x.size(0), 1, 1)
422
+ x = self.proj_in(x)
423
+ for attn, ff in self.layers:
424
+ latents = attn(x, latents) + latents
425
+ latents = ff(latents) + latents
426
+
427
+ latents = self.proj_out(latents)
428
+ return self.norm_out(latents)
429
+
430
+
431
+ class CondSequential(nn.Sequential):
432
+ """
433
+ A sequential module that passes timestep embeddings to the children that
434
+ support it as an extra input.
435
+ """
436
+
437
+ def forward(self, x, emb, context=None, num_frames=1):
438
+ for layer in self:
439
+ if isinstance(layer, ResBlock):
440
+ x = layer(x, emb)
441
+ elif isinstance(layer, SpatialTransformer3D):
442
+ x = layer(x, context, num_frames=num_frames)
443
+ else:
444
+ x = layer(x)
445
+ return x
446
+
447
+
448
+ class Upsample(nn.Module):
449
+ """
450
+ An upsampling layer with an optional convolution.
451
+ :param channels: channels in the inputs and outputs.
452
+ :param use_conv: a bool determining if a convolution is applied.
453
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
454
+ upsampling occurs in the inner-two dimensions.
455
+ """
456
+
457
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
458
+ super().__init__()
459
+ self.channels = channels
460
+ self.out_channels = out_channels or channels
461
+ self.use_conv = use_conv
462
+ self.dims = dims
463
+ if use_conv:
464
+ self.conv = conv_nd(
465
+ dims, self.channels, self.out_channels, 3, padding=padding
466
+ )
467
+
468
+ def forward(self, x):
469
+ assert x.shape[1] == self.channels
470
+ if self.dims == 3:
471
+ x = F.interpolate(
472
+ x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
473
+ )
474
+ else:
475
+ x = F.interpolate(x, scale_factor=2, mode="nearest")
476
+ if self.use_conv:
477
+ x = self.conv(x)
478
+ return x
479
+
480
+
481
+ class Downsample(nn.Module):
482
+ """
483
+ A downsampling layer with an optional convolution.
484
+ :param channels: channels in the inputs and outputs.
485
+ :param use_conv: a bool determining if a convolution is applied.
486
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
487
+ downsampling occurs in the inner-two dimensions.
488
+ """
489
+
490
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
491
+ super().__init__()
492
+ self.channels = channels
493
+ self.out_channels = out_channels or channels
494
+ self.use_conv = use_conv
495
+ self.dims = dims
496
+ stride = 2 if dims != 3 else (1, 2, 2)
497
+ if use_conv:
498
+ self.op = conv_nd(
499
+ dims,
500
+ self.channels,
501
+ self.out_channels,
502
+ 3,
503
+ stride=stride,
504
+ padding=padding,
505
+ )
506
+ else:
507
+ assert self.channels == self.out_channels
508
+ self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
509
+
510
+ def forward(self, x):
511
+ assert x.shape[1] == self.channels
512
+ return self.op(x)
513
+
514
+
515
+ class ResBlock(nn.Module):
516
+ """
517
+ A residual block that can optionally change the number of channels.
518
+ :param channels: the number of input channels.
519
+ :param emb_channels: the number of timestep embedding channels.
520
+ :param dropout: the rate of dropout.
521
+ :param out_channels: if specified, the number of out channels.
522
+ :param use_conv: if True and out_channels is specified, use a spatial
523
+ convolution instead of a smaller 1x1 convolution to change the
524
+ channels in the skip connection.
525
+ :param dims: determines if the signal is 1D, 2D, or 3D.
526
+ :param up: if True, use this block for upsampling.
527
+ :param down: if True, use this block for downsampling.
528
+ """
529
+
530
+ def __init__(
531
+ self,
532
+ channels,
533
+ emb_channels,
534
+ dropout,
535
+ out_channels=None,
536
+ use_conv=False,
537
+ use_scale_shift_norm=False,
538
+ dims=2,
539
+ up=False,
540
+ down=False,
541
+ ):
542
+ super().__init__()
543
+ self.channels = channels
544
+ self.emb_channels = emb_channels
545
+ self.dropout = dropout
546
+ self.out_channels = out_channels or channels
547
+ self.use_conv = use_conv
548
+ self.use_scale_shift_norm = use_scale_shift_norm
549
+
550
+ self.in_layers = nn.Sequential(
551
+ nn.GroupNorm(32, channels),
552
+ nn.SiLU(),
553
+ conv_nd(dims, channels, self.out_channels, 3, padding=1),
554
+ )
555
+
556
+ self.updown = up or down
557
+
558
+ if up:
559
+ self.h_upd = Upsample(channels, False, dims)
560
+ self.x_upd = Upsample(channels, False, dims)
561
+ elif down:
562
+ self.h_upd = Downsample(channels, False, dims)
563
+ self.x_upd = Downsample(channels, False, dims)
564
+ else:
565
+ self.h_upd = self.x_upd = nn.Identity()
566
+
567
+ self.emb_layers = nn.Sequential(
568
+ nn.SiLU(),
569
+ nn.Linear(
570
+ emb_channels,
571
+ 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
572
+ ),
573
+ )
574
+ self.out_layers = nn.Sequential(
575
+ nn.GroupNorm(32, self.out_channels),
576
+ nn.SiLU(),
577
+ nn.Dropout(p=dropout),
578
+ zero_module(
579
+ conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
580
+ ),
581
+ )
582
+
583
+ if self.out_channels == channels:
584
+ self.skip_connection = nn.Identity()
585
+ elif use_conv:
586
+ self.skip_connection = conv_nd(
587
+ dims, channels, self.out_channels, 3, padding=1
588
+ )
589
+ else:
590
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
591
+
592
+ def forward(self, x, emb):
593
+ if self.updown:
594
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
595
+ h = in_rest(x)
596
+ h = self.h_upd(h)
597
+ x = self.x_upd(x)
598
+ h = in_conv(h)
599
+ else:
600
+ h = self.in_layers(x)
601
+ emb_out = self.emb_layers(emb).type(h.dtype)
602
+ while len(emb_out.shape) < len(h.shape):
603
+ emb_out = emb_out[..., None]
604
+ if self.use_scale_shift_norm:
605
+ out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
606
+ scale, shift = torch.chunk(emb_out, 2, dim=1)
607
+ h = out_norm(h) * (1 + scale) + shift
608
+ h = out_rest(h)
609
+ else:
610
+ h = h + emb_out
611
+ h = self.out_layers(h)
612
+ return self.skip_connection(x) + h
613
+
614
+
615
+ class MultiViewUNetModel(ModelMixin, ConfigMixin):
616
+ """
617
+ The full multi-view UNet model with attention, timestep embedding and camera embedding.
618
+ :param in_channels: channels in the input Tensor.
619
+ :param model_channels: base channel count for the model.
620
+ :param out_channels: channels in the output Tensor.
621
+ :param num_res_blocks: number of residual blocks per downsample.
622
+ :param attention_resolutions: a collection of downsample rates at which
623
+ attention will take place. May be a set, list, or tuple.
624
+ For example, if this contains 4, then at 4x downsampling, attention
625
+ will be used.
626
+ :param dropout: the dropout probability.
627
+ :param channel_mult: channel multiplier for each level of the UNet.
628
+ :param conv_resample: if True, use learned convolutions for upsampling and
629
+ downsampling.
630
+ :param dims: determines if the signal is 1D, 2D, or 3D.
631
+ :param num_classes: if specified (as an int), then this model will be
632
+ class-conditional with `num_classes` classes.
633
+ :param num_heads: the number of attention heads in each attention layer.
634
+ :param num_heads_channels: if specified, ignore num_heads and instead use
635
+ a fixed channel width per attention head.
636
+ :param num_heads_upsample: works with num_heads to set a different number
637
+ of heads for upsampling. Deprecated.
638
+ :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
639
+ :param resblock_updown: use residual blocks for up/downsampling.
640
+ :param use_new_attention_order: use a different attention pattern for potentially
641
+ increased efficiency.
642
+ :param camera_dim: dimensionality of camera input.
643
+ """
644
+
645
+ def __init__(
646
+ self,
647
+ image_size,
648
+ in_channels,
649
+ model_channels,
650
+ out_channels,
651
+ num_res_blocks,
652
+ attention_resolutions,
653
+ dropout=0,
654
+ channel_mult=(1, 2, 4, 8),
655
+ conv_resample=True,
656
+ dims=2,
657
+ num_classes=None,
658
+ num_heads=-1,
659
+ num_head_channels=-1,
660
+ num_heads_upsample=-1,
661
+ use_scale_shift_norm=False,
662
+ resblock_updown=False,
663
+ transformer_depth=1,
664
+ context_dim=None,
665
+ n_embed=None,
666
+ num_attention_blocks=None,
667
+ adm_in_channels=None,
668
+ camera_dim=None,
669
+ ip_dim=0, # imagedream uses ip_dim > 0
670
+ ip_weight=1.0,
671
+ **kwargs,
672
+ ):
673
+ super().__init__()
674
+ assert context_dim is not None
675
+
676
+ if num_heads_upsample == -1:
677
+ num_heads_upsample = num_heads
678
+
679
+ if num_heads == -1:
680
+ assert (
681
+ num_head_channels != -1
682
+ ), "Either num_heads or num_head_channels has to be set"
683
+
684
+ if num_head_channels == -1:
685
+ assert (
686
+ num_heads != -1
687
+ ), "Either num_heads or num_head_channels has to be set"
688
+
689
+ self.image_size = image_size
690
+ self.in_channels = in_channels
691
+ self.model_channels = model_channels
692
+ self.out_channels = out_channels
693
+ if isinstance(num_res_blocks, int):
694
+ self.num_res_blocks = len(channel_mult) * [num_res_blocks]
695
+ else:
696
+ if len(num_res_blocks) != len(channel_mult):
697
+ raise ValueError(
698
+ "provide num_res_blocks either as an int (globally constant) or "
699
+ "as a list/tuple (per-level) with the same length as channel_mult"
700
+ )
701
+ self.num_res_blocks = num_res_blocks
702
+
703
+ if num_attention_blocks is not None:
704
+ assert len(num_attention_blocks) == len(self.num_res_blocks)
705
+ assert all(
706
+ map(
707
+ lambda i: self.num_res_blocks[i] >= num_attention_blocks[i],
708
+ range(len(num_attention_blocks)),
709
+ )
710
+ )
711
+ print(
712
+ f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
713
+ f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
714
+ f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
715
+ f"attention will still not be set."
716
+ )
717
+
718
+ self.attention_resolutions = attention_resolutions
719
+ self.dropout = dropout
720
+ self.channel_mult = channel_mult
721
+ self.conv_resample = conv_resample
722
+ self.num_classes = num_classes
723
+ self.num_heads = num_heads
724
+ self.num_head_channels = num_head_channels
725
+ self.num_heads_upsample = num_heads_upsample
726
+ self.predict_codebook_ids = n_embed is not None
727
+
728
+ self.ip_dim = ip_dim
729
+ self.ip_weight = ip_weight
730
+
731
+ if self.ip_dim > 0:
732
+ self.image_embed = Resampler(
733
+ dim=context_dim,
734
+ depth=4,
735
+ dim_head=64,
736
+ heads=12,
737
+ num_queries=ip_dim, # num token
738
+ embedding_dim=1280,
739
+ output_dim=context_dim,
740
+ ff_mult=4,
741
+ )
742
+
743
+ time_embed_dim = model_channels * 4
744
+ self.time_embed = nn.Sequential(
745
+ nn.Linear(model_channels, time_embed_dim),
746
+ nn.SiLU(),
747
+ nn.Linear(time_embed_dim, time_embed_dim),
748
+ )
749
+
750
+ if camera_dim is not None:
751
+ time_embed_dim = model_channels * 4
752
+ self.camera_embed = nn.Sequential(
753
+ nn.Linear(camera_dim, time_embed_dim),
754
+ nn.SiLU(),
755
+ nn.Linear(time_embed_dim, time_embed_dim),
756
+ )
757
+
758
+ if self.num_classes is not None:
759
+ if isinstance(self.num_classes, int):
760
+ self.label_emb = nn.Embedding(self.num_classes, time_embed_dim)
761
+ elif self.num_classes == "continuous":
762
+ # print("setting up linear c_adm embedding layer")
763
+ self.label_emb = nn.Linear(1, time_embed_dim)
764
+ elif self.num_classes == "sequential":
765
+ assert adm_in_channels is not None
766
+ self.label_emb = nn.Sequential(
767
+ nn.Sequential(
768
+ nn.Linear(adm_in_channels, time_embed_dim),
769
+ nn.SiLU(),
770
+ nn.Linear(time_embed_dim, time_embed_dim),
771
+ )
772
+ )
773
+ else:
774
+ raise ValueError()
775
+
776
+ self.input_blocks = nn.ModuleList(
777
+ [
778
+ CondSequential(
779
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
780
+ )
781
+ ]
782
+ )
783
+ self._feature_size = model_channels
784
+ input_block_chans = [model_channels]
785
+ ch = model_channels
786
+ ds = 1
787
+ for level, mult in enumerate(channel_mult):
788
+ for nr in range(self.num_res_blocks[level]):
789
+ layers: List[Any] = [
790
+ ResBlock(
791
+ ch,
792
+ time_embed_dim,
793
+ dropout,
794
+ out_channels=mult * model_channels,
795
+ dims=dims,
796
+ use_scale_shift_norm=use_scale_shift_norm,
797
+ )
798
+ ]
799
+ ch = mult * model_channels
800
+ if ds in attention_resolutions:
801
+ if num_head_channels == -1:
802
+ dim_head = ch // num_heads
803
+ else:
804
+ num_heads = ch // num_head_channels
805
+ dim_head = num_head_channels
806
+
807
+ if num_attention_blocks is None or nr < num_attention_blocks[level]:
808
+ layers.append(
809
+ SpatialTransformer3D(
810
+ ch,
811
+ num_heads,
812
+ dim_head,
813
+ context_dim=context_dim,
814
+ depth=transformer_depth,
815
+ ip_dim=self.ip_dim,
816
+ ip_weight=self.ip_weight,
817
+ )
818
+ )
819
+ self.input_blocks.append(CondSequential(*layers))
820
+ self._feature_size += ch
821
+ input_block_chans.append(ch)
822
+ if level != len(channel_mult) - 1:
823
+ out_ch = ch
824
+ self.input_blocks.append(
825
+ CondSequential(
826
+ ResBlock(
827
+ ch,
828
+ time_embed_dim,
829
+ dropout,
830
+ out_channels=out_ch,
831
+ dims=dims,
832
+ use_scale_shift_norm=use_scale_shift_norm,
833
+ down=True,
834
+ )
835
+ if resblock_updown
836
+ else Downsample(
837
+ ch, conv_resample, dims=dims, out_channels=out_ch
838
+ )
839
+ )
840
+ )
841
+ ch = out_ch
842
+ input_block_chans.append(ch)
843
+ ds *= 2
844
+ self._feature_size += ch
845
+
846
+ if num_head_channels == -1:
847
+ dim_head = ch // num_heads
848
+ else:
849
+ num_heads = ch // num_head_channels
850
+ dim_head = num_head_channels
851
+
852
+ self.middle_block = CondSequential(
853
+ ResBlock(
854
+ ch,
855
+ time_embed_dim,
856
+ dropout,
857
+ dims=dims,
858
+ use_scale_shift_norm=use_scale_shift_norm,
859
+ ),
860
+ SpatialTransformer3D(
861
+ ch,
862
+ num_heads,
863
+ dim_head,
864
+ context_dim=context_dim,
865
+ depth=transformer_depth,
866
+ ip_dim=self.ip_dim,
867
+ ip_weight=self.ip_weight,
868
+ ),
869
+ ResBlock(
870
+ ch,
871
+ time_embed_dim,
872
+ dropout,
873
+ dims=dims,
874
+ use_scale_shift_norm=use_scale_shift_norm,
875
+ ),
876
+ )
877
+ self._feature_size += ch
878
+
879
+ self.output_blocks = nn.ModuleList([])
880
+ for level, mult in list(enumerate(channel_mult))[::-1]:
881
+ for i in range(self.num_res_blocks[level] + 1):
882
+ ich = input_block_chans.pop()
883
+ layers = [
884
+ ResBlock(
885
+ ch + ich,
886
+ time_embed_dim,
887
+ dropout,
888
+ out_channels=model_channels * mult,
889
+ dims=dims,
890
+ use_scale_shift_norm=use_scale_shift_norm,
891
+ )
892
+ ]
893
+ ch = model_channels * mult
894
+ if ds in attention_resolutions:
895
+ if num_head_channels == -1:
896
+ dim_head = ch // num_heads
897
+ else:
898
+ num_heads = ch // num_head_channels
899
+ dim_head = num_head_channels
900
+
901
+ if num_attention_blocks is None or i < num_attention_blocks[level]:
902
+ layers.append(
903
+ SpatialTransformer3D(
904
+ ch,
905
+ num_heads,
906
+ dim_head,
907
+ context_dim=context_dim,
908
+ depth=transformer_depth,
909
+ ip_dim=self.ip_dim,
910
+ ip_weight=self.ip_weight,
911
+ )
912
+ )
913
+ if level and i == self.num_res_blocks[level]:
914
+ out_ch = ch
915
+ layers.append(
916
+ ResBlock(
917
+ ch,
918
+ time_embed_dim,
919
+ dropout,
920
+ out_channels=out_ch,
921
+ dims=dims,
922
+ use_scale_shift_norm=use_scale_shift_norm,
923
+ up=True,
924
+ )
925
+ if resblock_updown
926
+ else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
927
+ )
928
+ ds //= 2
929
+ self.output_blocks.append(CondSequential(*layers))
930
+ self._feature_size += ch
931
+
932
+ self.out = nn.Sequential(
933
+ nn.GroupNorm(32, ch),
934
+ nn.SiLU(),
935
+ zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
936
+ )
937
+ if self.predict_codebook_ids:
938
+ self.id_predictor = nn.Sequential(
939
+ nn.GroupNorm(32, ch),
940
+ conv_nd(dims, model_channels, n_embed, 1),
941
+ # nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
942
+ )
943
+
944
+ def forward(
945
+ self,
946
+ x,
947
+ timesteps=None,
948
+ context=None,
949
+ y=None,
950
+ camera=None,
951
+ num_frames=1,
952
+ ip=None,
953
+ ip_img=None,
954
+ **kwargs,
955
+ ):
956
+ """
957
+ Apply the model to an input batch.
958
+ :param x: an [(N x F) x C x ...] Tensor of inputs. F is the number of frames (views).
959
+ :param timesteps: a 1-D batch of timesteps.
960
+ :param context: conditioning plugged in via crossattn
961
+ :param y: an [N] Tensor of labels, if class-conditional.
962
+ :param num_frames: a integer indicating number of frames for tensor reshaping.
963
+ :return: an [(N x F) x C x ...] Tensor of outputs. F is the number of frames (views).
964
+ """
965
+ assert (
966
+ x.shape[0] % num_frames == 0
967
+ ), "input batch size must be dividable by num_frames!"
968
+ assert (y is not None) == (
969
+ self.num_classes is not None
970
+ ), "must specify y if and only if the model is class-conditional"
971
+
972
+ hs = []
973
+
974
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
975
+
976
+ emb = self.time_embed(t_emb)
977
+
978
+ if self.num_classes is not None:
979
+ assert y is not None
980
+ assert y.shape[0] == x.shape[0]
981
+ emb = emb + self.label_emb(y)
982
+
983
+ # Add camera embeddings
984
+ if camera is not None:
985
+ emb = emb + self.camera_embed(camera)
986
+
987
+ # imagedream variant
988
+ if self.ip_dim > 0:
989
+ x[(num_frames - 1) :: num_frames, :, :, :] = ip_img # place at [4, 9]
990
+ ip_emb = self.image_embed(ip)
991
+ context = torch.cat((context, ip_emb), 1)
992
+
993
+ h = x
994
+ for module in self.input_blocks:
995
+ h = module(h, emb, context, num_frames=num_frames)
996
+ hs.append(h)
997
+ h = self.middle_block(h, emb, context, num_frames=num_frames)
998
+ for module in self.output_blocks:
999
+ h = torch.cat([h, hs.pop()], dim=1)
1000
+ h = module(h, emb, context, num_frames=num_frames)
1001
+ h = h.type(x.dtype)
1002
+ if self.predict_codebook_ids:
1003
+ return self.id_predictor(h)
1004
+ else:
1005
+ return self.out(h)
vae/config.json ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "AutoencoderKL",
3
+ "_diffusers_version": "0.25.0",
4
+ "act_fn": "silu",
5
+ "block_out_channels": [
6
+ 128,
7
+ 256,
8
+ 512,
9
+ 512
10
+ ],
11
+ "down_block_types": [
12
+ "DownEncoderBlock2D",
13
+ "DownEncoderBlock2D",
14
+ "DownEncoderBlock2D",
15
+ "DownEncoderBlock2D"
16
+ ],
17
+ "force_upcast": true,
18
+ "in_channels": 3,
19
+ "latent_channels": 4,
20
+ "layers_per_block": 2,
21
+ "norm_num_groups": 32,
22
+ "out_channels": 3,
23
+ "sample_size": 256,
24
+ "scaling_factor": 0.18215,
25
+ "up_block_types": [
26
+ "UpDecoderBlock2D",
27
+ "UpDecoderBlock2D",
28
+ "UpDecoderBlock2D",
29
+ "UpDecoderBlock2D"
30
+ ]
31
+ }
vae/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3e4c08995484ee61270175e9e7a072b66a6e4eeb5f0c266667fe1f45b90daf9a
3
+ size 167335342