import diffusers from diffusers.models.transformer_temporal import TransformerTemporalModel, TransformerTemporalModelOutput import torch.nn as nn from einops import rearrange from diffusers.models.attention_processor import Attention # from t2v_enhanced.model.diffusers_conditional.models.controlnet.attention_processor import Attention from t2v_enhanced.model.diffusers_conditional.models.controlnet.transformer_temporal_crossattention import TransformerTemporalModel as TransformerTemporalModelCrossAttn import torch class CrossAttention(nn.Module): def __init__(self, input_channels, attention_head_dim, norm_num_groups=32): super().__init__() self.attention = Attention( query_dim=input_channels, cross_attention_dim=input_channels, heads=input_channels//attention_head_dim, dim_head=attention_head_dim, bias=False, upcast_attention=False) self.norm = torch.nn.GroupNorm( num_groups=norm_num_groups, num_channels=input_channels, eps=1e-6, affine=True) self.proj_in = nn.Linear(input_channels, input_channels) self.proj_out = nn.Linear(input_channels, input_channels) def forward(self, hidden_state, encoder_hidden_states, num_frames): h, w = hidden_state.shape[2], hidden_state.shape[3] hidden_state_norm = rearrange( hidden_state, "(B F) C H W -> B C F H W", F=num_frames) hidden_state_norm = self.norm(hidden_state_norm) hidden_state_norm = rearrange( hidden_state_norm, "B C F H W -> (B H W) F C") hidden_state_norm = self.proj_in(hidden_state_norm) attn = self.attention(hidden_state_norm, encoder_hidden_states=encoder_hidden_states, attention_mask=None, ) # proj_out residual = self.proj_out(attn) residual = rearrange( residual, "(B H W) F C -> (B F) C H W", H=h, W=w) output = hidden_state + residual return TransformerTemporalModelOutput(sample=output) class ConditionalModel(nn.Module): def __init__(self, input_channels, conditional_model: str, attention_head_dim=64): super().__init__() num_layers = 1 if "_layers_" in conditional_model: config = conditional_model.split("_layers_") conditional_model = config[0] num_layers = int(config[1]) if conditional_model == "self_cross_transformer": self.temporal_transformer = TransformerTemporalModel(num_attention_heads=input_channels//attention_head_dim, attention_head_dim=attention_head_dim, in_channels=input_channels, double_self_attention=False, cross_attention_dim=input_channels) elif conditional_model == "cross_transformer": self.temporal_transformer = TransformerTemporalModelCrossAttn(num_attention_heads=input_channels//attention_head_dim, attention_head_dim=attention_head_dim, in_channels=input_channels, double_self_attention=False, cross_attention_dim=input_channels, num_layers=num_layers) elif conditional_model == "cross_attention": self.temporal_transformer = CrossAttention( input_channels=input_channels, attention_head_dim=attention_head_dim) elif conditional_model == "test_conv": self.temporal_transformer = nn.Conv2d( input_channels, input_channels, kernel_size=1) else: raise NotImplementedError( f"mode {conditional_model} not implemented") if conditional_model != "test_conv": nn.init.zeros_(self.temporal_transformer.proj_out.weight) nn.init.zeros_(self.temporal_transformer.proj_out.bias) else: nn.init.zeros_(self.temporal_transformer.weight) nn.init.zeros_(self.temporal_transformer.bias) self.conditional_model = conditional_model def forward(self, sample, conditioning, num_frames=None): assert conditioning.ndim == 5 assert sample.ndim == 5 if self.conditional_model != "test_conv": conditioning = rearrange(conditioning, "B F C H W -> (B H W) F C") num_frames = sample.shape[1] sample = rearrange(sample, "B F C H W -> (B F) C H W") sample = self.temporal_transformer( sample, encoder_hidden_states=conditioning, num_frames=num_frames).sample sample = rearrange( sample, "(B F) C H W -> B F C H W", F=num_frames) else: conditioning = rearrange(conditioning, "B F C H W -> (B F) C H W") f = sample.shape[1] sample = rearrange(sample, "B F C H W -> (B F) C H W") sample = sample + self.temporal_transformer(conditioning) sample = rearrange(sample, "(B F) C H W -> B F C H W", F=f) return sample