jadechoghari commited on
Commit
f427d24
·
verified ·
1 Parent(s): 5f5ff73

Update mar.py

Browse files
Files changed (1) hide show
  1. mar.py +1 -1
mar.py CHANGED
@@ -12,7 +12,7 @@ from timm.models.vision_transformer import Block
12
 
13
  from .diffloss import DiffLoss
14
 
15
-
16
  def mask_by_order(mask_len, order, bsz, seq_len):
17
  masking = torch.zeros(bsz, seq_len).to(device)
18
  masking = torch.scatter(masking, dim=-1, index=order[:, :mask_len.long()], src=torch.ones(bsz, seq_len).to(device)).bool()
 
12
 
13
  from .diffloss import DiffLoss
14
 
15
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
  def mask_by_order(mask_len, order, bsz, seq_len):
17
  masking = torch.zeros(bsz, seq_len).to(device)
18
  masking = torch.scatter(masking, dim=-1, index=order[:, :mask_len.long()], src=torch.ones(bsz, seq_len).to(device)).bool()