RingMo-SAM / datasets /data_simmim_pt.py
AI-Cyber's picture
Upload 123 files
8d7921b
raw
history blame contribute delete
11.7 kB
# --------------------------------------------------------
# SimMIM
# Copyright (c) 2021 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Zhenda Xie
# --------------------------------------------------------
import math
import random
import numpy as np
import torch
import torch.distributed as dist
import torchvision.transforms as T
from torch.utils.data import DataLoader, DistributedSampler
from torch.utils.data._utils.collate import default_collate
from torchvision.datasets import ImageFolder
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from torch.utils.data import Dataset, BatchSampler
from torchvision.io import read_image
from .cached_image_folder import CachedImageFolder
class MultiTaskDataset(Dataset):
"""
useage example:
train_datasets = [SemData_Single(), SemData_Single()]
multi_task_train_dataset = MultiTaskDataset(train_datasets)
multi_task_batch_sampler = MultiTaskBatchSampler(train_datasets, batch_size=4, mix_opt=0, extra_task_ratio=0, drop_last=True)
multi_task_train_data = DataLoader(multi_task_train_dataset, batch_sampler=multi_task_batch_sampler)
for i, (task_id, input, target) in enumerate(multi_task_train_data):
pre = model(input)
"""
def __init__(self, datasets):
self._datasets = datasets
task_id_2_data_set_dic = {}
for i, dataset in enumerate(datasets):
task_id = i
assert task_id not in task_id_2_data_set_dic, "Duplicate task_id %s" % task_id
task_id_2_data_set_dic[task_id] = dataset
self._task_id_2_data_set_dic = task_id_2_data_set_dic
def __len__(self):
return sum(len(dataset) for dataset in self._datasets)
def __getitem__(self, idx):
task_id, sample_id = idx
return self._task_id_2_data_set_dic[task_id][sample_id]
class DistrubutedMultiTaskBatchSampler(BatchSampler):
"""
datasets: class the class of the Dataset
batch_size: int
mix_opt: int mix_opt ==0 shuffle all_task; mix_opt ==1 shuffle extra_task
extra_task_ratio(float, optional): the rate between task one and extra task
drop_last (bool, optional): set to ``True`` to drop the last incomplete batch,
if the dataset size is not divisible by the batch size. If ``False`` and
the size of dataset is not divisible by the batch size, then the last batch
will be smaller. (default: ``True``)
"""
def __init__(self, datasets, batch_size, num_replicas, rank, mix_opt=0, extra_task_ratio=0, drop_last=True,shuffle = True):
if num_replicas is None:
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available")
num_replicas = dist.get_world_size()
if rank is None:
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available")
rank = dist.get_rank()
if rank >= num_replicas or rank < 0:
raise ValueError(
"Invalid rank {}, rank should be in the interval"
" [0, {}]".format(rank, num_replicas - 1))
self.num_replicas = num_replicas
self.rank = rank
self.epoch = 0
assert mix_opt in [0, 1], 'mix_opt must equal 0 or 1'
assert extra_task_ratio >= 0, 'extra_task_ratio must greater than 0'
self._datasets = datasets
self._batch_size = batch_size
self._mix_opt = mix_opt
self._extra_task_ratio = extra_task_ratio
self._drop_last = drop_last
train_data_list = []
self.shuffle = shuffle
for dataset in datasets:
print(len(dataset))
train_data_list.append(self._get_index_batches(len(dataset), batch_size, self._drop_last))
######### 一个列表里存n个dataset的数据,数据也以列表形式存在,一个dataset的列表里面把数据划分成了不同的batch的index
self._train_data_list = train_data_list
self.total_len = sum(len(train_data) for train_data in self._train_data_list)
######### DDP ######################
if self._drop_last and self.total_len % self.num_replicas != 0: # type: ignore[arg-type]
self.num_samples = math.ceil(
(self.total_len - self.num_replicas) / self.num_replicas # type: ignore[arg-type]
)
else:
self.num_samples = math.ceil(self.total_len / self.num_replicas) # type: ignore[arg-type]
self.total_size = self.num_samples * self.num_replicas
self.epoch = 0
self.seed = 0
def set_epoch(self, epoch):
self.epoch = epoch
@staticmethod
def _get_index_batches(dataset_len, batch_size, drop_last):
# index_batches = [list(range(i, min(i+batch_size, dataset_len))) for i in range(0, dataset_len, batch_size)]
index = list(range(dataset_len))
if drop_last and dataset_len % batch_size:
del index[-(dataset_len % batch_size):]
index_batches = [index[i:i+batch_size] for i in range(0, len(index), batch_size)]
return index_batches
def __len__(self):
# return sum(len(train_data) for train_data in self._train_data_list)
return self.num_samples
def __iter__(self):
all_iters = [iter(item) for item in self._train_data_list]
all_indices = self._gen_task_indices(self._train_data_list, self._mix_opt, self._extra_task_ratio)
######### DDP ######################
random.shuffle(all_indices)
all_indices = all_indices[self.rank:self.total_size:self.num_replicas]
assert len(all_indices) == self.num_samples
for local_task_idx in all_indices:
# task_id = self._datasets[local_task_idx].get_task_id()
batch = next(all_iters[local_task_idx])
# batch = batch[self.rank:len(batch):self.num_replicas]
# print(local_task_idx)
yield [(local_task_idx, sample_id) for sample_id in batch]
# yield iter(batch)
@staticmethod
def _gen_task_indices(train_data_list, mix_opt, extra_task_ratio):
########## accoding to the number of models ###########
all_indices = []
for i in range(len(train_data_list)):
all_indices += [i] * len(train_data_list[i])
# print(all_indices)
return all_indices
# def set_epoch(self, epoch)
# r"""
# Sets the epoch for this sampler. When :attr:`shuffle=True`, this ensures all replicas
# use a different random ordering for each epoch. Otherwise, the next iteration of this
# sampler will yield the same ordering.
# Args:
# epoch (int): Epoch number.
# """
# self.epoch = epoch
class MaskGenerator:
def __init__(self, input_size=192, mask_patch_size=32, model_patch_size=4, mask_ratio=0.6):
self.input_size = input_size
self.mask_patch_size = mask_patch_size
self.model_patch_size = model_patch_size
self.mask_ratio = mask_ratio
assert self.input_size % self.mask_patch_size == 0
assert self.mask_patch_size % self.model_patch_size == 0
self.rand_size = self.input_size // self.mask_patch_size
self.scale = self.mask_patch_size // self.model_patch_size
self.token_count = self.rand_size ** 2
self.mask_count = int(np.ceil(self.token_count * self.mask_ratio))
def __call__(self):
mask_idx = np.random.permutation(self.token_count)[:self.mask_count]
mask = np.zeros(self.token_count, dtype=int)
mask[mask_idx] = 1
mask = mask.reshape((self.rand_size, self.rand_size))
mask = mask.repeat(self.scale, axis=0).repeat(self.scale, axis=1)
return mask
class ZeroOneNormalize(object):
def __call__(self, img):
return img.float().div(255)
class SimMIMTransform:
def __init__(self, config, NORM, SCALE):
self.transform_img = T.Compose([
# T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
# T.RandomResizedCrop(config.DATA.IMG_SIZE, scale=(0.67, 1.), ratio=(3. / 4., 4. / 3.)),
# T.RandomHorizontalFlip(),
# T.ToTensor(),
# T.Normalize(mean=torch.tensor(IMAGENET_DEFAULT_MEAN),std=torch.tensor(IMAGENET_DEFAULT_STD)),
T.RandomResizedCrop(config.DATA.IMG_SIZE, scale=SCALE, ratio=(3. / 4., 4. / 3.)),
T.RandomHorizontalFlip(),
ZeroOneNormalize(),
T.Normalize(mean=torch.tensor(NORM[0]),std=torch.tensor(NORM[1])),
])
if config.MODEL.TYPE in ['swin', 'swinv2']:
model_patch_size=config.MODEL.SWIN.PATCH_SIZE
else:
raise NotImplementedError
self.mask_generator = MaskGenerator(
input_size=config.DATA.IMG_SIZE,
mask_patch_size=config.DATA.MASK_PATCH_SIZE,
model_patch_size=model_patch_size,
mask_ratio=config.DATA.MASK_RATIO,
)
def __call__(self, img):
img = self.transform_img(img)
mask = self.mask_generator()
return img, mask
def collate_fn(batch):
# print(len(batch))
# print('*'*10)
# print(batch[0][0])
# print('#'*10)
# print(batch[0][1])
# batch = list(filter(lambda x: x[0][0] is not None, batch))
# if len(batch) == 0: return torch.Tensor()
if not isinstance(batch[0][0], tuple):
return default_collate(batch)
else:
batch_num = len(batch)
ret = []
for item_idx in range(len(batch[0][0])):
if batch[0][0][item_idx] is None:
ret.append(None)
else:
ret.append(default_collate([batch[i][0][item_idx] for i in range(batch_num)]))
ret.append(default_collate([batch[i][1] for i in range(batch_num)]))
return ret
def build_loader_simmim(config):
############ single model #####################
# transform = SimMIMTransform(config)
# dataset = ImageFolder(config.DATA.DATA_PATH, transform)
# sampler = DistributedSampler(dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=True)
# dataloader = DataLoader(dataset, config.DATA.BATCH_SIZE, sampler=sampler, num_workers=config.DATA.NUM_WORKERS, pin_memory=True, drop_last=True, collate_fn=collate_fn)
############## multi model ####################
datasets = []
### 数据增强 ######
model_paths = config.DATA.TYPE_PATH[0]
for i in model_paths.keys():
a = config.DATA.SCALE[0][i].split(',')
scale_model = (float(a[0].split('(')[1]),float(a[1].split(')')[0]))
transform = SimMIMTransform(config, config.DATA.NORM[0][i], scale_model)
dataset = CachedImageFolder(model_paths[i], transform = transform, model = i)
datasets.append(dataset)
multi_task_train_dataset = MultiTaskDataset(datasets)
print(len(datasets))
multi_task_batch_sampler = DistrubutedMultiTaskBatchSampler(datasets, batch_size=config.DATA.BATCH_SIZE, num_replicas=dist.get_world_size(), rank=dist.get_rank(), mix_opt=0, extra_task_ratio=0, drop_last=True,shuffle =True)
dataloader = DataLoader(multi_task_train_dataset, batch_sampler=multi_task_batch_sampler, num_workers=config.DATA.NUM_WORKERS, pin_memory=True, collate_fn=collate_fn)
# dataloader = DataLoader(multi_task_train_dataset, batch_sampler=multi_task_batch_sampler, pin_memory=True, collate_fn=collate_fn)
print(len(dataloader))
return dataloader