File size: 6,012 Bytes
c705408 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 |
from torchvision import transforms
from datasets import video_transforms
from .ucf101_datasets import UCF101
from .dummy_datasets import DummyDataset
from .webvid_datasets import WebVid10M
from .videoswap_datasets import VideoSwapDataset
from .dl3dv_datasets import DL3DVDataset
from .pair_datasets import PairDataset
from .metric_datasets import MetricDataset
from .sakuga_ref_datasets import SakugaRefDataset
def get_dataset(args):
if args.dataset not in ["encdec_images", "pair_dataset"]:
temporal_sample = video_transforms.TemporalRandomCrop(args.num_frames * args.frame_interval) # 16 1
if args.dataset == 'sakuga_ref':
temporal_sample = video_transforms.TemporalRandomCrop(args.num_frames * args.frame_interval+args.ref_jump_frames) # 16 1
if args.dataset == 'ucf101':
transform_ucf101 = transforms.Compose([
video_transforms.ToTensorVideo(), # TCHW
video_transforms.RandomHorizontalFlipVideo(),
video_transforms.UCFCenterCropVideo(args.image_size),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=False)
])
dataset = UCF101(args, transform=transform_ucf101, temporal_sample=temporal_sample)
return dataset
elif args.dataset == 'dummy':
size = (args.height, args.width)
transform = transforms.Compose([
video_transforms.ToTensorVideo(), # TCHW
# video_transforms.RandomHorizontalFlipVideo(), # NOTE
video_transforms.UCFCenterCropVideo(size=size),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=False)
])
dataset = DummyDataset(
sample_frames=args.num_frames,
base_folder=args.base_folder,
temporal_sample=temporal_sample,
transform=transform,
seed=args.seed,
file_list=args.file_list,
)
return dataset
elif args.dataset == 'sakuga_ref':
size = (args.height, args.width)
transform = transforms.Compose([
video_transforms.ToTensorVideo(), # TCHW
# video_transforms.RandomHorizontalFlipVideo(), # NOTE
video_transforms.UCFCenterCropVideo(size=size),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=False)
])
dataset = SakugaRefDataset(
video_frames=args.num_frames,
ref_jump_frames=args.ref_jump_frames,
base_folder=args.base_folder,
temporal_sample=temporal_sample,
transform=transform,
seed=args.seed,
file_list=args.file_list,
)
return dataset
elif args.dataset == 'webvid':
size = (args.height, args.width)
transform = transforms.Compose([
video_transforms.ToTensorVideo(), # TCHW
# video_transforms.RandomHorizontalFlipVideo(), # NOTE
video_transforms.UCFCenterCropVideo(size=size),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=False)
])
dataset = WebVid10M(
sample_frames=args.num_frames,
base_folder=args.base_folder,
temporal_sample=temporal_sample,
transform=transform,
seed=args.seed,
)
return dataset
elif args.dataset == 'videoswap':
size = (args.height, args.width)
transform = transforms.Compose([
video_transforms.ToTensorVideo(), # TCHW
# video_transforms.RandomHorizontalFlipVideo(),
# video_transforms.UCFCenterCropVideo(size=size),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=False)
])
dataset = VideoSwapDataset(
width=args.width,
height=args.height,
sample_frames=args.num_frames,
base_folder=args.base_folder,
temporal_sample=temporal_sample,
transform=transform,
seed=args.seed
)
return dataset
elif args.dataset == 'dl3dv':
size = (args.height, args.width)
# transform = transforms.Compose([
# video_transforms.ToTensorVideo(), # TCHW
# # video_transforms.RandomHorizontalFlipVideo(),
# # video_transforms.UCFCenterCropVideo(size=size),
# transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=False)
# ])
dataset = DL3DVDataset(
width=args.width,
height=args.height,
sample_frames=args.num_frames,
base_folder=args.base_folder,
file_list=args.file_list,
temporal_sample=temporal_sample,
# transform=transform,
seed=args.seed,
)
return dataset
elif args.dataset == "pair_dataset":
# size = (args.height, args.width)
# transform = transforms.Compose([
# video_transforms.ToTensorVideo(), # TCHW
# # video_transforms.RandomHorizontalFlipVideo(),
# video_transforms.UCFCenterCropVideo(size=size),
# # transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=False)
# ])
dataset = PairDataset(
# width=args.width,
# height=args.height,
# sample_frames=args.num_frames,
base_folder=args.base_folder,
# temporal_sample=temporal_sample,
# transform=transform,
# seed=args.seed,
with_pair=args.with_pair,
)
return dataset
elif args.dataset == "metric_dataset":
dataset = MetricDataset(
base_folder=args.base_folder,
)
return dataset
else:
raise NotImplementedError(args.dataset)
|