fbnnb commited on
Commit
9b32e18
Β·
verified Β·
1 Parent(s): 095edeb

Update lvdm/models/autoencoder_dualref.py

Browse files
Files changed (1) hide show
  1. lvdm/models/autoencoder_dualref.py +1177 -1176
lvdm/models/autoencoder_dualref.py CHANGED
@@ -1,1177 +1,1178 @@
1
- #### https://github.com/Stability-AI/generative-models
2
- from einops import rearrange, repeat
3
- import logging
4
- from typing import Any, Callable, Optional, Iterable, Union
5
-
6
- import numpy as np
7
- import torch
8
- import torch.nn as nn
9
- from packaging import version
10
- logpy = logging.getLogger(__name__)
11
-
12
- try:
13
- import xformers
14
- import xformers.ops
15
-
16
- XFORMERS_IS_AVAILABLE = True
17
- except:
18
- XFORMERS_IS_AVAILABLE = False
19
- logpy.warning("no module 'xformers'. Processing without...")
20
-
21
- from lvdm.modules.attention_svd import LinearAttention, MemoryEfficientCrossAttention
22
-
23
-
24
- def nonlinearity(x):
25
- # swish
26
- return x * torch.sigmoid(x)
27
-
28
-
29
- def Normalize(in_channels, num_groups=32):
30
- return torch.nn.GroupNorm(
31
- num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True
32
- )
33
-
34
-
35
- class ResnetBlock(nn.Module):
36
- def __init__(
37
- self,
38
- *,
39
- in_channels,
40
- out_channels=None,
41
- conv_shortcut=False,
42
- dropout,
43
- temb_channels=512,
44
- ):
45
- super().__init__()
46
- self.in_channels = in_channels
47
- out_channels = in_channels if out_channels is None else out_channels
48
- self.out_channels = out_channels
49
- self.use_conv_shortcut = conv_shortcut
50
-
51
- self.norm1 = Normalize(in_channels)
52
- self.conv1 = torch.nn.Conv2d(
53
- in_channels, out_channels, kernel_size=3, stride=1, padding=1
54
- )
55
- if temb_channels > 0:
56
- self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
57
- self.norm2 = Normalize(out_channels)
58
- self.dropout = torch.nn.Dropout(dropout)
59
- self.conv2 = torch.nn.Conv2d(
60
- out_channels, out_channels, kernel_size=3, stride=1, padding=1
61
- )
62
- if self.in_channels != self.out_channels:
63
- if self.use_conv_shortcut:
64
- self.conv_shortcut = torch.nn.Conv2d(
65
- in_channels, out_channels, kernel_size=3, stride=1, padding=1
66
- )
67
- else:
68
- self.nin_shortcut = torch.nn.Conv2d(
69
- in_channels, out_channels, kernel_size=1, stride=1, padding=0
70
- )
71
-
72
- def forward(self, x, temb):
73
- h = x
74
- h = self.norm1(h)
75
- h = nonlinearity(h)
76
- h = self.conv1(h)
77
-
78
- if temb is not None:
79
- h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
80
-
81
- h = self.norm2(h)
82
- h = nonlinearity(h)
83
- h = self.dropout(h)
84
- h = self.conv2(h)
85
-
86
- if self.in_channels != self.out_channels:
87
- if self.use_conv_shortcut:
88
- x = self.conv_shortcut(x)
89
- else:
90
- x = self.nin_shortcut(x)
91
-
92
- return x + h
93
-
94
-
95
- class LinAttnBlock(LinearAttention):
96
- """to match AttnBlock usage"""
97
-
98
- def __init__(self, in_channels):
99
- super().__init__(dim=in_channels, heads=1, dim_head=in_channels)
100
-
101
-
102
- class AttnBlock(nn.Module):
103
- def __init__(self, in_channels):
104
- super().__init__()
105
- self.in_channels = in_channels
106
-
107
- self.norm = Normalize(in_channels)
108
- self.q = torch.nn.Conv2d(
109
- in_channels, in_channels, kernel_size=1, stride=1, padding=0
110
- )
111
- self.k = torch.nn.Conv2d(
112
- in_channels, in_channels, kernel_size=1, stride=1, padding=0
113
- )
114
- self.v = torch.nn.Conv2d(
115
- in_channels, in_channels, kernel_size=1, stride=1, padding=0
116
- )
117
- self.proj_out = torch.nn.Conv2d(
118
- in_channels, in_channels, kernel_size=1, stride=1, padding=0
119
- )
120
-
121
- def attention(self, h_: torch.Tensor) -> torch.Tensor:
122
- h_ = self.norm(h_)
123
- q = self.q(h_)
124
- k = self.k(h_)
125
- v = self.v(h_)
126
-
127
- b, c, h, w = q.shape
128
- q, k, v = map(
129
- lambda x: rearrange(x, "b c h w -> b 1 (h w) c").contiguous(), (q, k, v)
130
- )
131
- h_ = torch.nn.functional.scaled_dot_product_attention(
132
- q, k, v
133
- ) # scale is dim ** -0.5 per default
134
- # compute attention
135
-
136
- return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
137
-
138
- def forward(self, x, **kwargs):
139
- h_ = x
140
- h_ = self.attention(h_)
141
- h_ = self.proj_out(h_)
142
- return x + h_
143
-
144
-
145
- class MemoryEfficientAttnBlock(nn.Module):
146
- """
147
- Uses xformers efficient implementation,
148
- see https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
149
- Note: this is a single-head self-attention operation
150
- """
151
-
152
- #
153
- def __init__(self, in_channels):
154
- super().__init__()
155
- self.in_channels = in_channels
156
-
157
- self.norm = Normalize(in_channels)
158
- self.q = torch.nn.Conv2d(
159
- in_channels, in_channels, kernel_size=1, stride=1, padding=0
160
- )
161
- self.k = torch.nn.Conv2d(
162
- in_channels, in_channels, kernel_size=1, stride=1, padding=0
163
- )
164
- self.v = torch.nn.Conv2d(
165
- in_channels, in_channels, kernel_size=1, stride=1, padding=0
166
- )
167
- self.proj_out = torch.nn.Conv2d(
168
- in_channels, in_channels, kernel_size=1, stride=1, padding=0
169
- )
170
- self.attention_op: Optional[Any] = None
171
-
172
- def attention(self, h_: torch.Tensor) -> torch.Tensor:
173
- h_ = self.norm(h_)
174
- q = self.q(h_)
175
- k = self.k(h_)
176
- v = self.v(h_)
177
-
178
- # compute attention
179
- B, C, H, W = q.shape
180
- q, k, v = map(lambda x: rearrange(x, "b c h w -> b (h w) c"), (q, k, v))
181
-
182
- q, k, v = map(
183
- lambda t: t.unsqueeze(3)
184
- .reshape(B, t.shape[1], 1, C)
185
- .permute(0, 2, 1, 3)
186
- .reshape(B * 1, t.shape[1], C)
187
- .contiguous(),
188
- (q, k, v),
189
- )
190
- out = xformers.ops.memory_efficient_attention(
191
- q, k, v, attn_bias=None, op=self.attention_op
192
- )
193
-
194
- out = (
195
- out.unsqueeze(0)
196
- .reshape(B, 1, out.shape[1], C)
197
- .permute(0, 2, 1, 3)
198
- .reshape(B, out.shape[1], C)
199
- )
200
- return rearrange(out, "b (h w) c -> b c h w", b=B, h=H, w=W, c=C)
201
-
202
- def forward(self, x, **kwargs):
203
- h_ = x
204
- h_ = self.attention(h_)
205
- h_ = self.proj_out(h_)
206
- return x + h_
207
-
208
-
209
- class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention):
210
- def forward(self, x, context=None, mask=None, **unused_kwargs):
211
- b, c, h, w = x.shape
212
- x = rearrange(x, "b c h w -> b (h w) c")
213
- out = super().forward(x, context=context, mask=mask)
214
- out = rearrange(out, "b (h w) c -> b c h w", h=h, w=w, c=c)
215
- return x + out
216
-
217
-
218
- def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
219
- assert attn_type in [
220
- "vanilla",
221
- "vanilla-xformers",
222
- "memory-efficient-cross-attn",
223
- "linear",
224
- "none",
225
- "memory-efficient-cross-attn-fusion",
226
- ], f"attn_type {attn_type} unknown"
227
- if (
228
- version.parse(torch.__version__) < version.parse("2.0.0")
229
- and attn_type != "none"
230
- ):
231
- assert XFORMERS_IS_AVAILABLE, (
232
- f"We do not support vanilla attention in {torch.__version__} anymore, "
233
- f"as it is too expensive. Please install xformers via e.g. 'pip install xformers==0.0.16'"
234
- )
235
- # attn_type = "vanilla-xformers"
236
- logpy.info(f"making attention of type '{attn_type}' with {in_channels} in_channels")
237
- if attn_type == "vanilla":
238
- assert attn_kwargs is None
239
- return AttnBlock(in_channels)
240
- elif attn_type == "vanilla-xformers":
241
- logpy.info(
242
- f"building MemoryEfficientAttnBlock with {in_channels} in_channels..."
243
- )
244
- return MemoryEfficientAttnBlock(in_channels)
245
- elif attn_type == "memory-efficient-cross-attn":
246
- attn_kwargs["query_dim"] = in_channels
247
- return MemoryEfficientCrossAttentionWrapper(**attn_kwargs)
248
- elif attn_type == "memory-efficient-cross-attn-fusion":
249
- attn_kwargs["query_dim"] = in_channels
250
- return MemoryEfficientCrossAttentionWrapperFusion(**attn_kwargs)
251
- elif attn_type == "none":
252
- return nn.Identity(in_channels)
253
- else:
254
- return LinAttnBlock(in_channels)
255
-
256
- class MemoryEfficientCrossAttentionWrapperFusion(MemoryEfficientCrossAttention):
257
- # print('x.shape: ',x.shape, 'context.shape: ',context.shape) ##torch.Size([8, 128, 256, 256]) torch.Size([1, 128, 2, 256, 256])
258
- def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0, **kwargs):
259
- super().__init__(query_dim, context_dim, heads, dim_head, dropout, **kwargs)
260
- self.norm = Normalize(query_dim)
261
- nn.init.zeros_(self.to_out[0].weight)
262
- nn.init.zeros_(self.to_out[0].bias)
263
-
264
- def forward(self, x, context=None, mask=None):
265
- if self.training:
266
- return checkpoint(self._forward, x, context, mask, use_reentrant=False)
267
- else:
268
- return self._forward(x, context, mask)
269
-
270
- def _forward(
271
- self,
272
- x,
273
- context=None,
274
- mask=None,
275
- ):
276
- bt, c, h, w = x.shape
277
- h_ = self.norm(x)
278
- h_ = rearrange(h_, "b c h w -> b (h w) c")
279
- q = self.to_q(h_)
280
-
281
-
282
- b, c, l, h, w = context.shape
283
- context = rearrange(context, "b c l h w -> (b l) (h w) c")
284
- k = self.to_k(context)
285
- v = self.to_v(context)
286
- k = rearrange(k, "(b l) d c -> b l d c", l=l)
287
- k = torch.cat([k[:, [0] * (bt//b)], k[:, [1]*(bt//b)]], dim=2)
288
- k = rearrange(k, "b l d c -> (b l) d c")
289
-
290
- v = rearrange(v, "(b l) d c -> b l d c", l=l)
291
- v = torch.cat([v[:, [0] * (bt//b)], v[:, [1]*(bt//b)]], dim=2)
292
- v = rearrange(v, "b l d c -> (b l) d c")
293
-
294
-
295
- b, _, _ = q.shape ##actually bt
296
- q, k, v = map(
297
- lambda t: t.unsqueeze(3)
298
- .reshape(b, t.shape[1], self.heads, self.dim_head)
299
- .permute(0, 2, 1, 3)
300
- .reshape(b * self.heads, t.shape[1], self.dim_head)
301
- .contiguous(),
302
- (q, k, v),
303
- )
304
-
305
- # actually compute the attention, what we cannot get enough of
306
- if version.parse(xformers.__version__) >= version.parse("0.0.21"):
307
- # NOTE: workaround for
308
- # https://github.com/facebookresearch/xformers/issues/845
309
- max_bs = 32768
310
- N = q.shape[0]
311
- n_batches = math.ceil(N / max_bs)
312
- out = list()
313
- for i_batch in range(n_batches):
314
- batch = slice(i_batch * max_bs, (i_batch + 1) * max_bs)
315
- out.append(
316
- xformers.ops.memory_efficient_attention(
317
- q[batch],
318
- k[batch],
319
- v[batch],
320
- attn_bias=None,
321
- op=self.attention_op,
322
- )
323
- )
324
- out = torch.cat(out, 0)
325
- else:
326
- out = xformers.ops.memory_efficient_attention(
327
- q, k, v, attn_bias=None, op=self.attention_op
328
- )
329
-
330
- # TODO: Use this directly in the attention operation, as a bias
331
- if exists(mask):
332
- raise NotImplementedError
333
- out = (
334
- out.unsqueeze(0)
335
- .reshape(b, self.heads, out.shape[1], self.dim_head)
336
- .permute(0, 2, 1, 3)
337
- .reshape(b, out.shape[1], self.heads * self.dim_head)
338
- )
339
- out = self.to_out(out)
340
- out = rearrange(out, "bt (h w) c -> bt c h w", h=h, w=w, c=c)
341
- return x + out
342
-
343
- class Combiner(nn.Module):
344
- def __init__(self, ch) -> None:
345
- super().__init__()
346
- self.conv = nn.Conv2d(ch,ch,1,padding=0)
347
-
348
- nn.init.zeros_(self.conv.weight)
349
- nn.init.zeros_(self.conv.bias)
350
-
351
- def forward(self, x, context):
352
- if self.training:
353
- return checkpoint(self._forward, x, context, use_reentrant=False)
354
- else:
355
- return self._forward(x, context)
356
-
357
- def _forward(self, x, context):
358
- ## x: b c h w, context: b c 2 h w
359
- b, c, l, h, w = context.shape
360
- bt, c, h, w = x.shape
361
- context = rearrange(context, "b c l h w -> (b l) c h w")
362
- context = self.conv(context)
363
- context = rearrange(context, "(b l) c h w -> b c l h w", l=l)
364
- x = rearrange(x, "(b t) c h w -> b c t h w", t=bt//b)
365
- x[:,:,0] = x[:,:,0] + context[:,:,0]
366
- x[:,:,-1] = x[:,:,-1] + context[:,:,1]
367
- x = rearrange(x, "b c t h w -> (b t) c h w")
368
- return x
369
-
370
-
371
- class Decoder(nn.Module):
372
- def __init__(
373
- self,
374
- *,
375
- ch,
376
- out_ch,
377
- ch_mult=(1, 2, 4, 8),
378
- num_res_blocks,
379
- attn_resolutions,
380
- dropout=0.0,
381
- resamp_with_conv=True,
382
- in_channels,
383
- resolution,
384
- z_channels,
385
- give_pre_end=False,
386
- tanh_out=False,
387
- use_linear_attn=False,
388
- attn_type="vanilla-xformers",
389
- attn_level=[2,3],
390
- **ignorekwargs,
391
- ):
392
- super().__init__()
393
- if use_linear_attn:
394
- attn_type = "linear"
395
- self.ch = ch
396
- self.temb_ch = 0
397
- self.num_resolutions = len(ch_mult)
398
- self.num_res_blocks = num_res_blocks
399
- self.resolution = resolution
400
- self.in_channels = in_channels
401
- self.give_pre_end = give_pre_end
402
- self.tanh_out = tanh_out
403
- self.attn_level = attn_level
404
- # compute in_ch_mult, block_in and curr_res at lowest res
405
- in_ch_mult = (1,) + tuple(ch_mult)
406
- block_in = ch * ch_mult[self.num_resolutions - 1]
407
- curr_res = resolution // 2 ** (self.num_resolutions - 1)
408
- self.z_shape = (1, z_channels, curr_res, curr_res)
409
- logpy.info(
410
- "Working with z of shape {} = {} dimensions.".format(
411
- self.z_shape, np.prod(self.z_shape)
412
- )
413
- )
414
-
415
- make_attn_cls = self._make_attn()
416
- make_resblock_cls = self._make_resblock()
417
- make_conv_cls = self._make_conv()
418
- # z to block_in
419
- self.conv_in = torch.nn.Conv2d(
420
- z_channels, block_in, kernel_size=3, stride=1, padding=1
421
- )
422
-
423
- # middle
424
- self.mid = nn.Module()
425
- self.mid.block_1 = make_resblock_cls(
426
- in_channels=block_in,
427
- out_channels=block_in,
428
- temb_channels=self.temb_ch,
429
- dropout=dropout,
430
- )
431
- self.mid.attn_1 = make_attn_cls(block_in, attn_type=attn_type)
432
- self.mid.block_2 = make_resblock_cls(
433
- in_channels=block_in,
434
- out_channels=block_in,
435
- temb_channels=self.temb_ch,
436
- dropout=dropout,
437
- )
438
-
439
- # upsampling
440
- self.up = nn.ModuleList()
441
- self.attn_refinement = nn.ModuleList()
442
- for i_level in reversed(range(self.num_resolutions)):
443
- block = nn.ModuleList()
444
- attn = nn.ModuleList()
445
- block_out = ch * ch_mult[i_level]
446
- for i_block in range(self.num_res_blocks + 1):
447
- block.append(
448
- make_resblock_cls(
449
- in_channels=block_in,
450
- out_channels=block_out,
451
- temb_channels=self.temb_ch,
452
- dropout=dropout,
453
- )
454
- )
455
- block_in = block_out
456
- if curr_res in attn_resolutions:
457
- attn.append(make_attn_cls(block_in, attn_type=attn_type))
458
- up = nn.Module()
459
- up.block = block
460
- up.attn = attn
461
- if i_level != 0:
462
- up.upsample = Upsample(block_in, resamp_with_conv)
463
- curr_res = curr_res * 2
464
- self.up.insert(0, up) # prepend to get consistent order
465
-
466
- if i_level in self.attn_level:
467
- self.attn_refinement.insert(0, make_attn_cls(block_in, attn_type='memory-efficient-cross-attn-fusion', attn_kwargs={}))
468
- else:
469
- self.attn_refinement.insert(0, Combiner(block_in))
470
- # end
471
- self.norm_out = Normalize(block_in)
472
- self.attn_refinement.append(Combiner(block_in))
473
- self.conv_out = make_conv_cls(
474
- block_in, out_ch, kernel_size=3, stride=1, padding=1
475
- )
476
-
477
- def _make_attn(self) -> Callable:
478
- return make_attn
479
-
480
- def _make_resblock(self) -> Callable:
481
- return ResnetBlock
482
-
483
- def _make_conv(self) -> Callable:
484
- return torch.nn.Conv2d
485
-
486
- def get_last_layer(self, **kwargs):
487
- return self.conv_out.weight
488
-
489
- def forward(self, z, ref_context=None, **kwargs):
490
- ## ref_context: b c 2 h w, 2 means starting and ending frame
491
- # assert z.shape[1:] == self.z_shape[1:]
492
- self.last_z_shape = z.shape
493
- # timestep embedding
494
- temb = None
495
-
496
- # z to block_in
497
- h = self.conv_in(z)
498
-
499
- # middle
500
- h = self.mid.block_1(h, temb, **kwargs)
501
- h = self.mid.attn_1(h, **kwargs)
502
- h = self.mid.block_2(h, temb, **kwargs)
503
-
504
- # upsampling
505
- for i_level in reversed(range(self.num_resolutions)):
506
- for i_block in range(self.num_res_blocks + 1):
507
- h = self.up[i_level].block[i_block](h, temb, **kwargs)
508
- if len(self.up[i_level].attn) > 0:
509
- h = self.up[i_level].attn[i_block](h, **kwargs)
510
- if ref_context:
511
- h = self.attn_refinement[i_level](x=h, context=ref_context[i_level])
512
- if i_level != 0:
513
- h = self.up[i_level].upsample(h)
514
-
515
- # end
516
- if self.give_pre_end:
517
- return h
518
-
519
- h = self.norm_out(h)
520
- h = nonlinearity(h)
521
- if ref_context:
522
- # print(h.shape, ref_context[i_level].shape) #torch.Size([8, 128, 256, 256]) torch.Size([1, 128, 2, 256, 256])
523
- h = self.attn_refinement[-1](x=h, context=ref_context[-1])
524
- h = self.conv_out(h, **kwargs)
525
- if self.tanh_out:
526
- h = torch.tanh(h)
527
- return h
528
-
529
- #####
530
-
531
-
532
- from abc import abstractmethod
533
- from lvdm.models.utils_diffusion import timestep_embedding
534
-
535
- from torch.utils.checkpoint import checkpoint
536
- from lvdm.basics import (
537
- zero_module,
538
- conv_nd,
539
- linear,
540
- normalization,
541
- )
542
- from lvdm.modules.networks.openaimodel3d import Upsample, Downsample
543
- class TimestepBlock(nn.Module):
544
- """
545
- Any module where forward() takes timestep embeddings as a second argument.
546
- """
547
-
548
- @abstractmethod
549
- def forward(self, x: torch.Tensor, emb: torch.Tensor):
550
- """
551
- Apply the module to `x` given `emb` timestep embeddings.
552
- """
553
-
554
- class ResBlock(TimestepBlock):
555
- """
556
- A residual block that can optionally change the number of channels.
557
- :param channels: the number of input channels.
558
- :param emb_channels: the number of timestep embedding channels.
559
- :param dropout: the rate of dropout.
560
- :param out_channels: if specified, the number of out channels.
561
- :param use_conv: if True and out_channels is specified, use a spatial
562
- convolution instead of a smaller 1x1 convolution to change the
563
- channels in the skip connection.
564
- :param dims: determines if the signal is 1D, 2D, or 3D.
565
- :param use_checkpoint: if True, use gradient checkpointing on this module.
566
- :param up: if True, use this block for upsampling.
567
- :param down: if True, use this block for downsampling.
568
- """
569
-
570
- def __init__(
571
- self,
572
- channels: int,
573
- emb_channels: int,
574
- dropout: float,
575
- out_channels: Optional[int] = None,
576
- use_conv: bool = False,
577
- use_scale_shift_norm: bool = False,
578
- dims: int = 2,
579
- use_checkpoint: bool = False,
580
- up: bool = False,
581
- down: bool = False,
582
- kernel_size: int = 3,
583
- exchange_temb_dims: bool = False,
584
- skip_t_emb: bool = False,
585
- ):
586
- super().__init__()
587
- self.channels = channels
588
- self.emb_channels = emb_channels
589
- self.dropout = dropout
590
- self.out_channels = out_channels or channels
591
- self.use_conv = use_conv
592
- self.use_checkpoint = use_checkpoint
593
- self.use_scale_shift_norm = use_scale_shift_norm
594
- self.exchange_temb_dims = exchange_temb_dims
595
-
596
- if isinstance(kernel_size, Iterable):
597
- padding = [k // 2 for k in kernel_size]
598
- else:
599
- padding = kernel_size // 2
600
-
601
- self.in_layers = nn.Sequential(
602
- normalization(channels),
603
- nn.SiLU(),
604
- conv_nd(dims, channels, self.out_channels, kernel_size, padding=padding),
605
- )
606
-
607
- self.updown = up or down
608
-
609
- if up:
610
- self.h_upd = Upsample(channels, False, dims)
611
- self.x_upd = Upsample(channels, False, dims)
612
- elif down:
613
- self.h_upd = Downsample(channels, False, dims)
614
- self.x_upd = Downsample(channels, False, dims)
615
- else:
616
- self.h_upd = self.x_upd = nn.Identity()
617
-
618
- self.skip_t_emb = skip_t_emb
619
- self.emb_out_channels = (
620
- 2 * self.out_channels if use_scale_shift_norm else self.out_channels
621
- )
622
- if self.skip_t_emb:
623
- # print(f"Skipping timestep embedding in {self.__class__.__name__}")
624
- assert not self.use_scale_shift_norm
625
- self.emb_layers = None
626
- self.exchange_temb_dims = False
627
- else:
628
- self.emb_layers = nn.Sequential(
629
- nn.SiLU(),
630
- linear(
631
- emb_channels,
632
- self.emb_out_channels,
633
- ),
634
- )
635
-
636
- self.out_layers = nn.Sequential(
637
- normalization(self.out_channels),
638
- nn.SiLU(),
639
- nn.Dropout(p=dropout),
640
- zero_module(
641
- conv_nd(
642
- dims,
643
- self.out_channels,
644
- self.out_channels,
645
- kernel_size,
646
- padding=padding,
647
- )
648
- ),
649
- )
650
-
651
- if self.out_channels == channels:
652
- self.skip_connection = nn.Identity()
653
- elif use_conv:
654
- self.skip_connection = conv_nd(
655
- dims, channels, self.out_channels, kernel_size, padding=padding
656
- )
657
- else:
658
- self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
659
-
660
- def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor:
661
- """
662
- Apply the block to a Tensor, conditioned on a timestep embedding.
663
- :param x: an [N x C x ...] Tensor of features.
664
- :param emb: an [N x emb_channels] Tensor of timestep embeddings.
665
- :return: an [N x C x ...] Tensor of outputs.
666
- """
667
- if self.use_checkpoint:
668
- return checkpoint(self._forward, x, emb, use_reentrant=False)
669
- else:
670
- return self._forward(x, emb)
671
-
672
- def _forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor:
673
- if self.updown:
674
- in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
675
- h = in_rest(x)
676
- h = self.h_upd(h)
677
- x = self.x_upd(x)
678
- h = in_conv(h)
679
- else:
680
- h = self.in_layers(x)
681
-
682
- if self.skip_t_emb:
683
- emb_out = torch.zeros_like(h)
684
- else:
685
- emb_out = self.emb_layers(emb).type(h.dtype)
686
- while len(emb_out.shape) < len(h.shape):
687
- emb_out = emb_out[..., None]
688
- if self.use_scale_shift_norm:
689
- out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
690
- scale, shift = torch.chunk(emb_out, 2, dim=1)
691
- h = out_norm(h) * (1 + scale) + shift
692
- h = out_rest(h)
693
- else:
694
- if self.exchange_temb_dims:
695
- emb_out = rearrange(emb_out, "b t c ... -> b c t ...")
696
- h = h + emb_out
697
- h = self.out_layers(h)
698
- return self.skip_connection(x) + h
699
- #####
700
-
701
- #####
702
- from lvdm.modules.attention_svd import *
703
- class VideoTransformerBlock(nn.Module):
704
- ATTENTION_MODES = {
705
- "softmax": CrossAttention,
706
- "softmax-xformers": MemoryEfficientCrossAttention,
707
- }
708
-
709
- def __init__(
710
- self,
711
- dim,
712
- n_heads,
713
- d_head,
714
- dropout=0.0,
715
- context_dim=None,
716
- gated_ff=True,
717
- checkpoint=True,
718
- timesteps=None,
719
- ff_in=False,
720
- inner_dim=None,
721
- attn_mode="softmax",
722
- disable_self_attn=False,
723
- disable_temporal_crossattention=False,
724
- switch_temporal_ca_to_sa=False,
725
- ):
726
- super().__init__()
727
-
728
- attn_cls = self.ATTENTION_MODES[attn_mode]
729
-
730
- self.ff_in = ff_in or inner_dim is not None
731
- if inner_dim is None:
732
- inner_dim = dim
733
-
734
- assert int(n_heads * d_head) == inner_dim
735
-
736
- self.is_res = inner_dim == dim
737
-
738
- if self.ff_in:
739
- self.norm_in = nn.LayerNorm(dim)
740
- self.ff_in = FeedForward(
741
- dim, dim_out=inner_dim, dropout=dropout, glu=gated_ff
742
- )
743
-
744
- self.timesteps = timesteps
745
- self.disable_self_attn = disable_self_attn
746
- if self.disable_self_attn:
747
- self.attn1 = attn_cls(
748
- query_dim=inner_dim,
749
- heads=n_heads,
750
- dim_head=d_head,
751
- context_dim=context_dim,
752
- dropout=dropout,
753
- ) # is a cross-attention
754
- else:
755
- self.attn1 = attn_cls(
756
- query_dim=inner_dim, heads=n_heads, dim_head=d_head, dropout=dropout
757
- ) # is a self-attention
758
-
759
- self.ff = FeedForward(inner_dim, dim_out=dim, dropout=dropout, glu=gated_ff)
760
-
761
- if disable_temporal_crossattention:
762
- if switch_temporal_ca_to_sa:
763
- raise ValueError
764
- else:
765
- self.attn2 = None
766
- else:
767
- self.norm2 = nn.LayerNorm(inner_dim)
768
- if switch_temporal_ca_to_sa:
769
- self.attn2 = attn_cls(
770
- query_dim=inner_dim, heads=n_heads, dim_head=d_head, dropout=dropout
771
- ) # is a self-attention
772
- else:
773
- self.attn2 = attn_cls(
774
- query_dim=inner_dim,
775
- context_dim=context_dim,
776
- heads=n_heads,
777
- dim_head=d_head,
778
- dropout=dropout,
779
- ) # is self-attn if context is none
780
-
781
- self.norm1 = nn.LayerNorm(inner_dim)
782
- self.norm3 = nn.LayerNorm(inner_dim)
783
- self.switch_temporal_ca_to_sa = switch_temporal_ca_to_sa
784
-
785
- self.checkpoint = checkpoint
786
- if self.checkpoint:
787
- print(f"====>{self.__class__.__name__} is using checkpointing")
788
- else:
789
- print(f"====>{self.__class__.__name__} is NOT using checkpointing")
790
-
791
- def forward(
792
- self, x: torch.Tensor, context: torch.Tensor = None, timesteps: int = None
793
- ) -> torch.Tensor:
794
- if self.checkpoint:
795
- return checkpoint(self._forward, x, context, timesteps, use_reentrant=False)
796
- else:
797
- return self._forward(x, context, timesteps=timesteps)
798
-
799
- def _forward(self, x, context=None, timesteps=None):
800
- assert self.timesteps or timesteps
801
- assert not (self.timesteps and timesteps) or self.timesteps == timesteps
802
- timesteps = self.timesteps or timesteps
803
- B, S, C = x.shape
804
- x = rearrange(x, "(b t) s c -> (b s) t c", t=timesteps)
805
-
806
- if self.ff_in:
807
- x_skip = x
808
- x = self.ff_in(self.norm_in(x))
809
- if self.is_res:
810
- x += x_skip
811
-
812
- if self.disable_self_attn:
813
- x = self.attn1(self.norm1(x), context=context) + x
814
- else:
815
- x = self.attn1(self.norm1(x)) + x
816
-
817
- if self.attn2 is not None:
818
- if self.switch_temporal_ca_to_sa:
819
- x = self.attn2(self.norm2(x)) + x
820
- else:
821
- x = self.attn2(self.norm2(x), context=context) + x
822
- x_skip = x
823
- x = self.ff(self.norm3(x))
824
- if self.is_res:
825
- x += x_skip
826
-
827
- x = rearrange(
828
- x, "(b s) t c -> (b t) s c", s=S, b=B // timesteps, c=C, t=timesteps
829
- )
830
- return x
831
-
832
- def get_last_layer(self):
833
- return self.ff.net[-1].weight
834
-
835
- #####
836
-
837
- #####
838
- import functools
839
- def partialclass(cls, *args, **kwargs):
840
- class NewCls(cls):
841
- __init__ = functools.partialmethod(cls.__init__, *args, **kwargs)
842
-
843
- return NewCls
844
- ######
845
-
846
- class VideoResBlock(ResnetBlock):
847
- def __init__(
848
- self,
849
- out_channels,
850
- *args,
851
- dropout=0.0,
852
- video_kernel_size=3,
853
- alpha=0.0,
854
- merge_strategy="learned",
855
- **kwargs,
856
- ):
857
- super().__init__(out_channels=out_channels, dropout=dropout, *args, **kwargs)
858
- if video_kernel_size is None:
859
- video_kernel_size = [3, 1, 1]
860
- self.time_stack = ResBlock(
861
- channels=out_channels,
862
- emb_channels=0,
863
- dropout=dropout,
864
- dims=3,
865
- use_scale_shift_norm=False,
866
- use_conv=False,
867
- up=False,
868
- down=False,
869
- kernel_size=video_kernel_size,
870
- use_checkpoint=True,
871
- skip_t_emb=True,
872
- )
873
-
874
- self.merge_strategy = merge_strategy
875
- if self.merge_strategy == "fixed":
876
- self.register_buffer("mix_factor", torch.Tensor([alpha]))
877
- elif self.merge_strategy == "learned":
878
- self.register_parameter(
879
- "mix_factor", torch.nn.Parameter(torch.Tensor([alpha]))
880
- )
881
- else:
882
- raise ValueError(f"unknown merge strategy {self.merge_strategy}")
883
-
884
- def get_alpha(self, bs):
885
- if self.merge_strategy == "fixed":
886
- return self.mix_factor
887
- elif self.merge_strategy == "learned":
888
- return torch.sigmoid(self.mix_factor)
889
- else:
890
- raise NotImplementedError()
891
-
892
- def forward(self, x, temb, skip_video=False, timesteps=None):
893
- if timesteps is None:
894
- timesteps = self.timesteps
895
-
896
- b, c, h, w = x.shape
897
-
898
- x = super().forward(x, temb)
899
-
900
- if not skip_video:
901
- x_mix = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps)
902
-
903
- x = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps)
904
-
905
- x = self.time_stack(x, temb)
906
-
907
- alpha = self.get_alpha(bs=b // timesteps)
908
- x = alpha * x + (1.0 - alpha) * x_mix
909
-
910
- x = rearrange(x, "b c t h w -> (b t) c h w")
911
- return x
912
-
913
-
914
- class AE3DConv(torch.nn.Conv2d):
915
- def __init__(self, in_channels, out_channels, video_kernel_size=3, *args, **kwargs):
916
- super().__init__(in_channels, out_channels, *args, **kwargs)
917
- if isinstance(video_kernel_size, Iterable):
918
- padding = [int(k // 2) for k in video_kernel_size]
919
- else:
920
- padding = int(video_kernel_size // 2)
921
-
922
- self.time_mix_conv = torch.nn.Conv3d(
923
- in_channels=out_channels,
924
- out_channels=out_channels,
925
- kernel_size=video_kernel_size,
926
- padding=padding,
927
- )
928
-
929
- def forward(self, input, timesteps, skip_video=False):
930
- x = super().forward(input)
931
- if skip_video:
932
- return x
933
- x = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps)
934
- x = self.time_mix_conv(x)
935
- return rearrange(x, "b c t h w -> (b t) c h w")
936
-
937
-
938
- class VideoBlock(AttnBlock):
939
- def __init__(
940
- self, in_channels: int, alpha: float = 0, merge_strategy: str = "learned"
941
- ):
942
- super().__init__(in_channels)
943
- # no context, single headed, as in base class
944
- self.time_mix_block = VideoTransformerBlock(
945
- dim=in_channels,
946
- n_heads=1,
947
- d_head=in_channels,
948
- checkpoint=True,
949
- ff_in=True,
950
- attn_mode="softmax",
951
- )
952
-
953
- time_embed_dim = self.in_channels * 4
954
- self.video_time_embed = torch.nn.Sequential(
955
- torch.nn.Linear(self.in_channels, time_embed_dim),
956
- torch.nn.SiLU(),
957
- torch.nn.Linear(time_embed_dim, self.in_channels),
958
- )
959
-
960
- self.merge_strategy = merge_strategy
961
- if self.merge_strategy == "fixed":
962
- self.register_buffer("mix_factor", torch.Tensor([alpha]))
963
- elif self.merge_strategy == "learned":
964
- self.register_parameter(
965
- "mix_factor", torch.nn.Parameter(torch.Tensor([alpha]))
966
- )
967
- else:
968
- raise ValueError(f"unknown merge strategy {self.merge_strategy}")
969
-
970
- def forward(self, x, timesteps, skip_video=False):
971
- if skip_video:
972
- return super().forward(x)
973
-
974
- x_in = x
975
- x = self.attention(x)
976
- h, w = x.shape[2:]
977
- x = rearrange(x, "b c h w -> b (h w) c")
978
-
979
- x_mix = x
980
- num_frames = torch.arange(timesteps, device=x.device)
981
- num_frames = repeat(num_frames, "t -> b t", b=x.shape[0] // timesteps)
982
- num_frames = rearrange(num_frames, "b t -> (b t)")
983
- t_emb = timestep_embedding(num_frames, self.in_channels, repeat_only=False)
984
- emb = self.video_time_embed(t_emb) # b, n_channels
985
- emb = emb[:, None, :]
986
- x_mix = x_mix + emb
987
-
988
- alpha = self.get_alpha()
989
- x_mix = self.time_mix_block(x_mix, timesteps=timesteps)
990
- x = alpha * x + (1.0 - alpha) * x_mix # alpha merge
991
-
992
- x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
993
- x = self.proj_out(x)
994
-
995
- return x_in + x
996
-
997
- def get_alpha(
998
- self,
999
- ):
1000
- if self.merge_strategy == "fixed":
1001
- return self.mix_factor
1002
- elif self.merge_strategy == "learned":
1003
- return torch.sigmoid(self.mix_factor)
1004
- else:
1005
- raise NotImplementedError(f"unknown merge strategy {self.merge_strategy}")
1006
-
1007
-
1008
- class MemoryEfficientVideoBlock(MemoryEfficientAttnBlock):
1009
- def __init__(
1010
- self, in_channels: int, alpha: float = 0, merge_strategy: str = "learned"
1011
- ):
1012
- super().__init__(in_channels)
1013
- # no context, single headed, as in base class
1014
- self.time_mix_block = VideoTransformerBlock(
1015
- dim=in_channels,
1016
- n_heads=1,
1017
- d_head=in_channels,
1018
- checkpoint=True,
1019
- ff_in=True,
1020
- attn_mode="softmax-xformers",
1021
- )
1022
-
1023
- time_embed_dim = self.in_channels * 4
1024
- self.video_time_embed = torch.nn.Sequential(
1025
- torch.nn.Linear(self.in_channels, time_embed_dim),
1026
- torch.nn.SiLU(),
1027
- torch.nn.Linear(time_embed_dim, self.in_channels),
1028
- )
1029
-
1030
- self.merge_strategy = merge_strategy
1031
- if self.merge_strategy == "fixed":
1032
- self.register_buffer("mix_factor", torch.Tensor([alpha]))
1033
- elif self.merge_strategy == "learned":
1034
- self.register_parameter(
1035
- "mix_factor", torch.nn.Parameter(torch.Tensor([alpha]))
1036
- )
1037
- else:
1038
- raise ValueError(f"unknown merge strategy {self.merge_strategy}")
1039
-
1040
- def forward(self, x, timesteps, skip_time_block=False):
1041
- if skip_time_block:
1042
- return super().forward(x)
1043
-
1044
- x_in = x
1045
- x = self.attention(x)
1046
- h, w = x.shape[2:]
1047
- x = rearrange(x, "b c h w -> b (h w) c")
1048
-
1049
- x_mix = x
1050
- num_frames = torch.arange(timesteps, device=x.device)
1051
- num_frames = repeat(num_frames, "t -> b t", b=x.shape[0] // timesteps)
1052
- num_frames = rearrange(num_frames, "b t -> (b t)")
1053
- t_emb = timestep_embedding(num_frames, self.in_channels, repeat_only=False)
1054
- emb = self.video_time_embed(t_emb) # b, n_channels
1055
- emb = emb[:, None, :]
1056
- x_mix = x_mix + emb
1057
-
1058
- alpha = self.get_alpha()
1059
- x_mix = self.time_mix_block(x_mix, timesteps=timesteps)
1060
- x = alpha * x + (1.0 - alpha) * x_mix # alpha merge
1061
-
1062
- x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
1063
- x = self.proj_out(x)
1064
-
1065
- return x_in + x
1066
-
1067
- def get_alpha(
1068
- self,
1069
- ):
1070
- if self.merge_strategy == "fixed":
1071
- return self.mix_factor
1072
- elif self.merge_strategy == "learned":
1073
- return torch.sigmoid(self.mix_factor)
1074
- else:
1075
- raise NotImplementedError(f"unknown merge strategy {self.merge_strategy}")
1076
-
1077
-
1078
- def make_time_attn(
1079
- in_channels,
1080
- attn_type="vanilla",
1081
- attn_kwargs=None,
1082
- alpha: float = 0,
1083
- merge_strategy: str = "learned",
1084
- ):
1085
- assert attn_type in [
1086
- "vanilla",
1087
- "vanilla-xformers",
1088
- ], f"attn_type {attn_type} not supported for spatio-temporal attention"
1089
- print(
1090
- f"making spatial and temporal attention of type '{attn_type}' with {in_channels} in_channels"
1091
- )
1092
- if not XFORMERS_IS_AVAILABLE and attn_type == "vanilla-xformers":
1093
- print(
1094
- f"Attention mode '{attn_type}' is not available. Falling back to vanilla attention. "
1095
- f"This is not a problem in Pytorch >= 2.0. FYI, you are running with PyTorch version {torch.__version__}"
1096
- )
1097
- attn_type = "vanilla"
1098
-
1099
- if attn_type == "vanilla":
1100
- assert attn_kwargs is None
1101
- return partialclass(
1102
- VideoBlock, in_channels, alpha=alpha, merge_strategy=merge_strategy
1103
- )
1104
- elif attn_type == "vanilla-xformers":
1105
- print(f"building MemoryEfficientAttnBlock with {in_channels} in_channels...")
1106
- return partialclass(
1107
- MemoryEfficientVideoBlock,
1108
- in_channels,
1109
- alpha=alpha,
1110
- merge_strategy=merge_strategy,
1111
- )
1112
- else:
1113
- return NotImplementedError()
1114
-
1115
-
1116
- class Conv2DWrapper(torch.nn.Conv2d):
1117
- def forward(self, input: torch.Tensor, **kwargs) -> torch.Tensor:
1118
- return super().forward(input)
1119
-
1120
-
1121
- class VideoDecoder(Decoder):
1122
- available_time_modes = ["all", "conv-only", "attn-only"]
1123
-
1124
- def __init__(
1125
- self,
1126
- *args,
1127
- video_kernel_size: Union[int, list] = [3,1,1],
1128
- alpha: float = 0.0,
1129
- merge_strategy: str = "learned",
1130
- time_mode: str = "conv-only",
1131
- **kwargs,
1132
- ):
1133
- self.video_kernel_size = video_kernel_size
1134
- self.alpha = alpha
1135
- self.merge_strategy = merge_strategy
1136
- self.time_mode = time_mode
1137
- assert (
1138
- self.time_mode in self.available_time_modes
1139
- ), f"time_mode parameter has to be in {self.available_time_modes}"
1140
- super().__init__(*args, **kwargs)
1141
-
1142
- def get_last_layer(self, skip_time_mix=False, **kwargs):
1143
- if self.time_mode == "attn-only":
1144
- raise NotImplementedError("TODO")
1145
- else:
1146
- return (
1147
- self.conv_out.time_mix_conv.weight
1148
- if not skip_time_mix
1149
- else self.conv_out.weight
1150
- )
1151
-
1152
- def _make_attn(self) -> Callable:
1153
- if self.time_mode not in ["conv-only", "only-last-conv"]:
1154
- return partialclass(
1155
- make_time_attn,
1156
- alpha=self.alpha,
1157
- merge_strategy=self.merge_strategy,
1158
- )
1159
- else:
1160
- return super()._make_attn()
1161
-
1162
- def _make_conv(self) -> Callable:
1163
- if self.time_mode != "attn-only":
1164
- return partialclass(AE3DConv, video_kernel_size=self.video_kernel_size)
1165
- else:
1166
- return Conv2DWrapper
1167
-
1168
- def _make_resblock(self) -> Callable:
1169
- if self.time_mode not in ["attn-only", "only-last-conv"]:
1170
- return partialclass(
1171
- VideoResBlock,
1172
- video_kernel_size=self.video_kernel_size,
1173
- alpha=self.alpha,
1174
- merge_strategy=self.merge_strategy,
1175
- )
1176
- else:
 
1177
  return super()._make_resblock()
 
1
+ #### https://github.com/Stability-AI/generative-models
2
+ from einops import rearrange, repeat
3
+ import logging
4
+ from typing import Any, Callable, Optional, Iterable, Union
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn as nn
9
+ from packaging import version
10
+ logpy = logging.getLogger(__name__)
11
+
12
+ try:
13
+ import xformers
14
+ import xformers.ops
15
+
16
+ XFORMERS_IS_AVAILABLE = True
17
+ except:
18
+ XFORMERS_IS_AVAILABLE = False
19
+ logpy.warning("no module 'xformers'. Processing without...")
20
+
21
+ from lvdm.modules.attention_svd import LinearAttention, MemoryEfficientCrossAttention
22
+
23
+
24
+ def nonlinearity(x):
25
+ # swish
26
+ return x * torch.sigmoid(x)
27
+
28
+
29
+ def Normalize(in_channels, num_groups=32):
30
+ return torch.nn.GroupNorm(
31
+ num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True
32
+ )
33
+
34
+
35
+ class ResnetBlock(nn.Module):
36
+ def __init__(
37
+ self,
38
+ *,
39
+ in_channels,
40
+ out_channels=None,
41
+ conv_shortcut=False,
42
+ dropout,
43
+ temb_channels=512,
44
+ ):
45
+ super().__init__()
46
+ self.in_channels = in_channels
47
+ out_channels = in_channels if out_channels is None else out_channels
48
+ self.out_channels = out_channels
49
+ self.use_conv_shortcut = conv_shortcut
50
+
51
+ self.norm1 = Normalize(in_channels)
52
+ self.conv1 = torch.nn.Conv2d(
53
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
54
+ )
55
+ if temb_channels > 0:
56
+ self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
57
+ self.norm2 = Normalize(out_channels)
58
+ self.dropout = torch.nn.Dropout(dropout)
59
+ self.conv2 = torch.nn.Conv2d(
60
+ out_channels, out_channels, kernel_size=3, stride=1, padding=1
61
+ )
62
+ if self.in_channels != self.out_channels:
63
+ if self.use_conv_shortcut:
64
+ self.conv_shortcut = torch.nn.Conv2d(
65
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
66
+ )
67
+ else:
68
+ self.nin_shortcut = torch.nn.Conv2d(
69
+ in_channels, out_channels, kernel_size=1, stride=1, padding=0
70
+ )
71
+
72
+ def forward(self, x, temb):
73
+ h = x
74
+ h = self.norm1(h)
75
+ h = nonlinearity(h)
76
+ h = self.conv1(h)
77
+
78
+ if temb is not None:
79
+ h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
80
+
81
+ h = self.norm2(h)
82
+ h = nonlinearity(h)
83
+ h = self.dropout(h)
84
+ h = self.conv2(h)
85
+
86
+ if self.in_channels != self.out_channels:
87
+ if self.use_conv_shortcut:
88
+ x = self.conv_shortcut(x)
89
+ else:
90
+ x = self.nin_shortcut(x)
91
+
92
+ return x + h
93
+
94
+
95
+ class LinAttnBlock(LinearAttention):
96
+ """to match AttnBlock usage"""
97
+
98
+ def __init__(self, in_channels):
99
+ super().__init__(dim=in_channels, heads=1, dim_head=in_channels)
100
+
101
+
102
+ class AttnBlock(nn.Module):
103
+ def __init__(self, in_channels):
104
+ super().__init__()
105
+ self.in_channels = in_channels
106
+
107
+ self.norm = Normalize(in_channels)
108
+ self.q = torch.nn.Conv2d(
109
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
110
+ )
111
+ self.k = torch.nn.Conv2d(
112
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
113
+ )
114
+ self.v = torch.nn.Conv2d(
115
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
116
+ )
117
+ self.proj_out = torch.nn.Conv2d(
118
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
119
+ )
120
+
121
+ def attention(self, h_: torch.Tensor) -> torch.Tensor:
122
+ h_ = self.norm(h_)
123
+ q = self.q(h_)
124
+ k = self.k(h_)
125
+ v = self.v(h_)
126
+
127
+ b, c, h, w = q.shape
128
+ q, k, v = map(
129
+ lambda x: rearrange(x, "b c h w -> b 1 (h w) c").contiguous(), (q, k, v)
130
+ )
131
+ h_ = torch.nn.functional.scaled_dot_product_attention(
132
+ q, k, v
133
+ ) # scale is dim ** -0.5 per default
134
+ # compute attention
135
+
136
+ return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
137
+
138
+ def forward(self, x, **kwargs):
139
+ h_ = x
140
+ h_ = self.attention(h_)
141
+ h_ = self.proj_out(h_)
142
+ return x + h_
143
+
144
+
145
+ class MemoryEfficientAttnBlock(nn.Module):
146
+ """
147
+ Uses xformers efficient implementation,
148
+ see https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
149
+ Note: this is a single-head self-attention operation
150
+ """
151
+
152
+ #
153
+ def __init__(self, in_channels):
154
+ super().__init__()
155
+ self.in_channels = in_channels
156
+
157
+ self.norm = Normalize(in_channels)
158
+ self.q = torch.nn.Conv2d(
159
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
160
+ )
161
+ self.k = torch.nn.Conv2d(
162
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
163
+ )
164
+ self.v = torch.nn.Conv2d(
165
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
166
+ )
167
+ self.proj_out = torch.nn.Conv2d(
168
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
169
+ )
170
+ self.attention_op: Optional[Any] = None
171
+
172
+ def attention(self, h_: torch.Tensor) -> torch.Tensor:
173
+ h_ = self.norm(h_)
174
+ q = self.q(h_)
175
+ k = self.k(h_)
176
+ v = self.v(h_)
177
+
178
+ # compute attention
179
+ B, C, H, W = q.shape
180
+ q, k, v = map(lambda x: rearrange(x, "b c h w -> b (h w) c"), (q, k, v))
181
+
182
+ q, k, v = map(
183
+ lambda t: t.unsqueeze(3)
184
+ .reshape(B, t.shape[1], 1, C)
185
+ .permute(0, 2, 1, 3)
186
+ .reshape(B * 1, t.shape[1], C)
187
+ .contiguous(),
188
+ (q, k, v),
189
+ )
190
+ out = xformers.ops.memory_efficient_attention(
191
+ q, k, v, attn_bias=None, op=self.attention_op
192
+ )
193
+
194
+ out = (
195
+ out.unsqueeze(0)
196
+ .reshape(B, 1, out.shape[1], C)
197
+ .permute(0, 2, 1, 3)
198
+ .reshape(B, out.shape[1], C)
199
+ )
200
+ return rearrange(out, "b (h w) c -> b c h w", b=B, h=H, w=W, c=C)
201
+
202
+ def forward(self, x, **kwargs):
203
+ h_ = x
204
+ h_ = self.attention(h_)
205
+ h_ = self.proj_out(h_)
206
+ return x + h_
207
+
208
+
209
+ class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention):
210
+ def forward(self, x, context=None, mask=None, **unused_kwargs):
211
+ b, c, h, w = x.shape
212
+ x = rearrange(x, "b c h w -> b (h w) c")
213
+ out = super().forward(x, context=context, mask=mask)
214
+ out = rearrange(out, "b (h w) c -> b c h w", h=h, w=w, c=c)
215
+ return x + out
216
+
217
+
218
+ def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
219
+ assert attn_type in [
220
+ "vanilla",
221
+ "vanilla-xformers",
222
+ "memory-efficient-cross-attn",
223
+ "linear",
224
+ "none",
225
+ "memory-efficient-cross-attn-fusion",
226
+ ], f"attn_type {attn_type} unknown"
227
+ if (
228
+ version.parse(torch.__version__) < version.parse("2.0.0")
229
+ and attn_type != "none"
230
+ ):
231
+ assert XFORMERS_IS_AVAILABLE, (
232
+ f"We do not support vanilla attention in {torch.__version__} anymore, "
233
+ f"as it is too expensive. Please install xformers via e.g. 'pip install xformers==0.0.16'"
234
+ )
235
+ # attn_type = "vanilla-xformers"
236
+ logpy.info(f"making attention of type '{attn_type}' with {in_channels} in_channels")
237
+ if attn_type == "vanilla":
238
+ assert attn_kwargs is None
239
+ return AttnBlock(in_channels)
240
+ elif attn_type == "vanilla-xformers":
241
+ logpy.info(
242
+ f"building MemoryEfficientAttnBlock with {in_channels} in_channels..."
243
+ )
244
+ return MemoryEfficientAttnBlock(in_channels)
245
+ elif attn_type == "memory-efficient-cross-attn":
246
+ attn_kwargs["query_dim"] = in_channels
247
+ return MemoryEfficientCrossAttentionWrapper(**attn_kwargs)
248
+ elif attn_type == "memory-efficient-cross-attn-fusion":
249
+ attn_kwargs["query_dim"] = in_channels
250
+ return MemoryEfficientCrossAttentionWrapperFusion(**attn_kwargs)
251
+ elif attn_type == "none":
252
+ return nn.Identity(in_channels)
253
+ else:
254
+ return LinAttnBlock(in_channels)
255
+
256
+ class MemoryEfficientCrossAttentionWrapperFusion(MemoryEfficientCrossAttention):
257
+ # print('x.shape: ',x.shape, 'context.shape: ',context.shape) ##torch.Size([8, 128, 256, 256]) torch.Size([1, 128, 2, 256, 256])
258
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0, **kwargs):
259
+ super().__init__(query_dim, context_dim, heads, dim_head, dropout, **kwargs)
260
+ self.norm = Normalize(query_dim)
261
+ nn.init.zeros_(self.to_out[0].weight)
262
+ nn.init.zeros_(self.to_out[0].bias)
263
+
264
+ def forward(self, x, context=None, mask=None):
265
+ if self.training:
266
+ return checkpoint(self._forward, x, context, mask, use_reentrant=False)
267
+ else:
268
+ return self._forward(x, context, mask)
269
+
270
+ def _forward(
271
+ self,
272
+ x,
273
+ context=None,
274
+ mask=None,
275
+ ):
276
+ bt, c, h, w = x.shape
277
+ h_ = self.norm(x)
278
+ h_ = rearrange(h_, "b c h w -> b (h w) c")
279
+ q = self.to_q(h_)
280
+
281
+
282
+ b, c, l, h, w = context.shape
283
+ context = rearrange(context, "b c l h w -> (b l) (h w) c")
284
+ k = self.to_k(context)
285
+ v = self.to_v(context)
286
+ k = rearrange(k, "(b l) d c -> b l d c", l=l)
287
+ k = torch.cat([k[:, [0] * (bt//b)], k[:, [1]*(bt//b)]], dim=2)
288
+ k = rearrange(k, "b l d c -> (b l) d c")
289
+
290
+ v = rearrange(v, "(b l) d c -> b l d c", l=l)
291
+ v = torch.cat([v[:, [0] * (bt//b)], v[:, [1]*(bt//b)]], dim=2)
292
+ v = rearrange(v, "b l d c -> (b l) d c")
293
+
294
+
295
+ b, _, _ = q.shape ##actually bt
296
+ q, k, v = map(
297
+ lambda t: t.unsqueeze(3)
298
+ .reshape(b, t.shape[1], self.heads, self.dim_head)
299
+ .permute(0, 2, 1, 3)
300
+ .reshape(b * self.heads, t.shape[1], self.dim_head)
301
+ .contiguous(),
302
+ (q, k, v),
303
+ )
304
+
305
+ # actually compute the attention, what we cannot get enough of
306
+ if version.parse(xformers.__version__) >= version.parse("0.0.21"):
307
+ # NOTE: workaround for
308
+ # https://github.com/facebookresearch/xformers/issues/845
309
+ max_bs = 32768
310
+ N = q.shape[0]
311
+ n_batches = math.ceil(N / max_bs)
312
+ out = list()
313
+ for i_batch in range(n_batches):
314
+ batch = slice(i_batch * max_bs, (i_batch + 1) * max_bs)
315
+ out.append(
316
+ xformers.ops.memory_efficient_attention(
317
+ q[batch],
318
+ k[batch],
319
+ v[batch],
320
+ attn_bias=None,
321
+ op=self.attention_op,
322
+ )
323
+ )
324
+ out = torch.cat(out, 0)
325
+ else:
326
+ out = xformers.ops.memory_efficient_attention(
327
+ q, k, v, attn_bias=None, op=self.attention_op
328
+ )
329
+
330
+ # TODO: Use this directly in the attention operation, as a bias
331
+ if exists(mask):
332
+ raise NotImplementedError
333
+ out = (
334
+ out.unsqueeze(0)
335
+ .reshape(b, self.heads, out.shape[1], self.dim_head)
336
+ .permute(0, 2, 1, 3)
337
+ .reshape(b, out.shape[1], self.heads * self.dim_head)
338
+ )
339
+ out = self.to_out(out)
340
+ out = rearrange(out, "bt (h w) c -> bt c h w", h=h, w=w, c=c)
341
+ return x + out
342
+
343
+ class Combiner(nn.Module):
344
+ def __init__(self, ch) -> None:
345
+ super().__init__()
346
+ self.conv = nn.Conv2d(ch,ch,1,padding=0)
347
+
348
+ nn.init.zeros_(self.conv.weight)
349
+ nn.init.zeros_(self.conv.bias)
350
+
351
+ def forward(self, x, context):
352
+ if self.training:
353
+ return checkpoint(self._forward, x, context, use_reentrant=False)
354
+ else:
355
+ return self._forward(x, context)
356
+
357
+ def _forward(self, x, context):
358
+ ## x: b c h w, context: b c 2 h w
359
+ b, c, l, h, w = context.shape
360
+ bt, c, h, w = x.shape
361
+ context = rearrange(context, "b c l h w -> (b l) c h w")
362
+ context = self.conv(context)
363
+ context = rearrange(context, "(b l) c h w -> b c l h w", l=l)
364
+ x = rearrange(x, "(b t) c h w -> b c t h w", t=bt//b)
365
+ x[:,:,0] = x[:,:,0] + context[:,:,0]
366
+ x[:,:,-1] = x[:,:,-1] + context[:,:,1]
367
+ x = rearrange(x, "b c t h w -> (b t) c h w")
368
+ return x
369
+
370
+
371
+ class Decoder(nn.Module):
372
+ def __init__(
373
+ self,
374
+ *,
375
+ ch,
376
+ out_ch,
377
+ ch_mult=(1, 2, 4, 8),
378
+ num_res_blocks,
379
+ attn_resolutions,
380
+ dropout=0.0,
381
+ resamp_with_conv=True,
382
+ in_channels,
383
+ resolution,
384
+ z_channels,
385
+ give_pre_end=False,
386
+ tanh_out=False,
387
+ use_linear_attn=False,
388
+ attn_type="vanilla-xformers",
389
+ attn_level=[2,3],
390
+ **ignorekwargs,
391
+ ):
392
+ super().__init__()
393
+ if use_linear_attn:
394
+ attn_type = "linear"
395
+ self.ch = ch
396
+ self.temb_ch = 0
397
+ self.num_resolutions = len(ch_mult)
398
+ self.num_res_blocks = num_res_blocks
399
+ self.resolution = resolution
400
+ self.in_channels = in_channels
401
+ self.give_pre_end = give_pre_end
402
+ self.tanh_out = tanh_out
403
+ self.attn_level = attn_level
404
+ # compute in_ch_mult, block_in and curr_res at lowest res
405
+ in_ch_mult = (1,) + tuple(ch_mult)
406
+ block_in = ch * ch_mult[self.num_resolutions - 1]
407
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
408
+ self.z_shape = (1, z_channels, curr_res, curr_res)
409
+ logpy.info(
410
+ "Working with z of shape {} = {} dimensions.".format(
411
+ self.z_shape, np.prod(self.z_shape)
412
+ )
413
+ )
414
+
415
+ make_attn_cls = self._make_attn()
416
+ make_resblock_cls = self._make_resblock()
417
+ make_conv_cls = self._make_conv()
418
+ # z to block_in
419
+ self.conv_in = torch.nn.Conv2d(
420
+ z_channels, block_in, kernel_size=3, stride=1, padding=1
421
+ )
422
+
423
+ # middle
424
+ self.mid = nn.Module()
425
+ self.mid.block_1 = make_resblock_cls(
426
+ in_channels=block_in,
427
+ out_channels=block_in,
428
+ temb_channels=self.temb_ch,
429
+ dropout=dropout,
430
+ )
431
+ self.mid.attn_1 = make_attn_cls(block_in, attn_type=attn_type)
432
+ self.mid.block_2 = make_resblock_cls(
433
+ in_channels=block_in,
434
+ out_channels=block_in,
435
+ temb_channels=self.temb_ch,
436
+ dropout=dropout,
437
+ )
438
+
439
+ # upsampling
440
+ self.up = nn.ModuleList()
441
+ self.attn_refinement = nn.ModuleList()
442
+ for i_level in reversed(range(self.num_resolutions)):
443
+ block = nn.ModuleList()
444
+ attn = nn.ModuleList()
445
+ block_out = ch * ch_mult[i_level]
446
+ for i_block in range(self.num_res_blocks + 1):
447
+ block.append(
448
+ make_resblock_cls(
449
+ in_channels=block_in,
450
+ out_channels=block_out,
451
+ temb_channels=self.temb_ch,
452
+ dropout=dropout,
453
+ )
454
+ )
455
+ block_in = block_out
456
+ if curr_res in attn_resolutions:
457
+ attn.append(make_attn_cls(block_in, attn_type=attn_type))
458
+ up = nn.Module()
459
+ up.block = block
460
+ up.attn = attn
461
+ if i_level != 0:
462
+ up.upsample = Upsample(block_in, resamp_with_conv)
463
+ curr_res = curr_res * 2
464
+ self.up.insert(0, up) # prepend to get consistent order
465
+
466
+ if i_level in self.attn_level:
467
+ self.attn_refinement.insert(0, make_attn_cls(block_in, attn_type='memory-efficient-cross-attn-fusion', attn_kwargs={}))
468
+ else:
469
+ self.attn_refinement.insert(0, Combiner(block_in))
470
+ # end
471
+ self.norm_out = Normalize(block_in)
472
+ self.attn_refinement.append(Combiner(block_in))
473
+ self.conv_out = make_conv_cls(
474
+ block_in, out_ch, kernel_size=3, stride=1, padding=1
475
+ )
476
+
477
+ def _make_attn(self) -> Callable:
478
+ return make_attn
479
+
480
+ def _make_resblock(self) -> Callable:
481
+ return ResnetBlock
482
+
483
+ def _make_conv(self) -> Callable:
484
+ return torch.nn.Conv2d
485
+
486
+ def get_last_layer(self, **kwargs):
487
+ return self.conv_out.weight
488
+
489
+ def forward(self, z, ref_context=None, **kwargs):
490
+ ## ref_context: b c 2 h w, 2 means starting and ending frame
491
+ # assert z.shape[1:] == self.z_shape[1:]
492
+ ref_context = None
493
+ self.last_z_shape = z.shape
494
+ # timestep embedding
495
+ temb = None
496
+
497
+ # z to block_in
498
+ h = self.conv_in(z)
499
+
500
+ # middle
501
+ h = self.mid.block_1(h, temb, **kwargs)
502
+ h = self.mid.attn_1(h, **kwargs)
503
+ h = self.mid.block_2(h, temb, **kwargs)
504
+
505
+ # upsampling
506
+ for i_level in reversed(range(self.num_resolutions)):
507
+ for i_block in range(self.num_res_blocks + 1):
508
+ h = self.up[i_level].block[i_block](h, temb, **kwargs)
509
+ if len(self.up[i_level].attn) > 0:
510
+ h = self.up[i_level].attn[i_block](h, **kwargs)
511
+ if ref_context:
512
+ h = self.attn_refinement[i_level](x=h, context=ref_context[i_level])
513
+ if i_level != 0:
514
+ h = self.up[i_level].upsample(h)
515
+
516
+ # end
517
+ if self.give_pre_end:
518
+ return h
519
+
520
+ h = self.norm_out(h)
521
+ h = nonlinearity(h)
522
+ if ref_context:
523
+ # print(h.shape, ref_context[i_level].shape) #torch.Size([8, 128, 256, 256]) torch.Size([1, 128, 2, 256, 256])
524
+ h = self.attn_refinement[-1](x=h, context=ref_context[-1])
525
+ h = self.conv_out(h, **kwargs)
526
+ if self.tanh_out:
527
+ h = torch.tanh(h)
528
+ return h
529
+
530
+ #####
531
+
532
+
533
+ from abc import abstractmethod
534
+ from lvdm.models.utils_diffusion import timestep_embedding
535
+
536
+ from torch.utils.checkpoint import checkpoint
537
+ from lvdm.basics import (
538
+ zero_module,
539
+ conv_nd,
540
+ linear,
541
+ normalization,
542
+ )
543
+ from lvdm.modules.networks.openaimodel3d import Upsample, Downsample
544
+ class TimestepBlock(nn.Module):
545
+ """
546
+ Any module where forward() takes timestep embeddings as a second argument.
547
+ """
548
+
549
+ @abstractmethod
550
+ def forward(self, x: torch.Tensor, emb: torch.Tensor):
551
+ """
552
+ Apply the module to `x` given `emb` timestep embeddings.
553
+ """
554
+
555
+ class ResBlock(TimestepBlock):
556
+ """
557
+ A residual block that can optionally change the number of channels.
558
+ :param channels: the number of input channels.
559
+ :param emb_channels: the number of timestep embedding channels.
560
+ :param dropout: the rate of dropout.
561
+ :param out_channels: if specified, the number of out channels.
562
+ :param use_conv: if True and out_channels is specified, use a spatial
563
+ convolution instead of a smaller 1x1 convolution to change the
564
+ channels in the skip connection.
565
+ :param dims: determines if the signal is 1D, 2D, or 3D.
566
+ :param use_checkpoint: if True, use gradient checkpointing on this module.
567
+ :param up: if True, use this block for upsampling.
568
+ :param down: if True, use this block for downsampling.
569
+ """
570
+
571
+ def __init__(
572
+ self,
573
+ channels: int,
574
+ emb_channels: int,
575
+ dropout: float,
576
+ out_channels: Optional[int] = None,
577
+ use_conv: bool = False,
578
+ use_scale_shift_norm: bool = False,
579
+ dims: int = 2,
580
+ use_checkpoint: bool = False,
581
+ up: bool = False,
582
+ down: bool = False,
583
+ kernel_size: int = 3,
584
+ exchange_temb_dims: bool = False,
585
+ skip_t_emb: bool = False,
586
+ ):
587
+ super().__init__()
588
+ self.channels = channels
589
+ self.emb_channels = emb_channels
590
+ self.dropout = dropout
591
+ self.out_channels = out_channels or channels
592
+ self.use_conv = use_conv
593
+ self.use_checkpoint = use_checkpoint
594
+ self.use_scale_shift_norm = use_scale_shift_norm
595
+ self.exchange_temb_dims = exchange_temb_dims
596
+
597
+ if isinstance(kernel_size, Iterable):
598
+ padding = [k // 2 for k in kernel_size]
599
+ else:
600
+ padding = kernel_size // 2
601
+
602
+ self.in_layers = nn.Sequential(
603
+ normalization(channels),
604
+ nn.SiLU(),
605
+ conv_nd(dims, channels, self.out_channels, kernel_size, padding=padding),
606
+ )
607
+
608
+ self.updown = up or down
609
+
610
+ if up:
611
+ self.h_upd = Upsample(channels, False, dims)
612
+ self.x_upd = Upsample(channels, False, dims)
613
+ elif down:
614
+ self.h_upd = Downsample(channels, False, dims)
615
+ self.x_upd = Downsample(channels, False, dims)
616
+ else:
617
+ self.h_upd = self.x_upd = nn.Identity()
618
+
619
+ self.skip_t_emb = skip_t_emb
620
+ self.emb_out_channels = (
621
+ 2 * self.out_channels if use_scale_shift_norm else self.out_channels
622
+ )
623
+ if self.skip_t_emb:
624
+ # print(f"Skipping timestep embedding in {self.__class__.__name__}")
625
+ assert not self.use_scale_shift_norm
626
+ self.emb_layers = None
627
+ self.exchange_temb_dims = False
628
+ else:
629
+ self.emb_layers = nn.Sequential(
630
+ nn.SiLU(),
631
+ linear(
632
+ emb_channels,
633
+ self.emb_out_channels,
634
+ ),
635
+ )
636
+
637
+ self.out_layers = nn.Sequential(
638
+ normalization(self.out_channels),
639
+ nn.SiLU(),
640
+ nn.Dropout(p=dropout),
641
+ zero_module(
642
+ conv_nd(
643
+ dims,
644
+ self.out_channels,
645
+ self.out_channels,
646
+ kernel_size,
647
+ padding=padding,
648
+ )
649
+ ),
650
+ )
651
+
652
+ if self.out_channels == channels:
653
+ self.skip_connection = nn.Identity()
654
+ elif use_conv:
655
+ self.skip_connection = conv_nd(
656
+ dims, channels, self.out_channels, kernel_size, padding=padding
657
+ )
658
+ else:
659
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
660
+
661
+ def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor:
662
+ """
663
+ Apply the block to a Tensor, conditioned on a timestep embedding.
664
+ :param x: an [N x C x ...] Tensor of features.
665
+ :param emb: an [N x emb_channels] Tensor of timestep embeddings.
666
+ :return: an [N x C x ...] Tensor of outputs.
667
+ """
668
+ if self.use_checkpoint:
669
+ return checkpoint(self._forward, x, emb, use_reentrant=False)
670
+ else:
671
+ return self._forward(x, emb)
672
+
673
+ def _forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor:
674
+ if self.updown:
675
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
676
+ h = in_rest(x)
677
+ h = self.h_upd(h)
678
+ x = self.x_upd(x)
679
+ h = in_conv(h)
680
+ else:
681
+ h = self.in_layers(x)
682
+
683
+ if self.skip_t_emb:
684
+ emb_out = torch.zeros_like(h)
685
+ else:
686
+ emb_out = self.emb_layers(emb).type(h.dtype)
687
+ while len(emb_out.shape) < len(h.shape):
688
+ emb_out = emb_out[..., None]
689
+ if self.use_scale_shift_norm:
690
+ out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
691
+ scale, shift = torch.chunk(emb_out, 2, dim=1)
692
+ h = out_norm(h) * (1 + scale) + shift
693
+ h = out_rest(h)
694
+ else:
695
+ if self.exchange_temb_dims:
696
+ emb_out = rearrange(emb_out, "b t c ... -> b c t ...")
697
+ h = h + emb_out
698
+ h = self.out_layers(h)
699
+ return self.skip_connection(x) + h
700
+ #####
701
+
702
+ #####
703
+ from lvdm.modules.attention_svd import *
704
+ class VideoTransformerBlock(nn.Module):
705
+ ATTENTION_MODES = {
706
+ "softmax": CrossAttention,
707
+ "softmax-xformers": MemoryEfficientCrossAttention,
708
+ }
709
+
710
+ def __init__(
711
+ self,
712
+ dim,
713
+ n_heads,
714
+ d_head,
715
+ dropout=0.0,
716
+ context_dim=None,
717
+ gated_ff=True,
718
+ checkpoint=True,
719
+ timesteps=None,
720
+ ff_in=False,
721
+ inner_dim=None,
722
+ attn_mode="softmax",
723
+ disable_self_attn=False,
724
+ disable_temporal_crossattention=False,
725
+ switch_temporal_ca_to_sa=False,
726
+ ):
727
+ super().__init__()
728
+
729
+ attn_cls = self.ATTENTION_MODES[attn_mode]
730
+
731
+ self.ff_in = ff_in or inner_dim is not None
732
+ if inner_dim is None:
733
+ inner_dim = dim
734
+
735
+ assert int(n_heads * d_head) == inner_dim
736
+
737
+ self.is_res = inner_dim == dim
738
+
739
+ if self.ff_in:
740
+ self.norm_in = nn.LayerNorm(dim)
741
+ self.ff_in = FeedForward(
742
+ dim, dim_out=inner_dim, dropout=dropout, glu=gated_ff
743
+ )
744
+
745
+ self.timesteps = timesteps
746
+ self.disable_self_attn = disable_self_attn
747
+ if self.disable_self_attn:
748
+ self.attn1 = attn_cls(
749
+ query_dim=inner_dim,
750
+ heads=n_heads,
751
+ dim_head=d_head,
752
+ context_dim=context_dim,
753
+ dropout=dropout,
754
+ ) # is a cross-attention
755
+ else:
756
+ self.attn1 = attn_cls(
757
+ query_dim=inner_dim, heads=n_heads, dim_head=d_head, dropout=dropout
758
+ ) # is a self-attention
759
+
760
+ self.ff = FeedForward(inner_dim, dim_out=dim, dropout=dropout, glu=gated_ff)
761
+
762
+ if disable_temporal_crossattention:
763
+ if switch_temporal_ca_to_sa:
764
+ raise ValueError
765
+ else:
766
+ self.attn2 = None
767
+ else:
768
+ self.norm2 = nn.LayerNorm(inner_dim)
769
+ if switch_temporal_ca_to_sa:
770
+ self.attn2 = attn_cls(
771
+ query_dim=inner_dim, heads=n_heads, dim_head=d_head, dropout=dropout
772
+ ) # is a self-attention
773
+ else:
774
+ self.attn2 = attn_cls(
775
+ query_dim=inner_dim,
776
+ context_dim=context_dim,
777
+ heads=n_heads,
778
+ dim_head=d_head,
779
+ dropout=dropout,
780
+ ) # is self-attn if context is none
781
+
782
+ self.norm1 = nn.LayerNorm(inner_dim)
783
+ self.norm3 = nn.LayerNorm(inner_dim)
784
+ self.switch_temporal_ca_to_sa = switch_temporal_ca_to_sa
785
+
786
+ self.checkpoint = checkpoint
787
+ if self.checkpoint:
788
+ print(f"====>{self.__class__.__name__} is using checkpointing")
789
+ else:
790
+ print(f"====>{self.__class__.__name__} is NOT using checkpointing")
791
+
792
+ def forward(
793
+ self, x: torch.Tensor, context: torch.Tensor = None, timesteps: int = None
794
+ ) -> torch.Tensor:
795
+ if self.checkpoint:
796
+ return checkpoint(self._forward, x, context, timesteps, use_reentrant=False)
797
+ else:
798
+ return self._forward(x, context, timesteps=timesteps)
799
+
800
+ def _forward(self, x, context=None, timesteps=None):
801
+ assert self.timesteps or timesteps
802
+ assert not (self.timesteps and timesteps) or self.timesteps == timesteps
803
+ timesteps = self.timesteps or timesteps
804
+ B, S, C = x.shape
805
+ x = rearrange(x, "(b t) s c -> (b s) t c", t=timesteps)
806
+
807
+ if self.ff_in:
808
+ x_skip = x
809
+ x = self.ff_in(self.norm_in(x))
810
+ if self.is_res:
811
+ x += x_skip
812
+
813
+ if self.disable_self_attn:
814
+ x = self.attn1(self.norm1(x), context=context) + x
815
+ else:
816
+ x = self.attn1(self.norm1(x)) + x
817
+
818
+ if self.attn2 is not None:
819
+ if self.switch_temporal_ca_to_sa:
820
+ x = self.attn2(self.norm2(x)) + x
821
+ else:
822
+ x = self.attn2(self.norm2(x), context=context) + x
823
+ x_skip = x
824
+ x = self.ff(self.norm3(x))
825
+ if self.is_res:
826
+ x += x_skip
827
+
828
+ x = rearrange(
829
+ x, "(b s) t c -> (b t) s c", s=S, b=B // timesteps, c=C, t=timesteps
830
+ )
831
+ return x
832
+
833
+ def get_last_layer(self):
834
+ return self.ff.net[-1].weight
835
+
836
+ #####
837
+
838
+ #####
839
+ import functools
840
+ def partialclass(cls, *args, **kwargs):
841
+ class NewCls(cls):
842
+ __init__ = functools.partialmethod(cls.__init__, *args, **kwargs)
843
+
844
+ return NewCls
845
+ ######
846
+
847
+ class VideoResBlock(ResnetBlock):
848
+ def __init__(
849
+ self,
850
+ out_channels,
851
+ *args,
852
+ dropout=0.0,
853
+ video_kernel_size=3,
854
+ alpha=0.0,
855
+ merge_strategy="learned",
856
+ **kwargs,
857
+ ):
858
+ super().__init__(out_channels=out_channels, dropout=dropout, *args, **kwargs)
859
+ if video_kernel_size is None:
860
+ video_kernel_size = [3, 1, 1]
861
+ self.time_stack = ResBlock(
862
+ channels=out_channels,
863
+ emb_channels=0,
864
+ dropout=dropout,
865
+ dims=3,
866
+ use_scale_shift_norm=False,
867
+ use_conv=False,
868
+ up=False,
869
+ down=False,
870
+ kernel_size=video_kernel_size,
871
+ use_checkpoint=True,
872
+ skip_t_emb=True,
873
+ )
874
+
875
+ self.merge_strategy = merge_strategy
876
+ if self.merge_strategy == "fixed":
877
+ self.register_buffer("mix_factor", torch.Tensor([alpha]))
878
+ elif self.merge_strategy == "learned":
879
+ self.register_parameter(
880
+ "mix_factor", torch.nn.Parameter(torch.Tensor([alpha]))
881
+ )
882
+ else:
883
+ raise ValueError(f"unknown merge strategy {self.merge_strategy}")
884
+
885
+ def get_alpha(self, bs):
886
+ if self.merge_strategy == "fixed":
887
+ return self.mix_factor
888
+ elif self.merge_strategy == "learned":
889
+ return torch.sigmoid(self.mix_factor)
890
+ else:
891
+ raise NotImplementedError()
892
+
893
+ def forward(self, x, temb, skip_video=False, timesteps=None):
894
+ if timesteps is None:
895
+ timesteps = self.timesteps
896
+
897
+ b, c, h, w = x.shape
898
+
899
+ x = super().forward(x, temb)
900
+
901
+ if not skip_video:
902
+ x_mix = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps)
903
+
904
+ x = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps)
905
+
906
+ x = self.time_stack(x, temb)
907
+
908
+ alpha = self.get_alpha(bs=b // timesteps)
909
+ x = alpha * x + (1.0 - alpha) * x_mix
910
+
911
+ x = rearrange(x, "b c t h w -> (b t) c h w")
912
+ return x
913
+
914
+
915
+ class AE3DConv(torch.nn.Conv2d):
916
+ def __init__(self, in_channels, out_channels, video_kernel_size=3, *args, **kwargs):
917
+ super().__init__(in_channels, out_channels, *args, **kwargs)
918
+ if isinstance(video_kernel_size, Iterable):
919
+ padding = [int(k // 2) for k in video_kernel_size]
920
+ else:
921
+ padding = int(video_kernel_size // 2)
922
+
923
+ self.time_mix_conv = torch.nn.Conv3d(
924
+ in_channels=out_channels,
925
+ out_channels=out_channels,
926
+ kernel_size=video_kernel_size,
927
+ padding=padding,
928
+ )
929
+
930
+ def forward(self, input, timesteps, skip_video=False):
931
+ x = super().forward(input)
932
+ if skip_video:
933
+ return x
934
+ x = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps)
935
+ x = self.time_mix_conv(x)
936
+ return rearrange(x, "b c t h w -> (b t) c h w")
937
+
938
+
939
+ class VideoBlock(AttnBlock):
940
+ def __init__(
941
+ self, in_channels: int, alpha: float = 0, merge_strategy: str = "learned"
942
+ ):
943
+ super().__init__(in_channels)
944
+ # no context, single headed, as in base class
945
+ self.time_mix_block = VideoTransformerBlock(
946
+ dim=in_channels,
947
+ n_heads=1,
948
+ d_head=in_channels,
949
+ checkpoint=True,
950
+ ff_in=True,
951
+ attn_mode="softmax",
952
+ )
953
+
954
+ time_embed_dim = self.in_channels * 4
955
+ self.video_time_embed = torch.nn.Sequential(
956
+ torch.nn.Linear(self.in_channels, time_embed_dim),
957
+ torch.nn.SiLU(),
958
+ torch.nn.Linear(time_embed_dim, self.in_channels),
959
+ )
960
+
961
+ self.merge_strategy = merge_strategy
962
+ if self.merge_strategy == "fixed":
963
+ self.register_buffer("mix_factor", torch.Tensor([alpha]))
964
+ elif self.merge_strategy == "learned":
965
+ self.register_parameter(
966
+ "mix_factor", torch.nn.Parameter(torch.Tensor([alpha]))
967
+ )
968
+ else:
969
+ raise ValueError(f"unknown merge strategy {self.merge_strategy}")
970
+
971
+ def forward(self, x, timesteps, skip_video=False):
972
+ if skip_video:
973
+ return super().forward(x)
974
+
975
+ x_in = x
976
+ x = self.attention(x)
977
+ h, w = x.shape[2:]
978
+ x = rearrange(x, "b c h w -> b (h w) c")
979
+
980
+ x_mix = x
981
+ num_frames = torch.arange(timesteps, device=x.device)
982
+ num_frames = repeat(num_frames, "t -> b t", b=x.shape[0] // timesteps)
983
+ num_frames = rearrange(num_frames, "b t -> (b t)")
984
+ t_emb = timestep_embedding(num_frames, self.in_channels, repeat_only=False)
985
+ emb = self.video_time_embed(t_emb) # b, n_channels
986
+ emb = emb[:, None, :]
987
+ x_mix = x_mix + emb
988
+
989
+ alpha = self.get_alpha()
990
+ x_mix = self.time_mix_block(x_mix, timesteps=timesteps)
991
+ x = alpha * x + (1.0 - alpha) * x_mix # alpha merge
992
+
993
+ x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
994
+ x = self.proj_out(x)
995
+
996
+ return x_in + x
997
+
998
+ def get_alpha(
999
+ self,
1000
+ ):
1001
+ if self.merge_strategy == "fixed":
1002
+ return self.mix_factor
1003
+ elif self.merge_strategy == "learned":
1004
+ return torch.sigmoid(self.mix_factor)
1005
+ else:
1006
+ raise NotImplementedError(f"unknown merge strategy {self.merge_strategy}")
1007
+
1008
+
1009
+ class MemoryEfficientVideoBlock(MemoryEfficientAttnBlock):
1010
+ def __init__(
1011
+ self, in_channels: int, alpha: float = 0, merge_strategy: str = "learned"
1012
+ ):
1013
+ super().__init__(in_channels)
1014
+ # no context, single headed, as in base class
1015
+ self.time_mix_block = VideoTransformerBlock(
1016
+ dim=in_channels,
1017
+ n_heads=1,
1018
+ d_head=in_channels,
1019
+ checkpoint=True,
1020
+ ff_in=True,
1021
+ attn_mode="softmax-xformers",
1022
+ )
1023
+
1024
+ time_embed_dim = self.in_channels * 4
1025
+ self.video_time_embed = torch.nn.Sequential(
1026
+ torch.nn.Linear(self.in_channels, time_embed_dim),
1027
+ torch.nn.SiLU(),
1028
+ torch.nn.Linear(time_embed_dim, self.in_channels),
1029
+ )
1030
+
1031
+ self.merge_strategy = merge_strategy
1032
+ if self.merge_strategy == "fixed":
1033
+ self.register_buffer("mix_factor", torch.Tensor([alpha]))
1034
+ elif self.merge_strategy == "learned":
1035
+ self.register_parameter(
1036
+ "mix_factor", torch.nn.Parameter(torch.Tensor([alpha]))
1037
+ )
1038
+ else:
1039
+ raise ValueError(f"unknown merge strategy {self.merge_strategy}")
1040
+
1041
+ def forward(self, x, timesteps, skip_time_block=False):
1042
+ if skip_time_block:
1043
+ return super().forward(x)
1044
+
1045
+ x_in = x
1046
+ x = self.attention(x)
1047
+ h, w = x.shape[2:]
1048
+ x = rearrange(x, "b c h w -> b (h w) c")
1049
+
1050
+ x_mix = x
1051
+ num_frames = torch.arange(timesteps, device=x.device)
1052
+ num_frames = repeat(num_frames, "t -> b t", b=x.shape[0] // timesteps)
1053
+ num_frames = rearrange(num_frames, "b t -> (b t)")
1054
+ t_emb = timestep_embedding(num_frames, self.in_channels, repeat_only=False)
1055
+ emb = self.video_time_embed(t_emb) # b, n_channels
1056
+ emb = emb[:, None, :]
1057
+ x_mix = x_mix + emb
1058
+
1059
+ alpha = self.get_alpha()
1060
+ x_mix = self.time_mix_block(x_mix, timesteps=timesteps)
1061
+ x = alpha * x + (1.0 - alpha) * x_mix # alpha merge
1062
+
1063
+ x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
1064
+ x = self.proj_out(x)
1065
+
1066
+ return x_in + x
1067
+
1068
+ def get_alpha(
1069
+ self,
1070
+ ):
1071
+ if self.merge_strategy == "fixed":
1072
+ return self.mix_factor
1073
+ elif self.merge_strategy == "learned":
1074
+ return torch.sigmoid(self.mix_factor)
1075
+ else:
1076
+ raise NotImplementedError(f"unknown merge strategy {self.merge_strategy}")
1077
+
1078
+
1079
+ def make_time_attn(
1080
+ in_channels,
1081
+ attn_type="vanilla",
1082
+ attn_kwargs=None,
1083
+ alpha: float = 0,
1084
+ merge_strategy: str = "learned",
1085
+ ):
1086
+ assert attn_type in [
1087
+ "vanilla",
1088
+ "vanilla-xformers",
1089
+ ], f"attn_type {attn_type} not supported for spatio-temporal attention"
1090
+ print(
1091
+ f"making spatial and temporal attention of type '{attn_type}' with {in_channels} in_channels"
1092
+ )
1093
+ if not XFORMERS_IS_AVAILABLE and attn_type == "vanilla-xformers":
1094
+ print(
1095
+ f"Attention mode '{attn_type}' is not available. Falling back to vanilla attention. "
1096
+ f"This is not a problem in Pytorch >= 2.0. FYI, you are running with PyTorch version {torch.__version__}"
1097
+ )
1098
+ attn_type = "vanilla"
1099
+
1100
+ if attn_type == "vanilla":
1101
+ assert attn_kwargs is None
1102
+ return partialclass(
1103
+ VideoBlock, in_channels, alpha=alpha, merge_strategy=merge_strategy
1104
+ )
1105
+ elif attn_type == "vanilla-xformers":
1106
+ print(f"building MemoryEfficientAttnBlock with {in_channels} in_channels...")
1107
+ return partialclass(
1108
+ MemoryEfficientVideoBlock,
1109
+ in_channels,
1110
+ alpha=alpha,
1111
+ merge_strategy=merge_strategy,
1112
+ )
1113
+ else:
1114
+ return NotImplementedError()
1115
+
1116
+
1117
+ class Conv2DWrapper(torch.nn.Conv2d):
1118
+ def forward(self, input: torch.Tensor, **kwargs) -> torch.Tensor:
1119
+ return super().forward(input)
1120
+
1121
+
1122
+ class VideoDecoder(Decoder):
1123
+ available_time_modes = ["all", "conv-only", "attn-only"]
1124
+
1125
+ def __init__(
1126
+ self,
1127
+ *args,
1128
+ video_kernel_size: Union[int, list] = [3,1,1],
1129
+ alpha: float = 0.0,
1130
+ merge_strategy: str = "learned",
1131
+ time_mode: str = "conv-only",
1132
+ **kwargs,
1133
+ ):
1134
+ self.video_kernel_size = video_kernel_size
1135
+ self.alpha = alpha
1136
+ self.merge_strategy = merge_strategy
1137
+ self.time_mode = time_mode
1138
+ assert (
1139
+ self.time_mode in self.available_time_modes
1140
+ ), f"time_mode parameter has to be in {self.available_time_modes}"
1141
+ super().__init__(*args, **kwargs)
1142
+
1143
+ def get_last_layer(self, skip_time_mix=False, **kwargs):
1144
+ if self.time_mode == "attn-only":
1145
+ raise NotImplementedError("TODO")
1146
+ else:
1147
+ return (
1148
+ self.conv_out.time_mix_conv.weight
1149
+ if not skip_time_mix
1150
+ else self.conv_out.weight
1151
+ )
1152
+
1153
+ def _make_attn(self) -> Callable:
1154
+ if self.time_mode not in ["conv-only", "only-last-conv"]:
1155
+ return partialclass(
1156
+ make_time_attn,
1157
+ alpha=self.alpha,
1158
+ merge_strategy=self.merge_strategy,
1159
+ )
1160
+ else:
1161
+ return super()._make_attn()
1162
+
1163
+ def _make_conv(self) -> Callable:
1164
+ if self.time_mode != "attn-only":
1165
+ return partialclass(AE3DConv, video_kernel_size=self.video_kernel_size)
1166
+ else:
1167
+ return Conv2DWrapper
1168
+
1169
+ def _make_resblock(self) -> Callable:
1170
+ if self.time_mode not in ["attn-only", "only-last-conv"]:
1171
+ return partialclass(
1172
+ VideoResBlock,
1173
+ video_kernel_size=self.video_kernel_size,
1174
+ alpha=self.alpha,
1175
+ merge_strategy=self.merge_strategy,
1176
+ )
1177
+ else:
1178
  return super()._make_resblock()