File size: 5,014 Bytes
f949b3f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import diffusers
from diffusers.models.transformer_temporal import TransformerTemporalModel, TransformerTemporalModelOutput
import torch.nn as nn
from einops import rearrange
from diffusers.models.attention_processor import Attention
# from t2v_enhanced.model.diffusers_conditional.models.controlnet.attention_processor import Attention
from t2v_enhanced.model.diffusers_conditional.models.controlnet.transformer_temporal_crossattention import TransformerTemporalModel as TransformerTemporalModelCrossAttn
import torch


class CrossAttention(nn.Module):

    def __init__(self, input_channels, attention_head_dim, norm_num_groups=32):
        super().__init__()
        self.attention = Attention(
            query_dim=input_channels, cross_attention_dim=input_channels, heads=input_channels//attention_head_dim, dim_head=attention_head_dim, bias=False, upcast_attention=False)
        self.norm = torch.nn.GroupNorm(
            num_groups=norm_num_groups, num_channels=input_channels, eps=1e-6, affine=True)
        self.proj_in = nn.Linear(input_channels, input_channels)
        self.proj_out = nn.Linear(input_channels, input_channels)

    def forward(self, hidden_state, encoder_hidden_states, num_frames):
        h, w = hidden_state.shape[2], hidden_state.shape[3]
        hidden_state_norm = rearrange(
            hidden_state, "(B F) C H W -> B C F H W", F=num_frames)
        hidden_state_norm = self.norm(hidden_state_norm)
        hidden_state_norm = rearrange(
            hidden_state_norm, "B C F H W -> (B H W) F C")
        hidden_state_norm = self.proj_in(hidden_state_norm)
        attn = self.attention(hidden_state_norm,
                              encoder_hidden_states=encoder_hidden_states,
                              attention_mask=None,
                              )
        # proj_out

        residual = self.proj_out(attn)

        residual = rearrange(
            residual, "(B H W) F C  -> (B F) C H W", H=h, W=w)
        output = hidden_state + residual
        return TransformerTemporalModelOutput(sample=output)


class ConditionalModel(nn.Module):

    def __init__(self, input_channels, conditional_model: str, attention_head_dim=64):
        super().__init__()
        num_layers = 1
        if "_layers_" in conditional_model:
            config = conditional_model.split("_layers_")
            conditional_model = config[0]
            num_layers = int(config[1])

        if conditional_model == "self_cross_transformer":
            self.temporal_transformer = TransformerTemporalModel(num_attention_heads=input_channels//attention_head_dim, attention_head_dim=attention_head_dim, in_channels=input_channels,
                                                                 double_self_attention=False, cross_attention_dim=input_channels)
        elif conditional_model == "cross_transformer":
            self.temporal_transformer = TransformerTemporalModelCrossAttn(num_attention_heads=input_channels//attention_head_dim, attention_head_dim=attention_head_dim, in_channels=input_channels,
                                                                          double_self_attention=False, cross_attention_dim=input_channels, num_layers=num_layers)
        elif conditional_model == "cross_attention":
            self.temporal_transformer = CrossAttention(
                input_channels=input_channels, attention_head_dim=attention_head_dim)
        elif conditional_model == "test_conv":
            self.temporal_transformer = nn.Conv2d(
                input_channels, input_channels, kernel_size=1)
        else:
            raise NotImplementedError(
                f"mode {conditional_model} not implemented")
        if conditional_model != "test_conv":
            nn.init.zeros_(self.temporal_transformer.proj_out.weight)
            nn.init.zeros_(self.temporal_transformer.proj_out.bias)
        else:
            nn.init.zeros_(self.temporal_transformer.weight)
            nn.init.zeros_(self.temporal_transformer.bias)
        self.conditional_model = conditional_model

    def forward(self, sample, conditioning, num_frames=None):

        assert conditioning.ndim == 5
        assert sample.ndim == 5
        if self.conditional_model != "test_conv":
            conditioning = rearrange(conditioning, "B F C H W -> (B H W) F C")

            num_frames = sample.shape[1]

            sample = rearrange(sample, "B F C H W -> (B F) C H W")

            sample = self.temporal_transformer(
                sample, encoder_hidden_states=conditioning, num_frames=num_frames).sample

            sample = rearrange(
                sample, "(B F) C H W -> B F C H W", F=num_frames)
        else:

            conditioning = rearrange(conditioning, "B F C H W -> (B F) C H W")
            f = sample.shape[1]
            sample = rearrange(sample, "B F C H W -> (B F) C H W")
            sample = sample + self.temporal_transformer(conditioning)
            sample = rearrange(sample, "(B F) C H W -> B F C H W", F=f)
        return sample