File size: 5,268 Bytes
2252f3d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import torch.nn as nn
from .net_utils import single_conv, double_conv, double_conv_down, double_conv_up, PosEnSine
from .transformer_basics import OurMultiheadAttention


class TransformerDecoderUnit(nn.Module):
    def __init__(self, feat_dim, n_head=8, pos_en_flag=True, attn_type='softmax', P=None):
        super(TransformerDecoderUnit, self).__init__()
        self.feat_dim = feat_dim
        self.attn_type = attn_type
        self.pos_en_flag = pos_en_flag
        self.P = P

        self.pos_en = PosEnSine(self.feat_dim // 2)
        self.attn = OurMultiheadAttention(feat_dim, n_head)    # cross-attention

        self.linear1 = nn.Conv2d(self.feat_dim, self.feat_dim, 1)
        self.linear2 = nn.Conv2d(self.feat_dim, self.feat_dim, 1)
        self.activation = nn.ReLU(inplace=True)

        self.norm = nn.BatchNorm2d(self.feat_dim)

    def forward(self, q, k, v):
        if self.pos_en_flag:
            q_pos_embed = self.pos_en(q)
            k_pos_embed = self.pos_en(k)
        else:
            q_pos_embed = 0
            k_pos_embed = 0

        # cross-multi-head attention
        out = self.attn(
            q=q + q_pos_embed, k=k + k_pos_embed, v=v, attn_type=self.attn_type, P=self.P
        )[0]

        # feed forward
        out2 = self.linear2(self.activation(self.linear1(out)))
        out = out + out2
        out = self.norm(out)

        return out


class Unet(nn.Module):
    def __init__(self, in_ch, feat_ch, out_ch):
        super().__init__()
        self.conv_in = single_conv(in_ch, feat_ch)

        self.conv1 = double_conv_down(feat_ch, feat_ch)
        self.conv2 = double_conv_down(feat_ch, feat_ch)
        self.conv3 = double_conv(feat_ch, feat_ch)
        self.conv4 = double_conv_up(feat_ch, feat_ch)
        self.conv5 = double_conv_up(feat_ch, feat_ch)
        self.conv6 = double_conv(feat_ch, out_ch)

    def forward(self, x):
        feat0 = self.conv_in(x)    # H
        feat1 = self.conv1(feat0)    # H/2
        feat2 = self.conv2(feat1)    # H/4
        feat3 = self.conv3(feat2)    # H/4
        feat3 = feat3 + feat2    # H/4
        feat4 = self.conv4(feat3)    # H/2
        feat4 = feat4 + feat1    # H/2
        feat5 = self.conv5(feat4)    # H
        feat5 = feat5 + feat0    # H
        feat6 = self.conv6(feat5)

        return feat0, feat1, feat2, feat3, feat4, feat6


class Texformer(nn.Module):
    def __init__(self, opts):
        super().__init__()
        self.feat_dim = opts.feat_dim
        src_ch = opts.src_ch
        tgt_ch = opts.tgt_ch
        out_ch = opts.out_ch
        self.mask_fusion = opts.mask_fusion

        if not self.mask_fusion:
            v_ch = out_ch
        else:
            v_ch = 2 + 3

        self.unet_q = Unet(tgt_ch, self.feat_dim, self.feat_dim)
        self.unet_k = Unet(src_ch, self.feat_dim, self.feat_dim)
        self.unet_v = Unet(v_ch, self.feat_dim, self.feat_dim)

        self.trans_dec = nn.ModuleList(
            [
                None, None, None,
                TransformerDecoderUnit(self.feat_dim, opts.nhead, True, 'softmax'),
                TransformerDecoderUnit(self.feat_dim, opts.nhead, True, 'dotproduct'),
                TransformerDecoderUnit(self.feat_dim, opts.nhead, True, 'dotproduct')
            ]
        )

        self.conv0 = double_conv(self.feat_dim, self.feat_dim)
        self.conv1 = double_conv_down(self.feat_dim, self.feat_dim)
        self.conv2 = double_conv_down(self.feat_dim, self.feat_dim)
        self.conv3 = double_conv(self.feat_dim, self.feat_dim)
        self.conv4 = double_conv_up(self.feat_dim, self.feat_dim)
        self.conv5 = double_conv_up(self.feat_dim, self.feat_dim)

        if not self.mask_fusion:
            self.conv6 = nn.Sequential(
                single_conv(self.feat_dim, self.feat_dim),
                nn.Conv2d(self.feat_dim, out_ch, 3, 1, 1)
            )
        else:
            self.conv6 = nn.Sequential(
                single_conv(self.feat_dim, self.feat_dim),
                nn.Conv2d(self.feat_dim, 2 + 3 + 1, 3, 1, 1)
            )    # mask*flow-sampling + (1-mask)*rgb
            self.sigmoid = nn.Sigmoid()

        self.tanh = nn.Tanh()

    def forward(self, q, k, v):
        print('qkv', q.shape, k.shape, v.shape)
        q_feat = self.unet_q(q)
        k_feat = self.unet_k(k)
        v_feat = self.unet_v(v)

        print('q_feat', len(q_feat))
        outputs = []
        for i in range(3, len(q_feat)):
            print(i, q_feat[i].shape, k_feat[i].shape, v_feat[i].shape)
            outputs.append(self.trans_dec[i](q_feat[i], k_feat[i], v_feat[i]))
            print('outputs', outputs[-1].shape)

        f0 = self.conv0(outputs[2])    # H
        f1 = self.conv1(f0)    # H/2
        f1 = f1 + outputs[1]
        f2 = self.conv2(f1)    # H/4
        f2 = f2 + outputs[0]
        f3 = self.conv3(f2)    # H/4
        f3 = f3 + outputs[0] + f2
        f4 = self.conv4(f3)    # H/2
        f4 = f4 + outputs[1] + f1
        f5 = self.conv5(f4)    # H
        f5 = f5 + outputs[2] + f0
        if not self.mask_fusion:
            out = self.tanh(self.conv6(f5))
        else:
            out_ = self.conv6(f5)
            out = [self.tanh(out_[:, :2]), self.tanh(out_[:, 2:5]), self.sigmoid(out_[:, 5:])]
        return out