Spaces:
Running
Running
# -------------------------------------------------------- | |
# SimMIM | |
# Copyright (c) 2021 Microsoft | |
# Licensed under The MIT License [see LICENSE for details] | |
# Written by Zhenda Xie | |
# -------------------------------------------------------- | |
import math | |
import random | |
import numpy as np | |
import torch | |
import torch.distributed as dist | |
import torchvision.transforms as T | |
from torch.utils.data import DataLoader, DistributedSampler | |
from torch.utils.data._utils.collate import default_collate | |
from torchvision.datasets import ImageFolder | |
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD | |
from torch.utils.data import Dataset, BatchSampler | |
from torchvision.io import read_image | |
from .cached_image_folder import CachedImageFolder | |
class MultiTaskDataset(Dataset): | |
""" | |
useage example: | |
train_datasets = [SemData_Single(), SemData_Single()] | |
multi_task_train_dataset = MultiTaskDataset(train_datasets) | |
multi_task_batch_sampler = MultiTaskBatchSampler(train_datasets, batch_size=4, mix_opt=0, extra_task_ratio=0, drop_last=True) | |
multi_task_train_data = DataLoader(multi_task_train_dataset, batch_sampler=multi_task_batch_sampler) | |
for i, (task_id, input, target) in enumerate(multi_task_train_data): | |
pre = model(input) | |
""" | |
def __init__(self, datasets): | |
self._datasets = datasets | |
task_id_2_data_set_dic = {} | |
for i, dataset in enumerate(datasets): | |
task_id = i | |
assert task_id not in task_id_2_data_set_dic, "Duplicate task_id %s" % task_id | |
task_id_2_data_set_dic[task_id] = dataset | |
self._task_id_2_data_set_dic = task_id_2_data_set_dic | |
def __len__(self): | |
return sum(len(dataset) for dataset in self._datasets) | |
def __getitem__(self, idx): | |
task_id, sample_id = idx | |
return self._task_id_2_data_set_dic[task_id][sample_id] | |
class DistrubutedMultiTaskBatchSampler(BatchSampler): | |
""" | |
datasets: class the class of the Dataset | |
batch_size: int | |
mix_opt: int mix_opt ==0 shuffle all_task; mix_opt ==1 shuffle extra_task | |
extra_task_ratio(float, optional): the rate between task one and extra task | |
drop_last (bool, optional): set to ``True`` to drop the last incomplete batch, | |
if the dataset size is not divisible by the batch size. If ``False`` and | |
the size of dataset is not divisible by the batch size, then the last batch | |
will be smaller. (default: ``True``) | |
""" | |
def __init__(self, datasets, batch_size, num_replicas, rank, mix_opt=0, extra_task_ratio=0, drop_last=True,shuffle = True): | |
if num_replicas is None: | |
if not dist.is_available(): | |
raise RuntimeError("Requires distributed package to be available") | |
num_replicas = dist.get_world_size() | |
if rank is None: | |
if not dist.is_available(): | |
raise RuntimeError("Requires distributed package to be available") | |
rank = dist.get_rank() | |
if rank >= num_replicas or rank < 0: | |
raise ValueError( | |
"Invalid rank {}, rank should be in the interval" | |
" [0, {}]".format(rank, num_replicas - 1)) | |
self.num_replicas = num_replicas | |
self.rank = rank | |
self.epoch = 0 | |
assert mix_opt in [0, 1], 'mix_opt must equal 0 or 1' | |
assert extra_task_ratio >= 0, 'extra_task_ratio must greater than 0' | |
self._datasets = datasets | |
self._batch_size = batch_size | |
self._mix_opt = mix_opt | |
self._extra_task_ratio = extra_task_ratio | |
self._drop_last = drop_last | |
train_data_list = [] | |
self.shuffle = shuffle | |
for dataset in datasets: | |
print(len(dataset)) | |
train_data_list.append(self._get_index_batches(len(dataset), batch_size, self._drop_last)) | |
######### 一个列表里存n个dataset的数据,数据也以列表形式存在,一个dataset的列表里面把数据划分成了不同的batch的index | |
self._train_data_list = train_data_list | |
self.total_len = sum(len(train_data) for train_data in self._train_data_list) | |
######### DDP ###################### | |
if self._drop_last and self.total_len % self.num_replicas != 0: # type: ignore[arg-type] | |
self.num_samples = math.ceil( | |
(self.total_len - self.num_replicas) / self.num_replicas # type: ignore[arg-type] | |
) | |
else: | |
self.num_samples = math.ceil(self.total_len / self.num_replicas) # type: ignore[arg-type] | |
self.total_size = self.num_samples * self.num_replicas | |
self.epoch = 0 | |
self.seed = 0 | |
def set_epoch(self, epoch): | |
self.epoch = epoch | |
def _get_index_batches(dataset_len, batch_size, drop_last): | |
# index_batches = [list(range(i, min(i+batch_size, dataset_len))) for i in range(0, dataset_len, batch_size)] | |
index = list(range(dataset_len)) | |
if drop_last and dataset_len % batch_size: | |
del index[-(dataset_len % batch_size):] | |
index_batches = [index[i:i+batch_size] for i in range(0, len(index), batch_size)] | |
return index_batches | |
def __len__(self): | |
# return sum(len(train_data) for train_data in self._train_data_list) | |
return self.num_samples | |
def __iter__(self): | |
all_iters = [iter(item) for item in self._train_data_list] | |
all_indices = self._gen_task_indices(self._train_data_list, self._mix_opt, self._extra_task_ratio) | |
######### DDP ###################### | |
random.shuffle(all_indices) | |
all_indices = all_indices[self.rank:self.total_size:self.num_replicas] | |
assert len(all_indices) == self.num_samples | |
for local_task_idx in all_indices: | |
# task_id = self._datasets[local_task_idx].get_task_id() | |
batch = next(all_iters[local_task_idx]) | |
# batch = batch[self.rank:len(batch):self.num_replicas] | |
# print(local_task_idx) | |
yield [(local_task_idx, sample_id) for sample_id in batch] | |
# yield iter(batch) | |
def _gen_task_indices(train_data_list, mix_opt, extra_task_ratio): | |
########## accoding to the number of models ########### | |
all_indices = [] | |
for i in range(len(train_data_list)): | |
all_indices += [i] * len(train_data_list[i]) | |
# print(all_indices) | |
return all_indices | |
# def set_epoch(self, epoch) | |
# r""" | |
# Sets the epoch for this sampler. When :attr:`shuffle=True`, this ensures all replicas | |
# use a different random ordering for each epoch. Otherwise, the next iteration of this | |
# sampler will yield the same ordering. | |
# Args: | |
# epoch (int): Epoch number. | |
# """ | |
# self.epoch = epoch | |
class MaskGenerator: | |
def __init__(self, input_size=192, mask_patch_size=32, model_patch_size=4, mask_ratio=0.6): | |
self.input_size = input_size | |
self.mask_patch_size = mask_patch_size | |
self.model_patch_size = model_patch_size | |
self.mask_ratio = mask_ratio | |
assert self.input_size % self.mask_patch_size == 0 | |
assert self.mask_patch_size % self.model_patch_size == 0 | |
self.rand_size = self.input_size // self.mask_patch_size | |
self.scale = self.mask_patch_size // self.model_patch_size | |
self.token_count = self.rand_size ** 2 | |
self.mask_count = int(np.ceil(self.token_count * self.mask_ratio)) | |
def __call__(self): | |
mask_idx = np.random.permutation(self.token_count)[:self.mask_count] | |
mask = np.zeros(self.token_count, dtype=int) | |
mask[mask_idx] = 1 | |
mask = mask.reshape((self.rand_size, self.rand_size)) | |
mask = mask.repeat(self.scale, axis=0).repeat(self.scale, axis=1) | |
return mask | |
class ZeroOneNormalize(object): | |
def __call__(self, img): | |
return img.float().div(255) | |
class SimMIMTransform: | |
def __init__(self, config, NORM, SCALE): | |
self.transform_img = T.Compose([ | |
# T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), | |
# T.RandomResizedCrop(config.DATA.IMG_SIZE, scale=(0.67, 1.), ratio=(3. / 4., 4. / 3.)), | |
# T.RandomHorizontalFlip(), | |
# T.ToTensor(), | |
# T.Normalize(mean=torch.tensor(IMAGENET_DEFAULT_MEAN),std=torch.tensor(IMAGENET_DEFAULT_STD)), | |
T.RandomResizedCrop(config.DATA.IMG_SIZE, scale=SCALE, ratio=(3. / 4., 4. / 3.)), | |
T.RandomHorizontalFlip(), | |
ZeroOneNormalize(), | |
T.Normalize(mean=torch.tensor(NORM[0]),std=torch.tensor(NORM[1])), | |
]) | |
if config.MODEL.TYPE in ['swin', 'swinv2']: | |
model_patch_size=config.MODEL.SWIN.PATCH_SIZE | |
else: | |
raise NotImplementedError | |
self.mask_generator = MaskGenerator( | |
input_size=config.DATA.IMG_SIZE, | |
mask_patch_size=config.DATA.MASK_PATCH_SIZE, | |
model_patch_size=model_patch_size, | |
mask_ratio=config.DATA.MASK_RATIO, | |
) | |
def __call__(self, img): | |
img = self.transform_img(img) | |
mask = self.mask_generator() | |
return img, mask | |
def collate_fn(batch): | |
# print(len(batch)) | |
# print('*'*10) | |
# print(batch[0][0]) | |
# print('#'*10) | |
# print(batch[0][1]) | |
# batch = list(filter(lambda x: x[0][0] is not None, batch)) | |
# if len(batch) == 0: return torch.Tensor() | |
if not isinstance(batch[0][0], tuple): | |
return default_collate(batch) | |
else: | |
batch_num = len(batch) | |
ret = [] | |
for item_idx in range(len(batch[0][0])): | |
if batch[0][0][item_idx] is None: | |
ret.append(None) | |
else: | |
ret.append(default_collate([batch[i][0][item_idx] for i in range(batch_num)])) | |
ret.append(default_collate([batch[i][1] for i in range(batch_num)])) | |
return ret | |
def build_loader_simmim(config): | |
############ single model ##################### | |
# transform = SimMIMTransform(config) | |
# dataset = ImageFolder(config.DATA.DATA_PATH, transform) | |
# sampler = DistributedSampler(dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=True) | |
# dataloader = DataLoader(dataset, config.DATA.BATCH_SIZE, sampler=sampler, num_workers=config.DATA.NUM_WORKERS, pin_memory=True, drop_last=True, collate_fn=collate_fn) | |
############## multi model #################### | |
datasets = [] | |
### 数据增强 ###### | |
model_paths = config.DATA.TYPE_PATH[0] | |
for i in model_paths.keys(): | |
a = config.DATA.SCALE[0][i].split(',') | |
scale_model = (float(a[0].split('(')[1]),float(a[1].split(')')[0])) | |
transform = SimMIMTransform(config, config.DATA.NORM[0][i], scale_model) | |
dataset = CachedImageFolder(model_paths[i], transform = transform, model = i) | |
datasets.append(dataset) | |
multi_task_train_dataset = MultiTaskDataset(datasets) | |
print(len(datasets)) | |
multi_task_batch_sampler = DistrubutedMultiTaskBatchSampler(datasets, batch_size=config.DATA.BATCH_SIZE, num_replicas=dist.get_world_size(), rank=dist.get_rank(), mix_opt=0, extra_task_ratio=0, drop_last=True,shuffle =True) | |
dataloader = DataLoader(multi_task_train_dataset, batch_sampler=multi_task_batch_sampler, num_workers=config.DATA.NUM_WORKERS, pin_memory=True, collate_fn=collate_fn) | |
# dataloader = DataLoader(multi_task_train_dataset, batch_sampler=multi_task_batch_sampler, pin_memory=True, collate_fn=collate_fn) | |
print(len(dataloader)) | |
return dataloader |