""" Originally inspired by impl at https://github.com/zhunzhong07/Random-Erasing, Apache 2.0 Copyright Zhun Zhong & Liang Zheng Hacked together by / Copyright 2020 Ross Wightman Modified by Hangbo Bao, for generating the masked position for visual image transformer """ # -------------------------------------------------------- # BEIT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254) # Github source: https://github.com/microsoft/unilm/tree/master/beit # Copyright (c) 2021 Microsoft # Licensed under The MIT License [see LICENSE for details] # By Hangbo Bao # Based on timm, DINO and DeiT code bases # https://github.com/rwightman/pytorch-image-models/tree/master/timm # Originally inspired by impl at https://github.com/zhunzhong07/Random-Erasing, Apache 2.0 # Copyright Zhun Zhong & Liang Zheng # # Hacked together by / Copyright 2020 Ross Wightman # # Modified by Hangbo Bao, for generating the masked position for visual image transformer # --------------------------------------------------------' import random import math import numpy as np class MaskingGenerator: def __init__( self, input_size, num_masking_patches, min_num_patches=4, max_num_patches=None, min_aspect=0.3, max_aspect=None): if not isinstance(input_size, tuple): input_size = (input_size, ) * 2 self.height, self.width = input_size self.num_patches = self.height * self.width self.num_masking_patches = num_masking_patches self.min_num_patches = min_num_patches self.max_num_patches = num_masking_patches if max_num_patches is None else max_num_patches max_aspect = max_aspect or 1 / min_aspect self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect)) def __repr__(self): repr_str = "Generator(%d, %d -> [%d ~ %d], max = %d, %.3f ~ %.3f)" % ( self.height, self.width, self.min_num_patches, self.max_num_patches, self.num_masking_patches, self.log_aspect_ratio[0], self.log_aspect_ratio[1]) return repr_str def get_shape(self): return self.height, self.width def _mask(self, mask, max_mask_patches): delta = 0 for attempt in range(10): target_area = random.uniform(self.min_num_patches, max_mask_patches) aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio)) h = int(round(math.sqrt(target_area * aspect_ratio))) w = int(round(math.sqrt(target_area / aspect_ratio))) if w < self.width and h < self.height: top = random.randint(0, self.height - h) left = random.randint(0, self.width - w) num_masked = mask[top: top + h, left: left + w].sum() # Overlap if 0 < h * w - num_masked <= max_mask_patches: for i in range(top, top + h): for j in range(left, left + w): if mask[i, j] == 0: mask[i, j] = 1 delta += 1 if delta > 0: break return delta def __call__(self): mask = np.zeros(shape=self.get_shape(), dtype=np.int32) mask_count = 0 while mask_count < self.num_masking_patches: max_mask_patches = self.num_masking_patches - mask_count max_mask_patches = min(max_mask_patches, self.max_num_patches) delta = self._mask(mask, max_mask_patches) if delta == 0: break else: mask_count += delta # maintain a fix number {self.num_masking_patches} if mask_count > self.num_masking_patches: delta = mask_count - self.num_masking_patches mask_x, mask_y = mask.nonzero() to_vis = np.random.choice(mask_x.shape[0], delta, replace=False) mask[mask_x[to_vis], mask_y[to_vis]] = 0 elif mask_count < self.num_masking_patches: delta = self.num_masking_patches - mask_count mask_x, mask_y = (mask == 0).nonzero() to_mask = np.random.choice(mask_x.shape[0], delta, replace=False) mask[mask_x[to_mask], mask_y[to_mask]] = 1 assert mask.sum() == self.num_masking_patches, f"mask: {mask}, mask count {mask.sum()}" return mask if __name__ == '__main__': import pdb generator = MaskingGenerator(input_size=14, num_masking_patches=118, min_num_patches=16,) for i in range(10000000): mask = generator() if mask.sum() != 118: pdb.set_trace() print(mask) print(mask.sum())