File size: 4,612 Bytes
b3f324b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 |
import json
import os, io, csv, math, random
import numpy as np
import torchvision
from einops import rearrange
from decord import VideoReader
import torch
import torchvision.transforms as transforms
from torch.utils.data.dataset import Dataset
from tqdm import tqdm
from opensora.utils.dataset_utils import DecordInit
from opensora.utils.utils import text_preprocessing
class T2V_dataset(Dataset):
def __init__(self, args, transform, temporal_sample, tokenizer):
# with open(args.data_path, 'r') as csvfile:
# self.samples = list(csv.DictReader(csvfile))
self.video_folder = args.video_folder
self.num_frames = args.num_frames
self.transform = transform
self.temporal_sample = temporal_sample
self.tokenizer = tokenizer
self.model_max_length = args.model_max_length
self.v_decoder = DecordInit()
with open(args.data_path, 'r') as f:
self.samples = json.load(f)
self.use_image_num = args.use_image_num
self.use_img_from_vid = args.use_img_from_vid
if self.use_image_num != 0 and not self.use_img_from_vid:
self.img_cap_list = self.get_img_cap_list()
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
try:
# video = torch.randn(3, 16, 128, 128)
# input_ids = torch.ones(1, 120).to(torch.long).squeeze(0)
# cond_mask = torch.cat([torch.ones(1, 60).to(torch.long), torch.ones(1, 60).to(torch.long)], dim=1).squeeze(0)
# return video, input_ids, cond_mask
video_path = self.samples[idx]['path']
video = self.decord_read(video_path)
video = self.transform(video) # T C H W -> T C H W
video = video.transpose(0, 1) # T C H W -> C T H W
text = self.samples[idx]['cap'][0]
text = text_preprocessing(text)
text_tokens_and_mask = self.tokenizer(
text,
max_length=self.model_max_length,
padding='max_length',
truncation=True,
return_attention_mask=True,
add_special_tokens=True,
return_tensors='pt'
)
input_ids = text_tokens_and_mask['input_ids'].squeeze(0)
cond_mask = text_tokens_and_mask['attention_mask'].squeeze(0)
if self.use_image_num != 0 and self.use_img_from_vid:
select_image_idx = np.linspace(0, self.num_frames-1, self.use_image_num, dtype=int)
assert self.num_frames >= self.use_image_num
images = video[:, select_image_idx] # c, num_img, h, w
video = torch.cat([video, images], dim=1) # c, num_frame+num_img, h, w
input_ids = torch.stack([input_ids] * (1+self.use_image_num)) # 1+self.use_image_num, l
cond_mask = torch.stack([cond_mask] * (1+self.use_image_num)) # 1+self.use_image_num, l
elif self.use_image_num != 0 and not self.use_img_from_vid:
images, captions = self.img_cap_list[idx]
raise NotImplementedError
else:
pass
return video, input_ids, cond_mask
except Exception as e:
print(f'Error with {e}, {self.samples[idx]}')
return self.__getitem__(random.randint(0, self.__len__() - 1))
def tv_read(self, path):
vframes, aframes, info = torchvision.io.read_video(filename=path, pts_unit='sec', output_format='TCHW')
total_frames = len(vframes)
# Sampling video frames
start_frame_ind, end_frame_ind = self.temporal_sample(total_frames)
# assert end_frame_ind - start_frame_ind >= self.num_frames
frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, self.num_frames, dtype=int)
video = vframes[frame_indice] # (T, C, H, W)
return video
def decord_read(self, path):
decord_vr = self.v_decoder(path)
total_frames = len(decord_vr)
# Sampling video frames
start_frame_ind, end_frame_ind = self.temporal_sample(total_frames)
# assert end_frame_ind - start_frame_ind >= self.num_frames
frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, self.num_frames, dtype=int)
video_data = decord_vr.get_batch(frame_indice).asnumpy()
video_data = torch.from_numpy(video_data)
video_data = video_data.permute(0, 3, 1, 2) # (T, H, W, C) -> (T C H W)
return video_data
def get_img_cap_list(self):
raise NotImplementedError |