|
import json |
|
import os, io, csv, math, random |
|
import numpy as np |
|
import torchvision |
|
from einops import rearrange |
|
from decord import VideoReader |
|
|
|
import torch |
|
import torchvision.transforms as transforms |
|
from torch.utils.data.dataset import Dataset |
|
from tqdm import tqdm |
|
|
|
from opensora.utils.dataset_utils import DecordInit |
|
from opensora.utils.utils import text_preprocessing |
|
|
|
|
|
|
|
class T2V_dataset(Dataset): |
|
def __init__(self, args, transform, temporal_sample, tokenizer): |
|
|
|
|
|
|
|
self.video_folder = args.video_folder |
|
self.num_frames = args.num_frames |
|
self.transform = transform |
|
self.temporal_sample = temporal_sample |
|
self.tokenizer = tokenizer |
|
self.model_max_length = args.model_max_length |
|
self.v_decoder = DecordInit() |
|
|
|
with open(args.data_path, 'r') as f: |
|
self.samples = json.load(f) |
|
self.use_image_num = args.use_image_num |
|
self.use_img_from_vid = args.use_img_from_vid |
|
if self.use_image_num != 0 and not self.use_img_from_vid: |
|
self.img_cap_list = self.get_img_cap_list() |
|
|
|
def __len__(self): |
|
return len(self.samples) |
|
|
|
def __getitem__(self, idx): |
|
try: |
|
|
|
|
|
|
|
|
|
video_path = self.samples[idx]['path'] |
|
video = self.decord_read(video_path) |
|
video = self.transform(video) |
|
video = video.transpose(0, 1) |
|
text = self.samples[idx]['cap'][0] |
|
|
|
text = text_preprocessing(text) |
|
text_tokens_and_mask = self.tokenizer( |
|
text, |
|
max_length=self.model_max_length, |
|
padding='max_length', |
|
truncation=True, |
|
return_attention_mask=True, |
|
add_special_tokens=True, |
|
return_tensors='pt' |
|
) |
|
input_ids = text_tokens_and_mask['input_ids'].squeeze(0) |
|
cond_mask = text_tokens_and_mask['attention_mask'].squeeze(0) |
|
|
|
if self.use_image_num != 0 and self.use_img_from_vid: |
|
select_image_idx = np.linspace(0, self.num_frames-1, self.use_image_num, dtype=int) |
|
assert self.num_frames >= self.use_image_num |
|
images = video[:, select_image_idx] |
|
video = torch.cat([video, images], dim=1) |
|
input_ids = torch.stack([input_ids] * (1+self.use_image_num)) |
|
cond_mask = torch.stack([cond_mask] * (1+self.use_image_num)) |
|
elif self.use_image_num != 0 and not self.use_img_from_vid: |
|
images, captions = self.img_cap_list[idx] |
|
raise NotImplementedError |
|
else: |
|
pass |
|
|
|
return video, input_ids, cond_mask |
|
except Exception as e: |
|
print(f'Error with {e}, {self.samples[idx]}') |
|
return self.__getitem__(random.randint(0, self.__len__() - 1)) |
|
|
|
def tv_read(self, path): |
|
vframes, aframes, info = torchvision.io.read_video(filename=path, pts_unit='sec', output_format='TCHW') |
|
total_frames = len(vframes) |
|
|
|
|
|
start_frame_ind, end_frame_ind = self.temporal_sample(total_frames) |
|
|
|
frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, self.num_frames, dtype=int) |
|
|
|
video = vframes[frame_indice] |
|
|
|
return video |
|
|
|
def decord_read(self, path): |
|
decord_vr = self.v_decoder(path) |
|
total_frames = len(decord_vr) |
|
|
|
start_frame_ind, end_frame_ind = self.temporal_sample(total_frames) |
|
|
|
frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, self.num_frames, dtype=int) |
|
|
|
video_data = decord_vr.get_batch(frame_indice).asnumpy() |
|
video_data = torch.from_numpy(video_data) |
|
video_data = video_data.permute(0, 3, 1, 2) |
|
return video_data |
|
|
|
def get_img_cap_list(self): |
|
raise NotImplementedError |