import os, io import re import json import torch import decord import torchvision import numpy as np from PIL import Image from einops import rearrange from typing import Dict, List, Tuple from torchvision import transforms import random class_labels_map = None cls_sample_cnt = None class_labels_map = None cls_sample_cnt = None def temporal_sampling(frames, start_idx, end_idx, num_samples): """ Given the start and end frame index, sample num_samples frames between the start and end with equal interval. Args: frames (tensor): a tensor of video frames, dimension is `num video frames` x `channel` x `height` x `width`. start_idx (int): the index of the start frame. end_idx (int): the index of the end frame. num_samples (int): number of frames to sample. Returns: frames (tersor): a tensor of temporal sampled video frames, dimension is `num clip frames` x `channel` x `height` x `width`. """ index = torch.linspace(start_idx, end_idx, num_samples) index = torch.clamp(index, 0, frames.shape[0] - 1).long() frames = torch.index_select(frames, 0, index) return frames def get_filelist(file_path): Filelist = [] for home, dirs, files in os.walk(file_path): for filename in files: Filelist.append(os.path.join(home, filename)) # Filelist.append( filename) return Filelist def load_annotation_data(data_file_path): with open(data_file_path, 'r') as data_file: return json.load(data_file) def get_class_labels(num_class, anno_pth='./k400_classmap.json'): global class_labels_map, cls_sample_cnt if class_labels_map is not None: return class_labels_map, cls_sample_cnt else: cls_sample_cnt = {} class_labels_map = load_annotation_data(anno_pth) for cls in class_labels_map: cls_sample_cnt[cls] = 0 return class_labels_map, cls_sample_cnt def load_annotations(ann_file, num_class, num_samples_per_cls): dataset = [] class_to_idx, cls_sample_cnt = get_class_labels(num_class) with open(ann_file, 'r') as fin: for line in fin: line_split = line.strip().split('\t') sample = {} idx = 0 # idx for frame_dir frame_dir = line_split[idx] sample['video'] = frame_dir idx += 1 # idx for label[s] label = [x for x in line_split[idx:]] assert label, f'missing label in line: {line}' assert len(label) == 1 class_name = label[0] class_index = int(class_to_idx[class_name]) # choose a class subset of whole dataset if class_index < num_class: sample['label'] = class_index if cls_sample_cnt[class_name] < num_samples_per_cls: dataset.append(sample) cls_sample_cnt[class_name]+=1 return dataset def find_classes(directory: str) -> Tuple[List[str], Dict[str, int]]: """Finds the class folders in a dataset. See :class:`DatasetFolder` for details. """ classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir()) if not classes: raise FileNotFoundError(f"Couldn't find any class folder in {directory}.") class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)} return classes, class_to_idx class DecordInit(object): """Using Decord(https://github.com/dmlc/decord) to initialize the video_reader.""" def __init__(self, num_threads=1): self.num_threads = num_threads self.ctx = decord.cpu(0) def __call__(self, filename): """Perform the Decord initialization. Args: results (dict): The resulting dict to be modified and passed to the next transform in pipeline. """ reader = decord.VideoReader(filename, ctx=self.ctx, num_threads=self.num_threads) return reader def __repr__(self): repr_str = (f'{self.__class__.__name__}(' f'sr={self.sr},' f'num_threads={self.num_threads})') return repr_str class UCF101Images(torch.utils.data.Dataset): """Load the UCF101 video files Args: target_video_len (int): the number of video frames will be load. align_transform (callable): Align different videos in a specified size. temporal_sample (callable): Sample the target length of a video. """ def __init__(self, configs, transform=None, temporal_sample=None): self.configs = configs self.data_path = configs.data_path self.video_lists = get_filelist(configs.data_path) self.transform = transform self.temporal_sample = temporal_sample self.target_video_len = self.configs.num_frames self.v_decoder = DecordInit() self.classes, self.class_to_idx = find_classes(self.data_path) self.video_num = len(self.video_lists) # ucf101 video frames self.frame_data_path = configs.frame_data_path # important self.video_frame_txt = configs.frame_data_txt self.video_frame_files = [frame_file.strip() for frame_file in open(self.video_frame_txt)] random.shuffle(self.video_frame_files) self.use_image_num = configs.use_image_num self.image_tranform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True) ]) self.video_frame_num = len(self.video_frame_files) def __getitem__(self, index): video_index = index % self.video_num path = self.video_lists[video_index] class_name = path.split('/')[-2] class_index = self.class_to_idx[class_name] vframes, aframes, info = torchvision.io.read_video(filename=path, pts_unit='sec', output_format='TCHW') total_frames = len(vframes) # Sampling video frames start_frame_ind, end_frame_ind = self.temporal_sample(total_frames) assert end_frame_ind - start_frame_ind >= self.target_video_len frame_indice = np.linspace(start_frame_ind, end_frame_ind-1, self.target_video_len, dtype=int) video = vframes[frame_indice] # videotransformer data proprecess video = self.transform(video) # T C H W images = [] image_names = [] for i in range(self.use_image_num): while True: try: video_frame_path = self.video_frame_files[index+i] image_class_name = video_frame_path.split('_')[1] image_class_index = self.class_to_idx[image_class_name] video_frame_path = os.path.join(self.frame_data_path, video_frame_path) image = Image.open(video_frame_path).convert('RGB') image = self.image_tranform(image).unsqueeze(0) images.append(image) image_names.append(str(image_class_index)) break except Exception as e: index = random.randint(0, self.video_frame_num - self.use_image_num) images = torch.cat(images, dim=0) assert len(images) == self.use_image_num assert len(image_names) == self.use_image_num image_names = '====='.join(image_names) video_cat = torch.cat([video, images], dim=0) return {'video': video_cat, 'video_name': class_index, 'image_name': image_names} def __len__(self): return self.video_frame_num if __name__ == '__main__': import argparse import video_transforms import torch.utils.data as Data import torchvision.transforms as transforms from PIL import Image parser = argparse.ArgumentParser() parser.add_argument("--num_frames", type=int, default=16) parser.add_argument("--frame_interval", type=int, default=3) parser.add_argument("--use-image-num", type=int, default=5) parser.add_argument("--data-path", type=str, default="/path/to/datasets/UCF101/videos/") parser.add_argument("--frame-data-path", type=str, default="/path/to/datasets/preprocessed_ffs/train/images/") parser.add_argument("--frame-data-txt", type=str, default="/path/to/datasets/UCF101/train_256_list.txt") config = parser.parse_args() temporal_sample = video_transforms.TemporalRandomCrop(config.num_frames * config.frame_interval) transform_ucf101 = transforms.Compose([ video_transforms.ToTensorVideo(), # TCHW video_transforms.RandomHorizontalFlipVideo(), video_transforms.UCFCenterCropVideo(256), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True) ]) ffs_dataset = UCF101Images(config, transform=transform_ucf101, temporal_sample=temporal_sample) ffs_dataloader = Data.DataLoader(dataset=ffs_dataset, batch_size=6, shuffle=False, num_workers=1) # for i, video_data in enumerate(ffs_dataloader): for video_data in ffs_dataloader: # print(type(video_data)) video = video_data['video'] # video_name = video_data['video_name'] print(video.shape) print(video_data['image_name']) image_name = video_data['image_name'] image_names = [] for caption in image_name: single_caption = [int(item) for item in caption.split('=====')] image_names.append(torch.as_tensor(single_caption)) print(image_names) # print(video_name) # print(video_data[2]) # for i in range(16): # img0 = rearrange(video_data[0][0][i], 'c h w -> h w c') # print('Label: {}'.format(video_data[1])) # print(img0.shape) # img0 = Image.fromarray(np.uint8(img0 * 255)) # img0.save('./img{}.jpg'.format(i))