|
from typing import Callable, Iterable, Union |
|
|
|
import torch |
|
from einops import rearrange, repeat |
|
|
|
from sgm.modules.diffusionmodules.model import ( |
|
XFORMERS_IS_AVAILABLE, |
|
AttnBlock, |
|
Decoder, |
|
MemoryEfficientAttnBlock, |
|
ResnetBlock, |
|
) |
|
from sgm.modules.diffusionmodules.openaimodel import ResBlock, timestep_embedding |
|
from sgm.modules.video_attention import VideoTransformerBlock |
|
from sgm.util import partialclass |
|
|
|
|
|
class VideoResBlock(ResnetBlock): |
|
def __init__( |
|
self, |
|
out_channels, |
|
*args, |
|
dropout=0.0, |
|
video_kernel_size=3, |
|
alpha=0.0, |
|
merge_strategy="learned", |
|
**kwargs, |
|
): |
|
super().__init__(out_channels=out_channels, dropout=dropout, *args, **kwargs) |
|
if video_kernel_size is None: |
|
video_kernel_size = [3, 1, 1] |
|
self.time_mix_blocks = ResBlock( |
|
channels=out_channels, |
|
emb_channels=0, |
|
dropout=dropout, |
|
dims=3, |
|
use_scale_shift_norm=False, |
|
use_conv=False, |
|
up=False, |
|
down=False, |
|
kernel_size=video_kernel_size, |
|
use_checkpoint=False, |
|
skip_t_emb=True, |
|
) |
|
|
|
self.merge_strategy = merge_strategy |
|
if self.merge_strategy == "fixed": |
|
self.register_buffer("mix_factor", torch.Tensor([alpha])) |
|
elif self.merge_strategy == "learned": |
|
self.register_parameter( |
|
"mix_factor", torch.nn.Parameter(torch.Tensor([alpha])) |
|
) |
|
else: |
|
raise ValueError(f"unknown merge strategy {self.merge_strategy}") |
|
|
|
def get_alpha(self, bs): |
|
if self.merge_strategy == "fixed": |
|
return self.mix_factor |
|
elif self.merge_strategy == "learned": |
|
return torch.sigmoid(self.mix_factor) |
|
else: |
|
raise NotImplementedError() |
|
|
|
def forward(self, x, temb, skip_video=False, timesteps=None): |
|
if timesteps is None: |
|
timesteps = self.timesteps |
|
|
|
b, c, h, w = x.shape |
|
|
|
x = super().forward(x, temb) |
|
|
|
if not skip_video: |
|
x_mix = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps) |
|
|
|
x = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps) |
|
|
|
x = self.time_mix_blocks(x, temb) |
|
|
|
alpha = self.get_alpha(bs=b // timesteps) |
|
x = alpha * x + (1.0 - alpha) * x_mix |
|
|
|
x = rearrange(x, "b c t h w -> (b t) c h w") |
|
return x |
|
|
|
|
|
class PostHocConv2WithTime(torch.nn.Conv2d): |
|
def __init__(self, in_channels, out_channels, video_kernel_size=3, *args, **kwargs): |
|
super().__init__(in_channels, out_channels, *args, **kwargs) |
|
if isinstance(video_kernel_size, Iterable): |
|
padding = [int(k // 2) for k in video_kernel_size] |
|
else: |
|
padding = int(video_kernel_size // 2) |
|
|
|
self.time_mix_conv = torch.nn.Conv3d( |
|
in_channels=out_channels, |
|
out_channels=out_channels, |
|
kernel_size=video_kernel_size, |
|
padding=padding, |
|
) |
|
|
|
def forward(self, input, timesteps, skip_video=False): |
|
x = super().forward(input) |
|
if skip_video: |
|
return x |
|
x = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps) |
|
x = self.time_mix_conv(x) |
|
return rearrange(x, "b c t h w -> (b t) c h w") |
|
|
|
|
|
class VideoBlock(AttnBlock): |
|
def __init__( |
|
self, in_channels: int, alpha: float = 0, merge_strategy: str = "learned" |
|
): |
|
super().__init__(in_channels) |
|
|
|
self.time_mix_block = VideoTransformerBlock( |
|
dim=in_channels, |
|
n_heads=1, |
|
d_head=in_channels, |
|
checkpoint=False, |
|
ff_in=True, |
|
attn_mode="softmax", |
|
) |
|
|
|
time_embed_dim = self.in_channels * 4 |
|
self.video_time_embed = torch.nn.Sequential( |
|
torch.nn.Linear(self.in_channels, time_embed_dim), |
|
torch.nn.SiLU(), |
|
torch.nn.Linear(time_embed_dim, self.in_channels), |
|
) |
|
|
|
self.merge_strategy = merge_strategy |
|
if self.merge_strategy == "fixed": |
|
self.register_buffer("mix_factor", torch.Tensor([alpha])) |
|
elif self.merge_strategy == "learned": |
|
self.register_parameter( |
|
"mix_factor", torch.nn.Parameter(torch.Tensor([alpha])) |
|
) |
|
else: |
|
raise ValueError(f"unknown merge strategy {self.merge_strategy}") |
|
|
|
def forward(self, x, timesteps, skip_video=False): |
|
if skip_video: |
|
return super().forward(x) |
|
|
|
x_in = x |
|
x = self.attention(x) |
|
h, w = x.shape[2:] |
|
x = rearrange(x, "b c h w -> b (h w) c") |
|
|
|
x_mix = x |
|
num_frames = torch.arange(timesteps, device=x.device) |
|
num_frames = repeat(num_frames, "t -> b t", b=x.shape[0] // timesteps) |
|
num_frames = rearrange(num_frames, "b t -> (b t)") |
|
t_emb = timestep_embedding(num_frames, self.in_channels, repeat_only=False) |
|
emb = self.video_time_embed(t_emb) |
|
emb = emb[:, None, :] |
|
x_mix = x_mix + emb |
|
|
|
alpha = self.get_alpha() |
|
x_mix = self.time_mix_block(x_mix, timesteps=timesteps) |
|
x = alpha * x + (1.0 - alpha) * x_mix |
|
|
|
x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w) |
|
x = self.proj_out(x) |
|
|
|
return x_in + x |
|
|
|
def get_alpha( |
|
self, |
|
): |
|
if self.merge_strategy == "fixed": |
|
return self.mix_factor |
|
elif self.merge_strategy == "learned": |
|
return torch.sigmoid(self.mix_factor) |
|
else: |
|
raise NotImplementedError(f"unknown merge strategy {self.merge_strategy}") |
|
|
|
|
|
class MemoryEfficientVideoBlock(MemoryEfficientAttnBlock): |
|
def __init__( |
|
self, in_channels: int, alpha: float = 0, merge_strategy: str = "learned" |
|
): |
|
super().__init__(in_channels) |
|
|
|
self.time_mix_block = VideoTransformerBlock( |
|
dim=in_channels, |
|
n_heads=1, |
|
d_head=in_channels, |
|
checkpoint=False, |
|
ff_in=True, |
|
attn_mode="softmax-xformers", |
|
) |
|
|
|
time_embed_dim = self.in_channels * 4 |
|
self.video_time_embed = torch.nn.Sequential( |
|
torch.nn.Linear(self.in_channels, time_embed_dim), |
|
torch.nn.SiLU(), |
|
torch.nn.Linear(time_embed_dim, self.in_channels), |
|
) |
|
|
|
self.merge_strategy = merge_strategy |
|
if self.merge_strategy == "fixed": |
|
self.register_buffer("mix_factor", torch.Tensor([alpha])) |
|
elif self.merge_strategy == "learned": |
|
self.register_parameter( |
|
"mix_factor", torch.nn.Parameter(torch.Tensor([alpha])) |
|
) |
|
else: |
|
raise ValueError(f"unknown merge strategy {self.merge_strategy}") |
|
|
|
def forward(self, x, timesteps, skip_time_block=False): |
|
if skip_time_block: |
|
return super().forward(x) |
|
|
|
x_in = x |
|
x = self.attention(x) |
|
h, w = x.shape[2:] |
|
x = rearrange(x, "b c h w -> b (h w) c") |
|
|
|
x_mix = x |
|
num_frames = torch.arange(timesteps, device=x.device) |
|
num_frames = repeat(num_frames, "t -> b t", b=x.shape[0] // timesteps) |
|
num_frames = rearrange(num_frames, "b t -> (b t)") |
|
t_emb = timestep_embedding(num_frames, self.in_channels, repeat_only=False) |
|
emb = self.video_time_embed(t_emb) |
|
emb = emb[:, None, :] |
|
x_mix = x_mix + emb |
|
|
|
alpha = self.get_alpha() |
|
x_mix = self.time_mix_block(x_mix, timesteps=timesteps) |
|
x = alpha * x + (1.0 - alpha) * x_mix |
|
|
|
x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w) |
|
x = self.proj_out(x) |
|
|
|
return x_in + x |
|
|
|
def get_alpha( |
|
self, |
|
): |
|
if self.merge_strategy == "fixed": |
|
return self.mix_factor |
|
elif self.merge_strategy == "learned": |
|
return torch.sigmoid(self.mix_factor) |
|
else: |
|
raise NotImplementedError(f"unknown merge strategy {self.merge_strategy}") |
|
|
|
|
|
def make_time_attn( |
|
in_channels, |
|
attn_type="vanilla", |
|
attn_kwargs=None, |
|
alpha: float = 0, |
|
merge_strategy: str = "learned", |
|
): |
|
assert attn_type in [ |
|
"vanilla", |
|
"vanilla-xformers", |
|
], f"attn_type {attn_type} not supported for spatio-temporal attention" |
|
print( |
|
f"making spatial and temporal attention of type '{attn_type}' with {in_channels} in_channels" |
|
) |
|
if not XFORMERS_IS_AVAILABLE and attn_type == "vanilla-xformers": |
|
print( |
|
f"Attention mode '{attn_type}' is not available. Falling back to vanilla attention. " |
|
f"This is not a problem in Pytorch >= 2.0. FYI, you are running with PyTorch version {torch.__version__}" |
|
) |
|
attn_type = "vanilla" |
|
|
|
if attn_type == "vanilla": |
|
assert attn_kwargs is None |
|
return partialclass( |
|
VideoBlock, in_channels, alpha=alpha, merge_strategy=merge_strategy |
|
) |
|
elif attn_type == "vanilla-xformers": |
|
print(f"building MemoryEfficientAttnBlock with {in_channels} in_channels...") |
|
return partialclass( |
|
MemoryEfficientVideoBlock, |
|
in_channels, |
|
alpha=alpha, |
|
merge_strategy=merge_strategy, |
|
) |
|
else: |
|
return NotImplementedError() |
|
|
|
|
|
class Conv2DWrapper(torch.nn.Conv2d): |
|
def forward(self, input: torch.Tensor, **kwargs) -> torch.Tensor: |
|
return super().forward(input) |
|
|
|
|
|
class VideoDecoder(Decoder): |
|
available_time_modes = ["all", "conv-only", "attn-only"] |
|
|
|
def __init__( |
|
self, |
|
*args, |
|
video_kernel_size: Union[int, list] = 3, |
|
alpha: float = 0.0, |
|
merge_strategy: str = "learned", |
|
time_mode: str = "conv-only", |
|
**kwargs, |
|
): |
|
self.video_kernel_size = video_kernel_size |
|
self.alpha = alpha |
|
self.merge_strategy = merge_strategy |
|
self.time_mode = time_mode |
|
assert ( |
|
self.time_mode in self.available_time_modes |
|
), f"time_mode parameter has to be in {self.available_time_modes}" |
|
super().__init__(*args, **kwargs) |
|
|
|
def get_last_layer(self, skip_time_mix=False, **kwargs): |
|
if self.time_mode == "attn-only": |
|
raise NotImplementedError("TODO") |
|
else: |
|
return ( |
|
self.conv_out.time_mix_conv.weight |
|
if not skip_time_mix |
|
else self.conv_out.weight |
|
) |
|
|
|
def _make_attn(self) -> Callable: |
|
if self.time_mode not in ["conv-only", "only-last-conv"]: |
|
return partialclass( |
|
make_time_attn, |
|
alpha=self.alpha, |
|
merge_strategy=self.merge_strategy, |
|
) |
|
else: |
|
return super()._make_attn() |
|
|
|
def _make_conv(self) -> Callable: |
|
if self.time_mode != "attn-only": |
|
return partialclass( |
|
PostHocConv2WithTime, video_kernel_size=self.video_kernel_size |
|
) |
|
else: |
|
return Conv2DWrapper |
|
|
|
def _make_resblock(self) -> Callable: |
|
if self.time_mode not in ["attn-only", "only-last-conv"]: |
|
return partialclass( |
|
VideoResBlock, |
|
video_kernel_size=self.video_kernel_size, |
|
alpha=self.alpha, |
|
merge_strategy=self.merge_strategy, |
|
) |
|
else: |
|
return super()._make_resblock() |
|
|