HarborYuan
commited on
Commit
•
c7fd587
1
Parent(s):
16d167b
bugfix
Browse files
app/models/heads/mask2former_vid.py
CHANGED
@@ -183,14 +183,15 @@ class Mask2FormerVideoHead(AnchorFreeHead):
|
|
183 |
_dim = cls_embed.size(2)
|
184 |
_prototypes = cls_embed.size(1)
|
185 |
|
186 |
-
if rank == 0:
|
187 |
-
|
188 |
-
|
189 |
-
else:
|
190 |
-
|
191 |
-
if world_size > 1:
|
192 |
-
|
193 |
-
back_token =
|
|
|
194 |
cls_embed = torch.cat([
|
195 |
cls_embed, back_token.repeat(_prototypes, 1)[None]
|
196 |
], dim=0)
|
|
|
183 |
_dim = cls_embed.size(2)
|
184 |
_prototypes = cls_embed.size(1)
|
185 |
|
186 |
+
# if rank == 0:
|
187 |
+
# back_token = torch.zeros(1, _dim, dtype=torch.float32, device='cuda')
|
188 |
+
# # back_token = back_token / back_token.norm(p=2, dim=-1, keepdim=True)
|
189 |
+
# else:
|
190 |
+
# back_token = torch.empty(1, _dim, dtype=torch.float32, device='cuda')
|
191 |
+
# if world_size > 1:
|
192 |
+
# dist.broadcast(back_token, src=0)
|
193 |
+
back_token = torch.zeros(1, _dim, dtype=torch.float32, device='cuda')
|
194 |
+
# back_token = back_token.to(device='cpu')
|
195 |
cls_embed = torch.cat([
|
196 |
cls_embed, back_token.repeat(_prototypes, 1)[None]
|
197 |
], dim=0)
|
app/models/heads/yoso_head.py
CHANGED
@@ -369,13 +369,15 @@ class CrossAttenHead(nn.Module):
|
|
369 |
# background class
|
370 |
_dim = cls_embed.size(2)
|
371 |
_prototypes = cls_embed.size(1)
|
372 |
-
if rank == 0:
|
373 |
-
|
374 |
-
|
375 |
-
|
376 |
-
|
377 |
-
|
378 |
-
back_token =
|
|
|
|
|
379 |
cls_embed = torch.cat([
|
380 |
cls_embed, back_token.repeat(_prototypes, 1)[None]
|
381 |
], dim=0)
|
|
|
369 |
# background class
|
370 |
_dim = cls_embed.size(2)
|
371 |
_prototypes = cls_embed.size(1)
|
372 |
+
# if rank == 0:
|
373 |
+
# back_token = torch.zeros(1, _dim, dtype=torch.float32, device='cuda')
|
374 |
+
# # back_token = back_token / back_token.norm(p=2, dim=-1, keepdim=True)
|
375 |
+
# else:
|
376 |
+
# back_token = torch.empty(1, _dim, dtype=torch.float32, device='cuda')
|
377 |
+
# if world_size > 1:
|
378 |
+
# dist.broadcast(back_token, src=0)
|
379 |
+
back_token = torch.zeros(1, _dim, dtype=torch.float32, device='cuda')
|
380 |
+
# back_token = back_token.to(device='cpu')
|
381 |
cls_embed = torch.cat([
|
382 |
cls_embed, back_token.repeat(_prototypes, 1)[None]
|
383 |
], dim=0)
|