from importlib import import_module | |
from torch.utils.data import dataloader | |
class Data: | |
def __init__(self, args): | |
self.loader_test = [] | |
for d in args.data_test: | |
if d in ['Set5', 'Set14', 'B100', 'Urban100']: | |
m = import_module('data.benchmark') | |
testset = getattr(m, 'Benchmark')(args, name=d) | |
else: | |
raise NotImplementedError | |
self.loader_test.append( | |
dataloader.DataLoader( | |
testset, | |
batch_size=1, | |
shuffle=False, | |
pin_memory=False, | |
num_workers=args.n_threads, | |
) | |
) | |