Spaces:
Runtime error
Runtime error
File size: 2,981 Bytes
a65550c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 |
import torch
import torch.nn as nn
import random
class MaskedDrop(nn.Module):
def __init__(self, model_args):
super().__init__()
self.mode = model_args.mm_mask_drop_mode
self.skip_percentage = model_args.mm_mask_drop_skip_percentage
self.ratio = model_args.mm_mask_drop_ratio
self.ratio_upper = model_args.mm_mask_drop_ratio_upper
self.ratio_lower = model_args.mm_mask_drop_ratio_lower
def forward(self, image_features, *args, **kwargs):
if not self.training:
return image_features
if self.skip_percentage > random.random():
return image_features
masked_features = []
for image_feature in image_features:
num_tokens = image_feature.shape[0]
if self.mode == "fixed":
num_keep = int(num_tokens * self.ratio)
masked_features.append(self.random_masking(image_feature.unsqueeze(0), num_keep)[0][0])
elif self.mode == "range":
num_keep = int(num_tokens * random.uniform(self.ratio_lower, self.ratio_upper))
masked_features.append(self.random_masking(image_feature.unsqueeze(0), num_keep)[0])
elif self.mode == "cls_only":
masked_features.append(image_feature[0:1])
else:
raise ValueError(f"Unexpected masked drop mode: {self.mode}")
if self.mode not in ["range"] and (type(image_features) is not list or self.mode in ["cls_only"]):
masked_features = torch.stack(masked_features, dim=0)
return masked_features
@property
def config(self):
return {
"mm_resampler_type": "masked_drop",
"mm_mask_drop_mode": self.mode,
"mm_mask_drop_skip_percentage": self.skip_percentage,
"mm_mask_drop_ratio": self.ratio,
"mm_mask_drop_ratio_upper": self.ratio_upper,
"mm_mask_drop_ratio_lower": self.ratio_lower,
}
def random_masking(self, x, len_keep):
"""
Perform per-sample random masking by per-sample shuffling.
Per-sample shuffling is done by argsort random noise.
x: [N, L, D], sequence
"""
N, L, D = x.shape # batch, length, dim
noise = torch.rand(N, L, device=x.device) # noise in [0, 1]
# sort noise for each sample
ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove
ids_restore = torch.argsort(ids_shuffle, dim=1)
# keep the first subset
ids_keep = ids_shuffle[:, :len_keep]
x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
# generate the binary mask: 0 is keep, 1 is remove
mask = torch.ones([N, L], device=x.device)
mask[:, :len_keep] = 0
# unshuffle to get the binary mask
mask = torch.gather(mask, dim=1, index=ids_restore)
return x_masked, mask, ids_restore
|