Spaces:
Running
on
L40S
Running
on
L40S
File size: 5,268 Bytes
2252f3d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
import torch.nn as nn
from .net_utils import single_conv, double_conv, double_conv_down, double_conv_up, PosEnSine
from .transformer_basics import OurMultiheadAttention
class TransformerDecoderUnit(nn.Module):
def __init__(self, feat_dim, n_head=8, pos_en_flag=True, attn_type='softmax', P=None):
super(TransformerDecoderUnit, self).__init__()
self.feat_dim = feat_dim
self.attn_type = attn_type
self.pos_en_flag = pos_en_flag
self.P = P
self.pos_en = PosEnSine(self.feat_dim // 2)
self.attn = OurMultiheadAttention(feat_dim, n_head) # cross-attention
self.linear1 = nn.Conv2d(self.feat_dim, self.feat_dim, 1)
self.linear2 = nn.Conv2d(self.feat_dim, self.feat_dim, 1)
self.activation = nn.ReLU(inplace=True)
self.norm = nn.BatchNorm2d(self.feat_dim)
def forward(self, q, k, v):
if self.pos_en_flag:
q_pos_embed = self.pos_en(q)
k_pos_embed = self.pos_en(k)
else:
q_pos_embed = 0
k_pos_embed = 0
# cross-multi-head attention
out = self.attn(
q=q + q_pos_embed, k=k + k_pos_embed, v=v, attn_type=self.attn_type, P=self.P
)[0]
# feed forward
out2 = self.linear2(self.activation(self.linear1(out)))
out = out + out2
out = self.norm(out)
return out
class Unet(nn.Module):
def __init__(self, in_ch, feat_ch, out_ch):
super().__init__()
self.conv_in = single_conv(in_ch, feat_ch)
self.conv1 = double_conv_down(feat_ch, feat_ch)
self.conv2 = double_conv_down(feat_ch, feat_ch)
self.conv3 = double_conv(feat_ch, feat_ch)
self.conv4 = double_conv_up(feat_ch, feat_ch)
self.conv5 = double_conv_up(feat_ch, feat_ch)
self.conv6 = double_conv(feat_ch, out_ch)
def forward(self, x):
feat0 = self.conv_in(x) # H
feat1 = self.conv1(feat0) # H/2
feat2 = self.conv2(feat1) # H/4
feat3 = self.conv3(feat2) # H/4
feat3 = feat3 + feat2 # H/4
feat4 = self.conv4(feat3) # H/2
feat4 = feat4 + feat1 # H/2
feat5 = self.conv5(feat4) # H
feat5 = feat5 + feat0 # H
feat6 = self.conv6(feat5)
return feat0, feat1, feat2, feat3, feat4, feat6
class Texformer(nn.Module):
def __init__(self, opts):
super().__init__()
self.feat_dim = opts.feat_dim
src_ch = opts.src_ch
tgt_ch = opts.tgt_ch
out_ch = opts.out_ch
self.mask_fusion = opts.mask_fusion
if not self.mask_fusion:
v_ch = out_ch
else:
v_ch = 2 + 3
self.unet_q = Unet(tgt_ch, self.feat_dim, self.feat_dim)
self.unet_k = Unet(src_ch, self.feat_dim, self.feat_dim)
self.unet_v = Unet(v_ch, self.feat_dim, self.feat_dim)
self.trans_dec = nn.ModuleList(
[
None, None, None,
TransformerDecoderUnit(self.feat_dim, opts.nhead, True, 'softmax'),
TransformerDecoderUnit(self.feat_dim, opts.nhead, True, 'dotproduct'),
TransformerDecoderUnit(self.feat_dim, opts.nhead, True, 'dotproduct')
]
)
self.conv0 = double_conv(self.feat_dim, self.feat_dim)
self.conv1 = double_conv_down(self.feat_dim, self.feat_dim)
self.conv2 = double_conv_down(self.feat_dim, self.feat_dim)
self.conv3 = double_conv(self.feat_dim, self.feat_dim)
self.conv4 = double_conv_up(self.feat_dim, self.feat_dim)
self.conv5 = double_conv_up(self.feat_dim, self.feat_dim)
if not self.mask_fusion:
self.conv6 = nn.Sequential(
single_conv(self.feat_dim, self.feat_dim),
nn.Conv2d(self.feat_dim, out_ch, 3, 1, 1)
)
else:
self.conv6 = nn.Sequential(
single_conv(self.feat_dim, self.feat_dim),
nn.Conv2d(self.feat_dim, 2 + 3 + 1, 3, 1, 1)
) # mask*flow-sampling + (1-mask)*rgb
self.sigmoid = nn.Sigmoid()
self.tanh = nn.Tanh()
def forward(self, q, k, v):
print('qkv', q.shape, k.shape, v.shape)
q_feat = self.unet_q(q)
k_feat = self.unet_k(k)
v_feat = self.unet_v(v)
print('q_feat', len(q_feat))
outputs = []
for i in range(3, len(q_feat)):
print(i, q_feat[i].shape, k_feat[i].shape, v_feat[i].shape)
outputs.append(self.trans_dec[i](q_feat[i], k_feat[i], v_feat[i]))
print('outputs', outputs[-1].shape)
f0 = self.conv0(outputs[2]) # H
f1 = self.conv1(f0) # H/2
f1 = f1 + outputs[1]
f2 = self.conv2(f1) # H/4
f2 = f2 + outputs[0]
f3 = self.conv3(f2) # H/4
f3 = f3 + outputs[0] + f2
f4 = self.conv4(f3) # H/2
f4 = f4 + outputs[1] + f1
f5 = self.conv5(f4) # H
f5 = f5 + outputs[2] + f0
if not self.mask_fusion:
out = self.tanh(self.conv6(f5))
else:
out_ = self.conv6(f5)
out = [self.tanh(out_[:, :2]), self.tanh(out_[:, 2:5]), self.sigmoid(out_[:, 5:])]
return out
|