Spaces:
Running
Running
def build_loader_simmim(config): | |
############ single model ##################### | |
# transform = SimMIMTransform(config) | |
# dataset = ImageFolder(config.DATA.DATA_PATH, transform) | |
# sampler = DistributedSampler(dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=True) | |
# dataloader = DataLoader(dataset, config.DATA.BATCH_SIZE, sampler=sampler, num_workers=config.DATA.NUM_WORKERS, pin_memory=True, drop_last=True, collate_fn=collate_fn) | |
############## multi model #################### | |
datasets = [] | |
### 数据增强 ###### | |
model_paths = config.DATA.TYPE_PATH[0] | |
for i in model_paths.keys(): | |
a = config.DATA.SCALE[0][i].split(',') | |
scale_model = (float(a[0].split('(')[1]) ,float(a[1].split(')')[0])) | |
transform = SimMIMTransform(config, config.DATA.NORM[0][i], scale_model) | |
dataset = CachedImageFolder(model_paths[i], transform = transform, model = i) | |
datasets.append(dataset) | |
multi_task_train_dataset = MultiTaskDataset(datasets) | |
print(len(datasets)) | |
multi_task_batch_sampler = DistrubutedMultiTaskBatchSampler(datasets, batch_size=config.DATA.BATCH_SIZE, num_replicas=dist.get_world_size(), rank=dist.get_rank(), mix_opt=0, extra_task_ratio=0, drop_last=True ,shuffle =True) | |
dataloader = DataLoader(multi_task_train_dataset, batch_sampler=multi_task_batch_sampler, num_workers=config.DATA.NUM_WORKERS, pin_memory=True, collate_fn=collate_fn) | |
# dataloader = DataLoader(multi_task_train_dataset, batch_sampler=multi_task_batch_sampler, pin_memory=True, collate_fn=collate_fn) | |
print(len(dataloader)) | |
return dataloader |