Spaces:
Configuration error
Configuration error
File size: 3,330 Bytes
1ba539f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 |
from .transforms import make_transforms
from . import samplers
import torch
import torch.utils.data
import imp
import os
from .collate_batch import make_collator
import numpy as np
import time
from lib.config.config import cfg
def _dataset_factory(is_train):
if is_train:
module = cfg.train_dataset_module
path = cfg.train_dataset_path
args = cfg.train_dataset
else:
module = cfg.test_dataset_module
path = cfg.test_dataset_path
args = cfg.test_dataset
dataset = imp.load_source(module, path).Dataset(**args)
return dataset
def make_dataset(cfg, dataset_name, transforms, is_train=True):
dataset = _dataset_factory(is_train)
return dataset
def make_data_sampler(dataset, shuffle, is_distributed, is_train):
if not is_train and cfg.test.sampler == 'FrameSampler':
sampler = samplers.FrameSampler(dataset)
return sampler
if is_distributed:
return samplers.DistributedSampler(dataset, shuffle=shuffle)
if shuffle:
sampler = torch.utils.data.sampler.RandomSampler(dataset)
else:
sampler = torch.utils.data.sampler.SequentialSampler(dataset)
return sampler
def make_batch_data_sampler(cfg, sampler, batch_size, drop_last, max_iter,
is_train):
if is_train:
batch_sampler = cfg.train.batch_sampler
sampler_meta = cfg.train.sampler_meta
else:
batch_sampler = cfg.test.batch_sampler
sampler_meta = cfg.test.sampler_meta
if batch_sampler == 'default':
batch_sampler = torch.utils.data.sampler.BatchSampler(
sampler, batch_size, drop_last)
elif batch_sampler == 'image_size':
batch_sampler = samplers.ImageSizeBatchSampler(sampler, batch_size,
drop_last, sampler_meta)
if max_iter != -1:
batch_sampler = samplers.IterationBasedBatchSampler(
batch_sampler, max_iter)
return batch_sampler
def worker_init_fn(worker_id):
np.random.seed(worker_id + (int(round(time.time() * 1000) % (2**16))))
def make_data_loader(cfg, is_train=True, is_distributed=False, max_iter=-1):
if is_train:
batch_size = cfg.train.batch_size
# shuffle = True
shuffle = cfg.train.shuffle
drop_last = False
else:
batch_size = cfg.test.batch_size
shuffle = True if is_distributed else False
drop_last = False
dataset_name = cfg.train.dataset if is_train else cfg.test.dataset
transforms = make_transforms(cfg, is_train)
dataset = make_dataset(cfg, dataset_name, transforms, is_train)
sampler = make_data_sampler(dataset, shuffle, is_distributed, is_train)
batch_sampler = make_batch_data_sampler(cfg, sampler, batch_size,
drop_last, max_iter, is_train)
num_workers = cfg.train.num_workers
collator = make_collator(cfg, is_train)
data_loader = torch.utils.data.DataLoader(dataset,
batch_sampler=batch_sampler,
num_workers=num_workers,
collate_fn=collator,
worker_init_fn=worker_init_fn)
return data_loader
|