'''create dataset and dataloader''' import logging import torch import torch.utils.data def create_dataloader(dataset, dataset_opt, opt=None, sampler=None): phase = dataset_opt['phase'] if phase == 'train': if opt['dist']: world_size = torch.distributed.get_world_size() num_workers = dataset_opt['n_workers'] assert dataset_opt['batch_size'] % world_size == 0 batch_size = dataset_opt['batch_size'] // world_size shuffle = False else: num_workers = dataset_opt['n_workers'] * len(opt['gpu_ids']) batch_size = dataset_opt['batch_size'] shuffle = True return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, sampler=sampler, drop_last=True, pin_memory=False) else: return torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1, pin_memory=True) def create_dataset(dataset_opt): mode = dataset_opt['mode'] if mode == 'test': from data.coco_test_dataset import imageTestDataset as D elif mode == 'train': from data.coco_dataset import CoCoDataset as D elif mode == 'td': from data.test_dataset_td import imageTestDataset as D else: raise NotImplementedError('Dataset [{:s}] is not recognized.'.format(mode)) print(mode) dataset = D(dataset_opt) logger = logging.getLogger('base') logger.info('Dataset [{:s} - {:s}] is created.'.format(dataset.__class__.__name__, dataset_opt['name'])) return dataset