Spaces:
Running
on
Zero
Running
on
Zero
Update MT.py
Browse files
MT.py
CHANGED
@@ -131,7 +131,7 @@ class PositionEmbeddingSine(nn.Module):
|
|
131 |
return pos
|
132 |
|
133 |
|
134 |
-
def feature_add_position(feature0, feature_channels, scale=
|
135 |
temp = torch.mean(abs(feature0))
|
136 |
pos_enc = PositionEmbeddingSine(num_pos_feats=feature_channels // 2)
|
137 |
# position = PositionalEncodingPermute2D(feature_channels)(feature0)
|
@@ -223,8 +223,6 @@ class TransformerLayer(nn.Module):
|
|
223 |
att = feature_add_position(att.transpose(-1, -2).view(
|
224 |
B, C, shape[0], shape[1]), C).reshape(B, C, -1).transpose(-1, -2)
|
225 |
|
226 |
-
# att = feature_add_position(att.transpose(-1, -2).view(
|
227 |
-
# B, C, shape[0], shape[1]), C).reshape(B, C, -1).transpose(-1, -2)
|
228 |
val_proj = self.v_proj(value)
|
229 |
att_proj = self.att_proj(att) # [B, L, C]
|
230 |
norm_fac = torch.sum(att_proj ** 2, dim=-1, keepdim=True) ** 0.5
|
@@ -237,7 +235,6 @@ class TransformerLayer(nn.Module):
|
|
237 |
D = 1 / (torch.sqrt(D) + 1e-6) # normalized node degrees
|
238 |
A = D * A * D.transpose(-1, -2)
|
239 |
|
240 |
-
# A = torch.softmax(A , dim=2) # [B, L, L]
|
241 |
message = torch.matmul(A, val_proj) # [B, L, C]
|
242 |
|
243 |
message = self.merge(message) # [B, L, C]
|
@@ -246,9 +243,6 @@ class TransformerLayer(nn.Module):
|
|
246 |
message = self.mlp(torch.cat([value, message], dim=-1))
|
247 |
message = self.norm2(message)
|
248 |
|
249 |
-
# if iteration > 2:
|
250 |
-
# message = self.drop(message)
|
251 |
-
|
252 |
att = self.attn_updater(att, message, shape)
|
253 |
value = self.gru(value, message, shape)
|
254 |
return value, att, A
|
@@ -290,14 +284,11 @@ class FeatureTransformer(nn.Module):
|
|
290 |
att = att.flatten(-2).permute(0, 2, 1) # [B, H*W, C]
|
291 |
for i in range(self.num_layers):
|
292 |
value, att, attn_viz = self.layers(att=att, value=value, shape=[h, w], iteration=i)
|
293 |
-
|
294 |
-
|
295 |
-
|
296 |
-
|
297 |
-
|
298 |
-
attn_list.append(att.view(b, h, w, c).permute(0, 3, 1, 2).contiguous())
|
299 |
-
feature_list.append(value_decode.view(b, h, w, c).permute(0, 3, 1, 2).contiguous())
|
300 |
-
# reshape back
|
301 |
return feature_list, attn_list, attn_viz_list
|
302 |
|
303 |
def forward_save_mem(self, feature0, add_position_embedding=True):
|
|
|
131 |
return pos
|
132 |
|
133 |
|
134 |
+
def feature_add_position(feature0, feature_channels, scale=0.5):
|
135 |
temp = torch.mean(abs(feature0))
|
136 |
pos_enc = PositionEmbeddingSine(num_pos_feats=feature_channels // 2)
|
137 |
# position = PositionalEncodingPermute2D(feature_channels)(feature0)
|
|
|
223 |
att = feature_add_position(att.transpose(-1, -2).view(
|
224 |
B, C, shape[0], shape[1]), C).reshape(B, C, -1).transpose(-1, -2)
|
225 |
|
|
|
|
|
226 |
val_proj = self.v_proj(value)
|
227 |
att_proj = self.att_proj(att) # [B, L, C]
|
228 |
norm_fac = torch.sum(att_proj ** 2, dim=-1, keepdim=True) ** 0.5
|
|
|
235 |
D = 1 / (torch.sqrt(D) + 1e-6) # normalized node degrees
|
236 |
A = D * A * D.transpose(-1, -2)
|
237 |
|
|
|
238 |
message = torch.matmul(A, val_proj) # [B, L, C]
|
239 |
|
240 |
message = self.merge(message) # [B, L, C]
|
|
|
243 |
message = self.mlp(torch.cat([value, message], dim=-1))
|
244 |
message = self.norm2(message)
|
245 |
|
|
|
|
|
|
|
246 |
att = self.attn_updater(att, message, shape)
|
247 |
value = self.gru(value, message, shape)
|
248 |
return value, att, A
|
|
|
284 |
att = att.flatten(-2).permute(0, 2, 1) # [B, H*W, C]
|
285 |
for i in range(self.num_layers):
|
286 |
value, att, attn_viz = self.layers(att=att, value=value, shape=[h, w], iteration=i)
|
287 |
+
value_decode = self.normalize(torch.square(self.re_proj(value))) # map to motion energy, Do use normalization here
|
288 |
+
|
289 |
+
attn_viz_list.append(attn_viz.reshape(b, h, w, h, w))
|
290 |
+
attn_list.append(att.view(b, h, w, c).permute(0, 3, 1, 2).contiguous())
|
291 |
+
feature_list.append(value_decode.view(b, h, w, c).permute(0, 3, 1, 2).contiguous())
|
|
|
|
|
|
|
292 |
return feature_list, attn_list, attn_viz_list
|
293 |
|
294 |
def forward_save_mem(self, feature0, add_position_embedding=True):
|