further clean!
Browse files- main.py +2 -2
- mvdream/attention.py +13 -85
- mvdream/models.py +12 -176
main.py
CHANGED
@@ -5,8 +5,8 @@ import argparse
|
|
5 |
from mvdream.pipeline_mvdream import MVDreamStableDiffusionPipeline
|
6 |
|
7 |
pipe = MVDreamStableDiffusionPipeline.from_pretrained(
|
8 |
-
|
9 |
-
"ashawkey/mvdream-sd2.1-diffusers",
|
10 |
torch_dtype=torch.float16
|
11 |
)
|
12 |
pipe = pipe.to("cuda")
|
|
|
5 |
from mvdream.pipeline_mvdream import MVDreamStableDiffusionPipeline
|
6 |
|
7 |
pipe = MVDreamStableDiffusionPipeline.from_pretrained(
|
8 |
+
"./weights", # local weights
|
9 |
+
# "ashawkey/mvdream-sd2.1-diffusers",
|
10 |
torch_dtype=torch.float16
|
11 |
)
|
12 |
pipe = pipe.to("cuda")
|
mvdream/attention.py
CHANGED
@@ -2,14 +2,14 @@
|
|
2 |
|
3 |
import math
|
4 |
import torch
|
|
|
5 |
import torch.nn.functional as F
|
|
|
6 |
|
7 |
from inspect import isfunction
|
8 |
-
from torch import nn, einsum
|
9 |
-
from torch.amp.autocast_mode import autocast
|
10 |
from einops import rearrange, repeat
|
11 |
from typing import Optional, Any
|
12 |
-
from .util import checkpoint
|
13 |
|
14 |
try:
|
15 |
import xformers # type: ignore
|
@@ -25,28 +25,12 @@ import os
|
|
25 |
_ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32")
|
26 |
|
27 |
|
28 |
-
def uniq(arr):
|
29 |
-
return {el: True for el in arr}.keys()
|
30 |
-
|
31 |
-
|
32 |
def default(val, d):
|
33 |
if val is not None:
|
34 |
return val
|
35 |
return d() if isfunction(d) else d
|
36 |
|
37 |
|
38 |
-
def max_neg_value(t):
|
39 |
-
return -torch.finfo(t.dtype).max
|
40 |
-
|
41 |
-
|
42 |
-
def init_(tensor):
|
43 |
-
dim = tensor.shape[-1]
|
44 |
-
std = 1 / math.sqrt(dim)
|
45 |
-
tensor.uniform_(-std, std)
|
46 |
-
return tensor
|
47 |
-
|
48 |
-
|
49 |
-
# feedforward
|
50 |
class GEGLU(nn.Module):
|
51 |
def __init__(self, dim_in, dim_out):
|
52 |
super().__init__()
|
@@ -76,66 +60,6 @@ class FeedForward(nn.Module):
|
|
76 |
return self.net(x)
|
77 |
|
78 |
|
79 |
-
def zero_module(module):
|
80 |
-
"""
|
81 |
-
Zero out the parameters of a module and return it.
|
82 |
-
"""
|
83 |
-
for p in module.parameters():
|
84 |
-
p.detach().zero_()
|
85 |
-
return module
|
86 |
-
|
87 |
-
|
88 |
-
def Normalize(in_channels):
|
89 |
-
return torch.nn.GroupNorm(
|
90 |
-
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
|
91 |
-
)
|
92 |
-
|
93 |
-
|
94 |
-
class SpatialSelfAttention(nn.Module):
|
95 |
-
def __init__(self, in_channels):
|
96 |
-
super().__init__()
|
97 |
-
self.in_channels = in_channels
|
98 |
-
|
99 |
-
self.norm = Normalize(in_channels)
|
100 |
-
self.q = torch.nn.Conv2d(
|
101 |
-
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
102 |
-
)
|
103 |
-
self.k = torch.nn.Conv2d(
|
104 |
-
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
105 |
-
)
|
106 |
-
self.v = torch.nn.Conv2d(
|
107 |
-
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
108 |
-
)
|
109 |
-
self.proj_out = torch.nn.Conv2d(
|
110 |
-
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
111 |
-
)
|
112 |
-
|
113 |
-
def forward(self, x):
|
114 |
-
h_ = x
|
115 |
-
h_ = self.norm(h_)
|
116 |
-
q = self.q(h_)
|
117 |
-
k = self.k(h_)
|
118 |
-
v = self.v(h_)
|
119 |
-
|
120 |
-
# compute attention
|
121 |
-
b, c, h, w = q.shape
|
122 |
-
q = rearrange(q, "b c h w -> b (h w) c")
|
123 |
-
k = rearrange(k, "b c h w -> b c (h w)")
|
124 |
-
w_ = torch.einsum("bij,bjk->bik", q, k)
|
125 |
-
|
126 |
-
w_ = w_ * (int(c) ** (-0.5))
|
127 |
-
w_ = torch.nn.functional.softmax(w_, dim=2)
|
128 |
-
|
129 |
-
# attend to values
|
130 |
-
v = rearrange(v, "b c h w -> b c (h w)")
|
131 |
-
w_ = rearrange(w_, "b i j -> b j i")
|
132 |
-
h_ = torch.einsum("bij,bjk->bik", v, w_)
|
133 |
-
h_ = rearrange(h_, "b c (h w) -> b c h w", h=h)
|
134 |
-
h_ = self.proj_out(h_)
|
135 |
-
|
136 |
-
return x + h_
|
137 |
-
|
138 |
-
|
139 |
class CrossAttention(nn.Module):
|
140 |
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
|
141 |
super().__init__()
|
@@ -167,9 +91,9 @@ class CrossAttention(nn.Module):
|
|
167 |
if _ATTN_PRECISION == "fp32":
|
168 |
with autocast(enabled=False, device_type="cuda"):
|
169 |
q, k = q.float(), k.float()
|
170 |
-
sim = einsum("b i d, b j d -> b i j", q, k) * self.scale
|
171 |
else:
|
172 |
-
sim = einsum("b i d, b j d -> b i j", q, k) * self.scale
|
173 |
|
174 |
del q, k
|
175 |
|
@@ -182,7 +106,7 @@ class CrossAttention(nn.Module):
|
|
182 |
# attention, what we cannot get enough of
|
183 |
sim = sim.softmax(dim=-1)
|
184 |
|
185 |
-
out = einsum("b i j, b j d -> b i d", sim, v)
|
186 |
out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
|
187 |
return self.to_out(out)
|
188 |
|
@@ -326,7 +250,9 @@ class SpatialTransformer(nn.Module):
|
|
326 |
context_dim = [context_dim]
|
327 |
self.in_channels = in_channels
|
328 |
inner_dim = n_heads * d_head
|
329 |
-
self.norm =
|
|
|
|
|
330 |
if not use_linear:
|
331 |
self.proj_in = nn.Conv2d(
|
332 |
in_channels, inner_dim, kernel_size=1, stride=1, padding=0
|
@@ -410,7 +336,7 @@ class SpatialTransformer3D(nn.Module):
|
|
410 |
dropout=0.0,
|
411 |
context_dim=None,
|
412 |
disable_self_attn=False,
|
413 |
-
use_linear=
|
414 |
use_checkpoint=True,
|
415 |
):
|
416 |
super().__init__()
|
@@ -419,7 +345,9 @@ class SpatialTransformer3D(nn.Module):
|
|
419 |
context_dim = [context_dim]
|
420 |
self.in_channels = in_channels
|
421 |
inner_dim = n_heads * d_head
|
422 |
-
self.norm =
|
|
|
|
|
423 |
if not use_linear:
|
424 |
self.proj_in = nn.Conv2d(
|
425 |
in_channels, inner_dim, kernel_size=1, stride=1, padding=0
|
|
|
2 |
|
3 |
import math
|
4 |
import torch
|
5 |
+
import torch.nn as nn
|
6 |
import torch.nn.functional as F
|
7 |
+
from torch.amp.autocast_mode import autocast
|
8 |
|
9 |
from inspect import isfunction
|
|
|
|
|
10 |
from einops import rearrange, repeat
|
11 |
from typing import Optional, Any
|
12 |
+
from .util import checkpoint, zero_module
|
13 |
|
14 |
try:
|
15 |
import xformers # type: ignore
|
|
|
25 |
_ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32")
|
26 |
|
27 |
|
|
|
|
|
|
|
|
|
28 |
def default(val, d):
|
29 |
if val is not None:
|
30 |
return val
|
31 |
return d() if isfunction(d) else d
|
32 |
|
33 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
class GEGLU(nn.Module):
|
35 |
def __init__(self, dim_in, dim_out):
|
36 |
super().__init__()
|
|
|
60 |
return self.net(x)
|
61 |
|
62 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
class CrossAttention(nn.Module):
|
64 |
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
|
65 |
super().__init__()
|
|
|
91 |
if _ATTN_PRECISION == "fp32":
|
92 |
with autocast(enabled=False, device_type="cuda"):
|
93 |
q, k = q.float(), k.float()
|
94 |
+
sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale
|
95 |
else:
|
96 |
+
sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale
|
97 |
|
98 |
del q, k
|
99 |
|
|
|
106 |
# attention, what we cannot get enough of
|
107 |
sim = sim.softmax(dim=-1)
|
108 |
|
109 |
+
out = torch.einsum("b i j, b j d -> b i d", sim, v)
|
110 |
out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
|
111 |
return self.to_out(out)
|
112 |
|
|
|
250 |
context_dim = [context_dim]
|
251 |
self.in_channels = in_channels
|
252 |
inner_dim = n_heads * d_head
|
253 |
+
self.norm = nn.GroupNorm(
|
254 |
+
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
|
255 |
+
)
|
256 |
if not use_linear:
|
257 |
self.proj_in = nn.Conv2d(
|
258 |
in_channels, inner_dim, kernel_size=1, stride=1, padding=0
|
|
|
336 |
dropout=0.0,
|
337 |
context_dim=None,
|
338 |
disable_self_attn=False,
|
339 |
+
use_linear=True,
|
340 |
use_checkpoint=True,
|
341 |
):
|
342 |
super().__init__()
|
|
|
345 |
context_dim = [context_dim]
|
346 |
self.in_channels = in_channels
|
347 |
inner_dim = n_heads * d_head
|
348 |
+
self.norm = nn.GroupNorm(
|
349 |
+
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
|
350 |
+
)
|
351 |
if not use_linear:
|
352 |
self.proj_in = nn.Conv2d(
|
353 |
in_channels, inner_dim, kernel_size=1, stride=1, padding=0
|
mvdream/models.py
CHANGED
@@ -1,8 +1,7 @@
|
|
1 |
# obtained and modified from https://github.com/bytedance/MVDream
|
2 |
|
3 |
import math
|
4 |
-
import
|
5 |
-
import torch as th
|
6 |
import torch.nn as nn
|
7 |
import torch.nn.functional as F
|
8 |
from diffusers.configuration_utils import ConfigMixin
|
@@ -223,7 +222,7 @@ class ResBlock(TimestepBlock):
|
|
223 |
emb_out = emb_out[..., None]
|
224 |
if self.use_scale_shift_norm:
|
225 |
out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
|
226 |
-
scale, shift =
|
227 |
h = out_norm(h) * (1 + scale) + shift
|
228 |
h = out_rest(h)
|
229 |
else:
|
@@ -232,112 +231,6 @@ class ResBlock(TimestepBlock):
|
|
232 |
return self.skip_connection(x) + h
|
233 |
|
234 |
|
235 |
-
class AttentionBlock(nn.Module):
|
236 |
-
"""
|
237 |
-
An attention block that allows spatial positions to attend to each other.
|
238 |
-
Originally ported from here, but adapted to the N-d case.
|
239 |
-
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
|
240 |
-
"""
|
241 |
-
|
242 |
-
def __init__(
|
243 |
-
self,
|
244 |
-
channels,
|
245 |
-
num_heads=1,
|
246 |
-
num_head_channels=-1,
|
247 |
-
use_checkpoint=False,
|
248 |
-
use_new_attention_order=False,
|
249 |
-
):
|
250 |
-
super().__init__()
|
251 |
-
self.channels = channels
|
252 |
-
if num_head_channels == -1:
|
253 |
-
self.num_heads = num_heads
|
254 |
-
else:
|
255 |
-
assert (
|
256 |
-
channels % num_head_channels == 0
|
257 |
-
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
|
258 |
-
self.num_heads = channels // num_head_channels
|
259 |
-
self.use_checkpoint = use_checkpoint
|
260 |
-
self.norm = nn.GroupNorm(32, channels)
|
261 |
-
self.qkv = conv_nd(1, channels, channels * 3, 1)
|
262 |
-
if use_new_attention_order:
|
263 |
-
# split qkv before split heads
|
264 |
-
self.attention = QKVAttention(self.num_heads)
|
265 |
-
else:
|
266 |
-
# split heads before split qkv
|
267 |
-
self.attention = QKVAttentionLegacy(self.num_heads)
|
268 |
-
|
269 |
-
self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
|
270 |
-
|
271 |
-
def forward(self, x):
|
272 |
-
return checkpoint(self._forward, (x,), self.parameters(), True)
|
273 |
-
|
274 |
-
def _forward(self, x):
|
275 |
-
b, c, *spatial = x.shape
|
276 |
-
x = x.reshape(b, c, -1)
|
277 |
-
qkv = self.qkv(self.norm(x))
|
278 |
-
h = self.attention(qkv)
|
279 |
-
h = self.proj_out(h)
|
280 |
-
return (x + h).reshape(b, c, *spatial)
|
281 |
-
|
282 |
-
|
283 |
-
class QKVAttentionLegacy(nn.Module):
|
284 |
-
"""
|
285 |
-
A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
|
286 |
-
"""
|
287 |
-
|
288 |
-
def __init__(self, n_heads):
|
289 |
-
super().__init__()
|
290 |
-
self.n_heads = n_heads
|
291 |
-
|
292 |
-
def forward(self, qkv):
|
293 |
-
"""
|
294 |
-
Apply QKV attention.
|
295 |
-
:param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
|
296 |
-
:return: an [N x (H * C) x T] tensor after attention.
|
297 |
-
"""
|
298 |
-
bs, width, length = qkv.shape
|
299 |
-
assert width % (3 * self.n_heads) == 0
|
300 |
-
ch = width // (3 * self.n_heads)
|
301 |
-
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
|
302 |
-
scale = 1 / math.sqrt(math.sqrt(ch))
|
303 |
-
weight = th.einsum(
|
304 |
-
"bct,bcs->bts", q * scale, k * scale
|
305 |
-
) # More stable with f16 than dividing afterwards
|
306 |
-
weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
|
307 |
-
a = th.einsum("bts,bcs->bct", weight, v)
|
308 |
-
return a.reshape(bs, -1, length)
|
309 |
-
|
310 |
-
|
311 |
-
class QKVAttention(nn.Module):
|
312 |
-
"""
|
313 |
-
A module which performs QKV attention and splits in a different order.
|
314 |
-
"""
|
315 |
-
|
316 |
-
def __init__(self, n_heads):
|
317 |
-
super().__init__()
|
318 |
-
self.n_heads = n_heads
|
319 |
-
|
320 |
-
def forward(self, qkv):
|
321 |
-
"""
|
322 |
-
Apply QKV attention.
|
323 |
-
:param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
|
324 |
-
:return: an [N x (H * C) x T] tensor after attention.
|
325 |
-
"""
|
326 |
-
bs, width, length = qkv.shape
|
327 |
-
assert width % (3 * self.n_heads) == 0
|
328 |
-
ch = width // (3 * self.n_heads)
|
329 |
-
q, k, v = qkv.chunk(3, dim=1)
|
330 |
-
scale = 1 / math.sqrt(math.sqrt(ch))
|
331 |
-
weight = th.einsum(
|
332 |
-
"bct,bcs->bts",
|
333 |
-
(q * scale).view(bs * self.n_heads, ch, length),
|
334 |
-
(k * scale).view(bs * self.n_heads, ch, length),
|
335 |
-
) # More stable with f16 than dividing afterwards
|
336 |
-
weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
|
337 |
-
a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
|
338 |
-
return a.reshape(bs, -1, length)
|
339 |
-
|
340 |
-
|
341 |
class MultiViewUNetModel(ModelMixin, ConfigMixin):
|
342 |
"""
|
343 |
The full multi-view UNet model with attention, timestep embedding and camera embedding.
|
@@ -388,34 +281,18 @@ class MultiViewUNetModel(ModelMixin, ConfigMixin):
|
|
388 |
num_heads_upsample=-1,
|
389 |
use_scale_shift_norm=False,
|
390 |
resblock_updown=False,
|
391 |
-
use_new_attention_order=False,
|
392 |
-
use_spatial_transformer=False, # custom transformer support
|
393 |
transformer_depth=1, # custom transformer support
|
394 |
context_dim=None, # custom transformer support
|
395 |
n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
|
396 |
-
legacy=True,
|
397 |
disable_self_attentions=None,
|
398 |
num_attention_blocks=None,
|
399 |
disable_middle_self_attn=False,
|
400 |
-
use_linear_in_transformer=False,
|
401 |
adm_in_channels=None,
|
402 |
camera_dim=None,
|
403 |
):
|
404 |
super().__init__()
|
405 |
-
|
406 |
-
|
407 |
-
context_dim is not None
|
408 |
-
), "Fool!! You forgot to include the dimension of your cross-attention conditioning..."
|
409 |
-
|
410 |
-
if context_dim is not None:
|
411 |
-
assert (
|
412 |
-
use_spatial_transformer
|
413 |
-
), "Fool!! You forgot to use the spatial transformer for your cross-attention conditioning..."
|
414 |
-
from omegaconf.listconfig import ListConfig
|
415 |
-
|
416 |
-
if type(context_dim) == ListConfig:
|
417 |
-
context_dim = list(context_dim)
|
418 |
-
|
419 |
if num_heads_upsample == -1:
|
420 |
num_heads_upsample = num_heads
|
421 |
|
@@ -535,13 +412,7 @@ class MultiViewUNetModel(ModelMixin, ConfigMixin):
|
|
535 |
else:
|
536 |
num_heads = ch // num_head_channels
|
537 |
dim_head = num_head_channels
|
538 |
-
|
539 |
-
# num_heads = 1
|
540 |
-
dim_head = (
|
541 |
-
ch // num_heads
|
542 |
-
if use_spatial_transformer
|
543 |
-
else num_head_channels
|
544 |
-
)
|
545 |
if disable_self_attentions is not None:
|
546 |
disabled_sa = disable_self_attentions[level]
|
547 |
else:
|
@@ -549,22 +420,13 @@ class MultiViewUNetModel(ModelMixin, ConfigMixin):
|
|
549 |
|
550 |
if num_attention_blocks is None or nr < num_attention_blocks[level]:
|
551 |
layers.append(
|
552 |
-
|
553 |
-
ch,
|
554 |
-
use_checkpoint=use_checkpoint,
|
555 |
-
num_heads=num_heads,
|
556 |
-
num_head_channels=dim_head,
|
557 |
-
use_new_attention_order=use_new_attention_order,
|
558 |
-
)
|
559 |
-
if not use_spatial_transformer
|
560 |
-
else SpatialTransformer3D(
|
561 |
ch,
|
562 |
num_heads,
|
563 |
dim_head,
|
564 |
depth=transformer_depth,
|
565 |
context_dim=context_dim,
|
566 |
disable_self_attn=disabled_sa,
|
567 |
-
use_linear=use_linear_in_transformer,
|
568 |
use_checkpoint=use_checkpoint,
|
569 |
)
|
570 |
)
|
@@ -601,9 +463,7 @@ class MultiViewUNetModel(ModelMixin, ConfigMixin):
|
|
601 |
else:
|
602 |
num_heads = ch // num_head_channels
|
603 |
dim_head = num_head_channels
|
604 |
-
|
605 |
-
# num_heads = 1
|
606 |
-
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
607 |
self.middle_block = TimestepEmbedSequential(
|
608 |
ResBlock(
|
609 |
ch,
|
@@ -613,24 +473,15 @@ class MultiViewUNetModel(ModelMixin, ConfigMixin):
|
|
613 |
use_checkpoint=use_checkpoint,
|
614 |
use_scale_shift_norm=use_scale_shift_norm,
|
615 |
),
|
616 |
-
|
617 |
-
ch,
|
618 |
-
use_checkpoint=use_checkpoint,
|
619 |
-
num_heads=num_heads,
|
620 |
-
num_head_channels=dim_head,
|
621 |
-
use_new_attention_order=use_new_attention_order,
|
622 |
-
)
|
623 |
-
if not use_spatial_transformer
|
624 |
-
else SpatialTransformer3D(
|
625 |
ch,
|
626 |
num_heads,
|
627 |
dim_head,
|
628 |
depth=transformer_depth,
|
629 |
context_dim=context_dim,
|
630 |
disable_self_attn=disable_middle_self_attn,
|
631 |
-
use_linear=use_linear_in_transformer,
|
632 |
use_checkpoint=use_checkpoint,
|
633 |
-
),
|
634 |
ResBlock(
|
635 |
ch,
|
636 |
time_embed_dim,
|
@@ -664,13 +515,7 @@ class MultiViewUNetModel(ModelMixin, ConfigMixin):
|
|
664 |
else:
|
665 |
num_heads = ch // num_head_channels
|
666 |
dim_head = num_head_channels
|
667 |
-
|
668 |
-
# num_heads = 1
|
669 |
-
dim_head = (
|
670 |
-
ch // num_heads
|
671 |
-
if use_spatial_transformer
|
672 |
-
else num_head_channels
|
673 |
-
)
|
674 |
if disable_self_attentions is not None:
|
675 |
disabled_sa = disable_self_attentions[level]
|
676 |
else:
|
@@ -678,22 +523,13 @@ class MultiViewUNetModel(ModelMixin, ConfigMixin):
|
|
678 |
|
679 |
if num_attention_blocks is None or i < num_attention_blocks[level]:
|
680 |
layers.append(
|
681 |
-
|
682 |
-
ch,
|
683 |
-
use_checkpoint=use_checkpoint,
|
684 |
-
num_heads=num_heads_upsample,
|
685 |
-
num_head_channels=dim_head,
|
686 |
-
use_new_attention_order=use_new_attention_order,
|
687 |
-
)
|
688 |
-
if not use_spatial_transformer
|
689 |
-
else SpatialTransformer3D(
|
690 |
ch,
|
691 |
num_heads,
|
692 |
dim_head,
|
693 |
depth=transformer_depth,
|
694 |
context_dim=context_dim,
|
695 |
disable_self_attn=disabled_sa,
|
696 |
-
use_linear=use_linear_in_transformer,
|
697 |
use_checkpoint=use_checkpoint,
|
698 |
)
|
699 |
)
|
@@ -777,7 +613,7 @@ class MultiViewUNetModel(ModelMixin, ConfigMixin):
|
|
777 |
hs.append(h)
|
778 |
h = self.middle_block(h, emb, context, num_frames=num_frames)
|
779 |
for module in self.output_blocks:
|
780 |
-
h =
|
781 |
h = module(h, emb, context, num_frames=num_frames)
|
782 |
h = h.type(x.dtype)
|
783 |
if self.predict_codebook_ids:
|
|
|
1 |
# obtained and modified from https://github.com/bytedance/MVDream
|
2 |
|
3 |
import math
|
4 |
+
import torch
|
|
|
5 |
import torch.nn as nn
|
6 |
import torch.nn.functional as F
|
7 |
from diffusers.configuration_utils import ConfigMixin
|
|
|
222 |
emb_out = emb_out[..., None]
|
223 |
if self.use_scale_shift_norm:
|
224 |
out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
|
225 |
+
scale, shift = torch.chunk(emb_out, 2, dim=1)
|
226 |
h = out_norm(h) * (1 + scale) + shift
|
227 |
h = out_rest(h)
|
228 |
else:
|
|
|
231 |
return self.skip_connection(x) + h
|
232 |
|
233 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
234 |
class MultiViewUNetModel(ModelMixin, ConfigMixin):
|
235 |
"""
|
236 |
The full multi-view UNet model with attention, timestep embedding and camera embedding.
|
|
|
281 |
num_heads_upsample=-1,
|
282 |
use_scale_shift_norm=False,
|
283 |
resblock_updown=False,
|
|
|
|
|
284 |
transformer_depth=1, # custom transformer support
|
285 |
context_dim=None, # custom transformer support
|
286 |
n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
|
|
|
287 |
disable_self_attentions=None,
|
288 |
num_attention_blocks=None,
|
289 |
disable_middle_self_attn=False,
|
|
|
290 |
adm_in_channels=None,
|
291 |
camera_dim=None,
|
292 |
):
|
293 |
super().__init__()
|
294 |
+
assert context_dim is not None
|
295 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
296 |
if num_heads_upsample == -1:
|
297 |
num_heads_upsample = num_heads
|
298 |
|
|
|
412 |
else:
|
413 |
num_heads = ch // num_head_channels
|
414 |
dim_head = num_head_channels
|
415 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
416 |
if disable_self_attentions is not None:
|
417 |
disabled_sa = disable_self_attentions[level]
|
418 |
else:
|
|
|
420 |
|
421 |
if num_attention_blocks is None or nr < num_attention_blocks[level]:
|
422 |
layers.append(
|
423 |
+
SpatialTransformer3D(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
424 |
ch,
|
425 |
num_heads,
|
426 |
dim_head,
|
427 |
depth=transformer_depth,
|
428 |
context_dim=context_dim,
|
429 |
disable_self_attn=disabled_sa,
|
|
|
430 |
use_checkpoint=use_checkpoint,
|
431 |
)
|
432 |
)
|
|
|
463 |
else:
|
464 |
num_heads = ch // num_head_channels
|
465 |
dim_head = num_head_channels
|
466 |
+
|
|
|
|
|
467 |
self.middle_block = TimestepEmbedSequential(
|
468 |
ResBlock(
|
469 |
ch,
|
|
|
473 |
use_checkpoint=use_checkpoint,
|
474 |
use_scale_shift_norm=use_scale_shift_norm,
|
475 |
),
|
476 |
+
SpatialTransformer3D(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
477 |
ch,
|
478 |
num_heads,
|
479 |
dim_head,
|
480 |
depth=transformer_depth,
|
481 |
context_dim=context_dim,
|
482 |
disable_self_attn=disable_middle_self_attn,
|
|
|
483 |
use_checkpoint=use_checkpoint,
|
484 |
+
),
|
485 |
ResBlock(
|
486 |
ch,
|
487 |
time_embed_dim,
|
|
|
515 |
else:
|
516 |
num_heads = ch // num_head_channels
|
517 |
dim_head = num_head_channels
|
518 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
519 |
if disable_self_attentions is not None:
|
520 |
disabled_sa = disable_self_attentions[level]
|
521 |
else:
|
|
|
523 |
|
524 |
if num_attention_blocks is None or i < num_attention_blocks[level]:
|
525 |
layers.append(
|
526 |
+
SpatialTransformer3D(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
527 |
ch,
|
528 |
num_heads,
|
529 |
dim_head,
|
530 |
depth=transformer_depth,
|
531 |
context_dim=context_dim,
|
532 |
disable_self_attn=disabled_sa,
|
|
|
533 |
use_checkpoint=use_checkpoint,
|
534 |
)
|
535 |
)
|
|
|
613 |
hs.append(h)
|
614 |
h = self.middle_block(h, emb, context, num_frames=num_frames)
|
615 |
for module in self.output_blocks:
|
616 |
+
h = torch.cat([h, hs.pop()], dim=1)
|
617 |
h = module(h, emb, context, num_frames=num_frames)
|
618 |
h = h.type(x.dtype)
|
619 |
if self.predict_codebook_ids:
|