|
from torchvision import transforms
|
|
from datasets import video_transforms
|
|
from .ucf101_datasets import UCF101
|
|
from .dummy_datasets import DummyDataset
|
|
from .webvid_datasets import WebVid10M
|
|
from .videoswap_datasets import VideoSwapDataset
|
|
from .dl3dv_datasets import DL3DVDataset
|
|
from .pair_datasets import PairDataset
|
|
from .metric_datasets import MetricDataset
|
|
from .sakuga_ref_datasets import SakugaRefDataset
|
|
|
|
def get_dataset(args):
|
|
if args.dataset not in ["encdec_images", "pair_dataset"]:
|
|
temporal_sample = video_transforms.TemporalRandomCrop(args.num_frames * args.frame_interval)
|
|
if args.dataset == 'sakuga_ref':
|
|
temporal_sample = video_transforms.TemporalRandomCrop(args.num_frames * args.frame_interval+args.ref_jump_frames)
|
|
if args.dataset == 'ucf101':
|
|
transform_ucf101 = transforms.Compose([
|
|
video_transforms.ToTensorVideo(),
|
|
video_transforms.RandomHorizontalFlipVideo(),
|
|
video_transforms.UCFCenterCropVideo(args.image_size),
|
|
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=False)
|
|
])
|
|
dataset = UCF101(args, transform=transform_ucf101, temporal_sample=temporal_sample)
|
|
return dataset
|
|
|
|
elif args.dataset == 'dummy':
|
|
size = (args.height, args.width)
|
|
transform = transforms.Compose([
|
|
video_transforms.ToTensorVideo(),
|
|
|
|
video_transforms.UCFCenterCropVideo(size=size),
|
|
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=False)
|
|
])
|
|
|
|
dataset = DummyDataset(
|
|
sample_frames=args.num_frames,
|
|
base_folder=args.base_folder,
|
|
temporal_sample=temporal_sample,
|
|
transform=transform,
|
|
seed=args.seed,
|
|
file_list=args.file_list,
|
|
)
|
|
return dataset
|
|
elif args.dataset == 'sakuga_ref':
|
|
size = (args.height, args.width)
|
|
transform = transforms.Compose([
|
|
video_transforms.ToTensorVideo(),
|
|
|
|
video_transforms.UCFCenterCropVideo(size=size),
|
|
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=False)
|
|
])
|
|
|
|
dataset = SakugaRefDataset(
|
|
video_frames=args.num_frames,
|
|
ref_jump_frames=args.ref_jump_frames,
|
|
base_folder=args.base_folder,
|
|
temporal_sample=temporal_sample,
|
|
transform=transform,
|
|
seed=args.seed,
|
|
file_list=args.file_list,
|
|
)
|
|
return dataset
|
|
elif args.dataset == 'webvid':
|
|
size = (args.height, args.width)
|
|
transform = transforms.Compose([
|
|
video_transforms.ToTensorVideo(),
|
|
|
|
video_transforms.UCFCenterCropVideo(size=size),
|
|
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=False)
|
|
])
|
|
|
|
dataset = WebVid10M(
|
|
sample_frames=args.num_frames,
|
|
base_folder=args.base_folder,
|
|
temporal_sample=temporal_sample,
|
|
transform=transform,
|
|
seed=args.seed,
|
|
)
|
|
return dataset
|
|
|
|
elif args.dataset == 'videoswap':
|
|
size = (args.height, args.width)
|
|
transform = transforms.Compose([
|
|
video_transforms.ToTensorVideo(),
|
|
|
|
|
|
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=False)
|
|
])
|
|
|
|
dataset = VideoSwapDataset(
|
|
width=args.width,
|
|
height=args.height,
|
|
sample_frames=args.num_frames,
|
|
base_folder=args.base_folder,
|
|
temporal_sample=temporal_sample,
|
|
transform=transform,
|
|
seed=args.seed
|
|
)
|
|
return dataset
|
|
|
|
elif args.dataset == 'dl3dv':
|
|
size = (args.height, args.width)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dataset = DL3DVDataset(
|
|
width=args.width,
|
|
height=args.height,
|
|
sample_frames=args.num_frames,
|
|
base_folder=args.base_folder,
|
|
file_list=args.file_list,
|
|
temporal_sample=temporal_sample,
|
|
|
|
seed=args.seed,
|
|
)
|
|
return dataset
|
|
|
|
elif args.dataset == "pair_dataset":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dataset = PairDataset(
|
|
|
|
|
|
|
|
base_folder=args.base_folder,
|
|
|
|
|
|
|
|
with_pair=args.with_pair,
|
|
)
|
|
return dataset
|
|
|
|
elif args.dataset == "metric_dataset":
|
|
|
|
dataset = MetricDataset(
|
|
base_folder=args.base_folder,
|
|
)
|
|
return dataset
|
|
|
|
else:
|
|
raise NotImplementedError(args.dataset)
|
|
|