import torch import numpy as np import glob import os import io import random import pickle from torch.utils.data import Dataset, DataLoader from lib.data.augmentation import Augmenter3D from lib.utils.tools import read_pkl from lib.utils.utils_data import flip_data class MotionDataset(Dataset): def __init__(self, args, subset_list, data_split): # data_split: train/test np.random.seed(0) self.data_root = args.data_root self.subset_list = subset_list self.data_split = data_split file_list_all = [] for subset in self.subset_list: data_path = os.path.join(self.data_root, subset, self.data_split) motion_list = sorted(os.listdir(data_path)) for i in motion_list: file_list_all.append(os.path.join(data_path, i)) self.file_list = file_list_all def __len__(self): 'Denotes the total number of samples' return len(self.file_list) def __getitem__(self, index): raise NotImplementedError class MotionDataset3D(MotionDataset): def __init__(self, args, subset_list, data_split): super(MotionDataset3D, self).__init__(args, subset_list, data_split) self.flip = args.flip self.synthetic = args.synthetic self.aug = Augmenter3D(args) self.gt_2d = args.gt_2d def __getitem__(self, index): 'Generates one sample of data' # Select sample file_path = self.file_list[index] motion_file = read_pkl(file_path) motion_3d = motion_file["data_label"] if self.data_split=="train": if self.synthetic or self.gt_2d: motion_3d = self.aug.augment3D(motion_3d) motion_2d = np.zeros(motion_3d.shape, dtype=np.float32) motion_2d[:,:,:2] = motion_3d[:,:,:2] motion_2d[:,:,2] = 1 # No 2D detection, use GT xy and c=1. elif motion_file["data_input"] is not None: # Have 2D detection motion_2d = motion_file["data_input"] if self.flip and random.random() > 0.5: # Training augmentation - random flipping motion_2d = flip_data(motion_2d) motion_3d = flip_data(motion_3d) else: raise ValueError('Training illegal.') elif self.data_split=="test": motion_2d = motion_file["data_input"] if self.gt_2d: motion_2d[:,:,:2] = motion_3d[:,:,:2] motion_2d[:,:,2] = 1 else: raise ValueError('Data split unknown.') return torch.FloatTensor(motion_2d), torch.FloatTensor(motion_3d)