AniDoc / models_diffusers /camera /motion_module.py
fffiloni's picture
Migrated from GitHub
c705408 verified
raw
history blame
16 kB
from dataclasses import dataclass
from typing import Callable, Optional
import torch
from torch import nn
from diffusers.utils import BaseOutput
from diffusers.models.attention_processor import Attention
from diffusers.models.attention import FeedForward
from typing import Dict, Any
# from cameractrl.models.attention_processor import PoseAdaptorAttnProcessor
from models_diffusers.camera.attention_processor import PoseAdaptorAttnProcessor
from einops import rearrange
import math
class InflatedGroupNorm(nn.GroupNorm):
def forward(self, x):
# return super().forward(x)
video_length = x.shape[2]
x = rearrange(x, "b c f h w -> (b f) c h w")
x = super().forward(x)
x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
return x
def zero_module(module):
# Zero out the parameters of a module and return it.
for p in module.parameters():
p.detach().zero_()
return module
@dataclass
class TemporalTransformer3DModelOutput(BaseOutput):
sample: torch.FloatTensor
def get_motion_module(
in_channels,
motion_module_type: str,
motion_module_kwargs: dict
):
if motion_module_type == "Vanilla":
return VanillaTemporalModule(in_channels=in_channels, **motion_module_kwargs)
else:
raise ValueError
class VanillaTemporalModule(nn.Module):
def __init__(
self,
in_channels,
num_attention_heads=8,
num_transformer_block=2,
attention_block_types=("Temporal_Self",),
temporal_position_encoding=True,
temporal_position_encoding_max_len=32,
temporal_attention_dim_div=1,
cross_attention_dim=320,
zero_initialize=True,
encoder_hidden_states_query=(False, False),
attention_activation_scale=1.0,
attention_processor_kwargs: Dict = {},
causal_temporal_attention=False,
causal_temporal_attention_mask_type="",
rescale_output_factor=1.0
):
super().__init__()
self.temporal_transformer = TemporalTransformer3DModel(
in_channels=in_channels,
num_attention_heads=num_attention_heads,
attention_head_dim=in_channels // num_attention_heads // temporal_attention_dim_div,
num_layers=num_transformer_block,
attention_block_types=attention_block_types,
cross_attention_dim=cross_attention_dim,
temporal_position_encoding=temporal_position_encoding,
temporal_position_encoding_max_len=temporal_position_encoding_max_len,
encoder_hidden_states_query=encoder_hidden_states_query,
attention_activation_scale=attention_activation_scale,
attention_processor_kwargs=attention_processor_kwargs,
causal_temporal_attention=causal_temporal_attention,
causal_temporal_attention_mask_type=causal_temporal_attention_mask_type,
rescale_output_factor=rescale_output_factor
)
if zero_initialize:
self.temporal_transformer.proj_out = zero_module(self.temporal_transformer.proj_out)
def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None,
cross_attention_kwargs: Dict[str, Any] = {}):
hidden_states = self.temporal_transformer(hidden_states, encoder_hidden_states, attention_mask, cross_attention_kwargs=cross_attention_kwargs)
output = hidden_states
return output
class TemporalTransformer3DModel(nn.Module):
def __init__(
self,
in_channels,
num_attention_heads,
attention_head_dim,
num_layers,
attention_block_types=("Temporal_Self", "Temporal_Self",),
dropout=0.0,
norm_num_groups=32,
cross_attention_dim=320,
activation_fn="geglu",
attention_bias=False,
upcast_attention=False,
temporal_position_encoding=False,
temporal_position_encoding_max_len=32,
encoder_hidden_states_query=(False, False),
attention_activation_scale=1.0,
attention_processor_kwargs: Dict = {},
causal_temporal_attention=None,
causal_temporal_attention_mask_type="",
rescale_output_factor=1.0
):
super().__init__()
assert causal_temporal_attention is not None
self.causal_temporal_attention = causal_temporal_attention
assert (not causal_temporal_attention) or (causal_temporal_attention_mask_type != "")
self.causal_temporal_attention_mask_type = causal_temporal_attention_mask_type
self.causal_temporal_attention_mask = None
inner_dim = num_attention_heads * attention_head_dim
self.norm = InflatedGroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
self.proj_in = nn.Linear(in_channels, inner_dim)
self.transformer_blocks = nn.ModuleList(
[
TemporalTransformerBlock(
dim=inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
attention_block_types=attention_block_types,
dropout=dropout,
norm_num_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim,
activation_fn=activation_fn,
attention_bias=attention_bias,
upcast_attention=upcast_attention,
temporal_position_encoding=temporal_position_encoding,
temporal_position_encoding_max_len=temporal_position_encoding_max_len,
encoder_hidden_states_query=encoder_hidden_states_query,
attention_activation_scale=attention_activation_scale,
attention_processor_kwargs=attention_processor_kwargs,
rescale_output_factor=rescale_output_factor,
)
for d in range(num_layers)
]
)
self.proj_out = nn.Linear(inner_dim, in_channels)
def get_causal_temporal_attention_mask(self, hidden_states):
batch_size, sequence_length, dim = hidden_states.shape
if self.causal_temporal_attention_mask is None or self.causal_temporal_attention_mask.shape != (
batch_size, sequence_length, sequence_length):
if self.causal_temporal_attention_mask_type == "causal":
# 1. vanilla causal mask
mask = torch.tril(torch.ones(sequence_length, sequence_length))
elif self.causal_temporal_attention_mask_type == "2-seq":
# 2. 2-seq
mask = torch.zeros(sequence_length, sequence_length)
mask[:sequence_length // 2, :sequence_length // 2] = 1
mask[-sequence_length // 2:, -sequence_length // 2:] = 1
elif self.causal_temporal_attention_mask_type == "0-prev":
# attn to the previous frame
indices = torch.arange(sequence_length)
indices_prev = indices - 1
indices_prev[0] = 0
mask = torch.zeros(sequence_length, sequence_length)
mask[:, 0] = 1.
mask[indices, indices_prev] = 1.
elif self.causal_temporal_attention_mask_type == "0":
# only attn to first frame
mask = torch.zeros(sequence_length, sequence_length)
mask[:, 0] = 1
elif self.causal_temporal_attention_mask_type == "wo-self":
indices = torch.arange(sequence_length)
mask = torch.ones(sequence_length, sequence_length)
mask[indices, indices] = 0
elif self.causal_temporal_attention_mask_type == "circle":
indices = torch.arange(sequence_length)
indices_prev = indices - 1
indices_prev[0] = 0
mask = torch.eye(sequence_length)
mask[indices, indices_prev] = 1
mask[0, -1] = 1
else:
raise ValueError
# generate attention mask fron binary values
mask = mask.masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
mask = mask.unsqueeze(0)
mask = mask.repeat(batch_size, 1, 1)
self.causal_temporal_attention_mask = mask.to(hidden_states.device)
return self.causal_temporal_attention_mask
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None,
cross_attention_kwargs: Dict[str, Any] = {},):
residual = hidden_states
assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
height, width = hidden_states.shape[-2:]
hidden_states = self.norm(hidden_states)
hidden_states = rearrange(hidden_states, "b c f h w -> (b h w) f c")
hidden_states = self.proj_in(hidden_states)
attention_mask = self.get_causal_temporal_attention_mask(
hidden_states) if self.causal_temporal_attention else attention_mask
# Transformer Blocks
for block in self.transformer_blocks:
hidden_states = block(hidden_states, encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs)
hidden_states = self.proj_out(hidden_states)
hidden_states = rearrange(hidden_states, "(b h w) f c -> b c f h w", h=height, w=width)
output = hidden_states + residual
return output
class TemporalTransformerBlock(nn.Module):
def __init__(
self,
dim,
num_attention_heads,
attention_head_dim,
attention_block_types=("Temporal_Self", "Temporal_Self",),
dropout=0.0,
norm_num_groups=32,
cross_attention_dim=768,
activation_fn="geglu",
attention_bias=False,
upcast_attention=False,
temporal_position_encoding=False,
temporal_position_encoding_max_len=32,
encoder_hidden_states_query=(False, False),
attention_activation_scale=1.0,
attention_processor_kwargs: Dict = {},
rescale_output_factor=1.0
):
super().__init__()
attention_blocks = []
norms = []
self.attention_block_types = attention_block_types
for block_idx, block_name in enumerate(attention_block_types):
attention_blocks.append(
TemporalSelfAttention(
attention_mode=block_name,
cross_attention_dim=cross_attention_dim if block_name in ['Temporal_Cross', 'Temporal_Pose_Adaptor'] else None,
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
upcast_attention=upcast_attention,
temporal_position_encoding=temporal_position_encoding,
temporal_position_encoding_max_len=temporal_position_encoding_max_len,
rescale_output_factor=rescale_output_factor,
)
)
norms.append(nn.LayerNorm(dim))
self.attention_blocks = nn.ModuleList(attention_blocks)
self.norms = nn.ModuleList(norms)
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
self.ff_norm = nn.LayerNorm(dim)
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs: Dict[str, Any] = {}):
for attention_block, norm, attention_block_type in zip(self.attention_blocks, self.norms, self.attention_block_types):
norm_hidden_states = norm(hidden_states)
hidden_states = attention_block(
norm_hidden_states,
encoder_hidden_states=norm_hidden_states if attention_block_type == 'Temporal_Self' else encoder_hidden_states,
attention_mask=attention_mask,
**cross_attention_kwargs
) + hidden_states
hidden_states = self.ff(self.ff_norm(hidden_states)) + hidden_states
output = hidden_states
return output
class PositionalEncoding(nn.Module):
def __init__(
self,
d_model,
dropout=0.,
max_len=32,
):
super().__init__()
self.dropout = nn.Dropout(p=dropout)
position = torch.arange(max_len).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
pe = torch.zeros(1, max_len, d_model)
pe[0, :, 0::2] = torch.sin(position * div_term)
pe[0, :, 1::2] = torch.cos(position * div_term)
self.register_buffer('pe', pe)
def forward(self, x):
x = x + self.pe[:, :x.size(1)]
return self.dropout(x)
class TemporalSelfAttention(Attention):
def __init__(
self,
attention_mode=None,
temporal_position_encoding=False,
temporal_position_encoding_max_len=32,
rescale_output_factor=1.0,
*args, **kwargs
):
super().__init__(*args, **kwargs)
assert attention_mode == "Temporal_Self"
self.pos_encoder = PositionalEncoding(
kwargs["query_dim"],
max_len=temporal_position_encoding_max_len
) if temporal_position_encoding else None
self.rescale_output_factor = rescale_output_factor
def set_use_memory_efficient_attention_xformers(
self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
):
# disable motion module efficient xformers to avoid bad results, don't know why
# TODO: fix this bug
pass
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, **cross_attention_kwargs):
# The `Attention` class can call different attention processors / attention functions
# here we simply pass along all tensors to the selected processor class
# For standard processors that are defined here, `**cross_attention_kwargs` is empty
# add position encoding
if self.pos_encoder is not None:
hidden_states = self.pos_encoder(hidden_states)
if "pose_feature" in cross_attention_kwargs:
pose_feature = cross_attention_kwargs["pose_feature"]
if pose_feature.ndim == 5:
pose_feature = rearrange(pose_feature, "b c f h w -> (b h w) f c")
else:
assert pose_feature.ndim == 3
cross_attention_kwargs["pose_feature"] = pose_feature
if isinstance(self.processor, PoseAdaptorAttnProcessor):
return self.processor(
self,
hidden_states,
cross_attention_kwargs.pop('pose_feature'),
encoder_hidden_states=None,
attention_mask=attention_mask,
**cross_attention_kwargs,
)
elif hasattr(self.processor, "__call__"):
return self.processor.__call__(
self,
hidden_states,
encoder_hidden_states=None,
attention_mask=attention_mask,
**cross_attention_kwargs,
)
else:
return self.processor(
self,
hidden_states,
encoder_hidden_states=None,
attention_mask=attention_mask,
**cross_attention_kwargs,
)