import torch from torch.utils.data import ConcatDataset, DataLoader from torchvision import transforms from torchvision.transforms import InterpolationMode from dataset.it_dataset import ITImgTrainDataset, ITVidTrainDataset def get_media_type(dataset_config): if len(dataset_config) == 3 and dataset_config[2] == "video": return "video" elif dataset_config[-1] == "only_video": return "only_video" else: return "image" def create_dataset(dataset_type, config): if "clip" in config.model.get("vit_model", 'vit'): mean = (0.485, 0.456, 0.406) std = (0.229, 0.224, 0.225) else: vision_enc_name = config.model.vision_encoder.name if "swin" in vision_enc_name or "vit" in vision_enc_name: mean = (0.485, 0.456, 0.406) std = (0.229, 0.224, 0.225) elif "beit" in vision_enc_name: mean = (0.5, 0.5, 0.5) # for all beit model except IN1K finetuning std = (0.5, 0.5, 0.5) elif "clip" in vision_enc_name: mean = (0.48145466, 0.4578275, 0.40821073) std = (0.26862954, 0.26130258, 0.27577711) else: raise ValueError normalize = transforms.Normalize(mean, std) # loaded images and videos are torch.Tensor of torch.uint8 format, # ordered as (T, 1 or 3, H, W) where T=1 for image type_transform = transforms.Lambda(lambda x: x.float().div(255.0)) if config.inputs.video_input.random_aug: aug_transform = transforms.RandAugment() else: aug_transform = transforms.Lambda(lambda x: x) train_transform = transforms.Compose( [ aug_transform, transforms.RandomResizedCrop( config.inputs.image_res, scale=(0.5, 1.0), interpolation=InterpolationMode.BICUBIC, ), transforms.RandomHorizontalFlip(), type_transform, normalize, ] ) test_transform = transforms.Compose( [ transforms.Resize( (config.inputs.image_res, config.inputs.image_res), interpolation=InterpolationMode.BICUBIC, ), type_transform, normalize, ] ) video_reader_type = config.inputs.video_input.get("video_reader_type", "decord") video_only_dataset_kwargs_train = dict( video_reader_type=video_reader_type, sample_type=config.inputs.video_input.sample_type, num_frames=config.inputs.video_input.num_frames, num_tries=3, # false tolerance ) if dataset_type == "pt_train": raise ValueError("NOT PRETRAINING YET") elif dataset_type in ["it_train"]: # convert to list of lists train_files = ( [config.train_file] if isinstance(config.train_file[0], str) else config.train_file ) train_media_types = sorted(list({get_media_type(e) for e in train_files})) train_datasets = [] for m in train_media_types: dataset_cls = ITImgTrainDataset if m == "image" else ITVidTrainDataset # dataset of the same media_type will be mixed in a single Dataset object _train_files = [e for e in train_files if get_media_type(e) == m] datasets = [] for train_file in _train_files: dataset_kwargs = dict( ann_file=train_file, transform=train_transform, mm_alone=config.preprocess.get("mm_alone", True), add_second_msg=config.preprocess.get("add_second_msg", True), skip_short_sample=config.preprocess.get("skip_short_sample", False), clip_transform=config.preprocess.get("clip_transform", False), random_shuffle=config.preprocess.get("random_shuffle", True), system=config.preprocess.get("system", ""), role=config.preprocess.get('roles', ("Human", "Assistant")), end_signal=config.preprocess.get('end_signal', "###"), begin_signal=config.preprocess.get('begin_signal', ""), ) if m == "video": video_only_dataset_kwargs_train.update({ "start_token": config.model.get("start_token", ""), }) dataset_kwargs.update(video_only_dataset_kwargs_train) if "tgif" in train_file[1]: video_only_dataset_kwargs_train.update({ "video_reader_type": "gif" }) dataset_kwargs.update(video_only_dataset_kwargs_train) elif "webvid" in train_file[1]: video_only_dataset_kwargs_train.update({ "video_reader_type": "hdfs" }) else: video_only_dataset_kwargs_train.update({ "video_reader_type": "decord" }) dataset_kwargs.update(video_only_dataset_kwargs_train) datasets.append(dataset_cls(**dataset_kwargs)) dataset = ConcatDataset(datasets) train_datasets.append(dataset) return train_datasets def create_loader(datasets, samplers, batch_size, num_workers, is_trains, collate_fns): loaders = [] for dataset, sampler, bs, n_worker, is_train, collate_fn in zip( datasets, samplers, batch_size, num_workers, is_trains, collate_fns ): if is_train: shuffle = sampler is None drop_last = True else: shuffle = False drop_last = False loader = DataLoader( dataset, batch_size=bs, num_workers=n_worker, pin_memory=False, sampler=sampler, shuffle=shuffle, collate_fn=collate_fn, drop_last=drop_last, persistent_workers=True if n_worker > 0 else False, ) loaders.append(loader) return loaders