RingMo-SAM / datasets /data_loader_multi_tasks.py
AI-Cyber's picture
Upload 123 files
8d7921b
raw
history blame contribute delete
1.65 kB
def build_loader_simmim(config):
############ single model #####################
# transform = SimMIMTransform(config)
# dataset = ImageFolder(config.DATA.DATA_PATH, transform)
# sampler = DistributedSampler(dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=True)
# dataloader = DataLoader(dataset, config.DATA.BATCH_SIZE, sampler=sampler, num_workers=config.DATA.NUM_WORKERS, pin_memory=True, drop_last=True, collate_fn=collate_fn)
############## multi model ####################
datasets = []
### 数据增强 ######
model_paths = config.DATA.TYPE_PATH[0]
for i in model_paths.keys():
a = config.DATA.SCALE[0][i].split(',')
scale_model = (float(a[0].split('(')[1]) ,float(a[1].split(')')[0]))
transform = SimMIMTransform(config, config.DATA.NORM[0][i], scale_model)
dataset = CachedImageFolder(model_paths[i], transform = transform, model = i)
datasets.append(dataset)
multi_task_train_dataset = MultiTaskDataset(datasets)
print(len(datasets))
multi_task_batch_sampler = DistrubutedMultiTaskBatchSampler(datasets, batch_size=config.DATA.BATCH_SIZE, num_replicas=dist.get_world_size(), rank=dist.get_rank(), mix_opt=0, extra_task_ratio=0, drop_last=True ,shuffle =True)
dataloader = DataLoader(multi_task_train_dataset, batch_sampler=multi_task_batch_sampler, num_workers=config.DATA.NUM_WORKERS, pin_memory=True, collate_fn=collate_fn)
# dataloader = DataLoader(multi_task_train_dataset, batch_sampler=multi_task_batch_sampler, pin_memory=True, collate_fn=collate_fn)
print(len(dataloader))
return dataloader