HarborYuan commited on
Commit
c7fd587
1 Parent(s): 16d167b
app/models/heads/mask2former_vid.py CHANGED
@@ -183,14 +183,15 @@ class Mask2FormerVideoHead(AnchorFreeHead):
183
  _dim = cls_embed.size(2)
184
  _prototypes = cls_embed.size(1)
185
 
186
- if rank == 0:
187
- back_token = torch.zeros(1, _dim, dtype=torch.float32, device='cuda')
188
- # back_token = back_token / back_token.norm(p=2, dim=-1, keepdim=True)
189
- else:
190
- back_token = torch.empty(1, _dim, dtype=torch.float32, device='cuda')
191
- if world_size > 1:
192
- dist.broadcast(back_token, src=0)
193
- back_token = back_token.to(device='cpu')
 
194
  cls_embed = torch.cat([
195
  cls_embed, back_token.repeat(_prototypes, 1)[None]
196
  ], dim=0)
 
183
  _dim = cls_embed.size(2)
184
  _prototypes = cls_embed.size(1)
185
 
186
+ # if rank == 0:
187
+ # back_token = torch.zeros(1, _dim, dtype=torch.float32, device='cuda')
188
+ # # back_token = back_token / back_token.norm(p=2, dim=-1, keepdim=True)
189
+ # else:
190
+ # back_token = torch.empty(1, _dim, dtype=torch.float32, device='cuda')
191
+ # if world_size > 1:
192
+ # dist.broadcast(back_token, src=0)
193
+ back_token = torch.zeros(1, _dim, dtype=torch.float32, device='cuda')
194
+ # back_token = back_token.to(device='cpu')
195
  cls_embed = torch.cat([
196
  cls_embed, back_token.repeat(_prototypes, 1)[None]
197
  ], dim=0)
app/models/heads/yoso_head.py CHANGED
@@ -369,13 +369,15 @@ class CrossAttenHead(nn.Module):
369
  # background class
370
  _dim = cls_embed.size(2)
371
  _prototypes = cls_embed.size(1)
372
- if rank == 0:
373
- back_token = torch.zeros(1, _dim, dtype=torch.float32, device='cuda')
374
- else:
375
- back_token = torch.empty(1, _dim, dtype=torch.float32, device='cuda')
376
- if world_size > 1:
377
- dist.broadcast(back_token, src=0)
378
- back_token = back_token.to(device='cpu')
 
 
379
  cls_embed = torch.cat([
380
  cls_embed, back_token.repeat(_prototypes, 1)[None]
381
  ], dim=0)
 
369
  # background class
370
  _dim = cls_embed.size(2)
371
  _prototypes = cls_embed.size(1)
372
+ # if rank == 0:
373
+ # back_token = torch.zeros(1, _dim, dtype=torch.float32, device='cuda')
374
+ # # back_token = back_token / back_token.norm(p=2, dim=-1, keepdim=True)
375
+ # else:
376
+ # back_token = torch.empty(1, _dim, dtype=torch.float32, device='cuda')
377
+ # if world_size > 1:
378
+ # dist.broadcast(back_token, src=0)
379
+ back_token = torch.zeros(1, _dim, dtype=torch.float32, device='cuda')
380
+ # back_token = back_token.to(device='cpu')
381
  cls_embed = torch.cat([
382
  cls_embed, back_token.repeat(_prototypes, 1)[None]
383
  ], dim=0)