jadechoghari
commited on
Update mar.py
Browse files
mar.py
CHANGED
@@ -12,7 +12,7 @@ from timm.models.vision_transformer import Block
|
|
12 |
|
13 |
from .diffloss import DiffLoss
|
14 |
|
15 |
-
|
16 |
def mask_by_order(mask_len, order, bsz, seq_len):
|
17 |
masking = torch.zeros(bsz, seq_len).to(device)
|
18 |
masking = torch.scatter(masking, dim=-1, index=order[:, :mask_len.long()], src=torch.ones(bsz, seq_len).to(device)).bool()
|
|
|
12 |
|
13 |
from .diffloss import DiffLoss
|
14 |
|
15 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
16 |
def mask_by_order(mask_len, order, bsz, seq_len):
|
17 |
masking = torch.zeros(bsz, seq_len).to(device)
|
18 |
masking = torch.scatter(masking, dim=-1, index=order[:, :mask_len.long()], src=torch.ones(bsz, seq_len).to(device)).bool()
|