|
import json |
|
import os |
|
import torch |
|
import random |
|
import torch.utils.data as data |
|
|
|
import numpy as np |
|
from glob import glob |
|
from PIL import Image |
|
from torch.utils.data import Dataset |
|
from tqdm import tqdm |
|
|
|
from opensora.dataset.transform import center_crop, RandomCropVideo |
|
from opensora.utils.dataset_utils import DecordInit |
|
|
|
|
|
class T2V_Feature_dataset(Dataset): |
|
def __init__(self, args, temporal_sample): |
|
|
|
self.video_folder = args.video_folder |
|
self.num_frames = args.video_length |
|
self.temporal_sample = temporal_sample |
|
|
|
print('Building dataset...') |
|
if os.path.exists('samples_430k.json'): |
|
with open('samples_430k.json', 'r') as f: |
|
self.samples = json.load(f) |
|
else: |
|
self.samples = self._make_dataset() |
|
with open('samples_430k.json', 'w') as f: |
|
json.dump(self.samples, f, indent=2) |
|
|
|
self.use_image_num = args.use_image_num |
|
self.use_img_from_vid = args.use_img_from_vid |
|
if self.use_image_num != 0 and not self.use_img_from_vid: |
|
self.img_cap_list = self.get_img_cap_list() |
|
|
|
def _make_dataset(self): |
|
all_mp4 = list(glob(os.path.join(self.video_folder, '**', '*.mp4'), recursive=True)) |
|
|
|
samples = [] |
|
for i in tqdm(all_mp4): |
|
video_id = os.path.basename(i).split('.')[0] |
|
ae = os.path.split(i)[0].replace('data_split_tt', 'lb_causalvideovae444_feature') |
|
ae = os.path.join(ae, f'{video_id}_causalvideovae444.npy') |
|
if not os.path.exists(ae): |
|
continue |
|
|
|
t5 = os.path.split(i)[0].replace('data_split_tt', 'lb_t5_feature') |
|
cond_list = [] |
|
cond_llava = os.path.join(t5, f'{video_id}_t5_llava_fea.npy') |
|
mask_llava = os.path.join(t5, f'{video_id}_t5_llava_mask.npy') |
|
if os.path.exists(cond_llava) and os.path.exists(mask_llava): |
|
llava = dict(cond=cond_llava, mask=mask_llava) |
|
cond_list.append(llava) |
|
cond_sharegpt4v = os.path.join(t5, f'{video_id}_t5_sharegpt4v_fea.npy') |
|
mask_sharegpt4v = os.path.join(t5, f'{video_id}_t5_sharegpt4v_mask.npy') |
|
if os.path.exists(cond_sharegpt4v) and os.path.exists(mask_sharegpt4v): |
|
sharegpt4v = dict(cond=cond_sharegpt4v, mask=mask_sharegpt4v) |
|
cond_list.append(sharegpt4v) |
|
if len(cond_list) > 0: |
|
sample = dict(ae=ae, t5=cond_list) |
|
samples.append(sample) |
|
return samples |
|
|
|
def __len__(self): |
|
return len(self.samples) |
|
|
|
def __getitem__(self, idx): |
|
|
|
sample = self.samples[idx] |
|
ae, t5 = sample['ae'], sample['t5'] |
|
t5 = random.choice(t5) |
|
video_origin = np.load(ae)[0] |
|
_, total_frames, _, _ = video_origin.shape |
|
|
|
start_frame_ind, end_frame_ind = self.temporal_sample(total_frames) |
|
assert end_frame_ind - start_frame_ind >= self.num_frames |
|
select_video_idx = np.linspace(start_frame_ind, end_frame_ind - 1, num=self.num_frames, dtype=int) |
|
|
|
video = video_origin[:, select_video_idx] |
|
video = torch.from_numpy(video) |
|
|
|
cond = torch.from_numpy(np.load(t5['cond']))[0] |
|
cond_mask = torch.from_numpy(np.load(t5['mask']))[0] |
|
|
|
if self.use_image_num != 0 and self.use_img_from_vid: |
|
select_image_idx = np.random.randint(0, total_frames, self.use_image_num) |
|
|
|
images = video_origin[:, select_image_idx] |
|
images = torch.from_numpy(images) |
|
video = torch.cat([video, images], dim=1) |
|
cond = torch.stack([cond] * (1+self.use_image_num)) |
|
cond_mask = torch.stack([cond_mask] * (1+self.use_image_num)) |
|
elif self.use_image_num != 0 and not self.use_img_from_vid: |
|
images, captions = self.img_cap_list[idx] |
|
raise NotImplementedError |
|
else: |
|
pass |
|
|
|
return video, cond, cond_mask |
|
|
|
|
|
|
|
|
|
def get_img_cap_list(self): |
|
raise NotImplementedError |
|
|
|
|
|
|
|
|
|
class T2V_T5_Feature_dataset(Dataset): |
|
def __init__(self, args, transform, temporal_sample): |
|
|
|
self.video_folder = args.video_folder |
|
self.num_frames = args.num_frames |
|
self.transform = transform |
|
self.temporal_sample = temporal_sample |
|
self.v_decoder = DecordInit() |
|
|
|
print('Building dataset...') |
|
if os.path.exists('samples_430k.json'): |
|
with open('samples_430k.json', 'r') as f: |
|
self.samples = json.load(f) |
|
self.samples = [dict(ae=i['ae'].replace('lb_causalvideovae444_feature', 'data_split_1024').replace('_causalvideovae444.npy', '.mp4'), t5=i['t5']) for i in self.samples] |
|
else: |
|
self.samples = self._make_dataset() |
|
with open('samples_430k.json', 'w') as f: |
|
json.dump(self.samples, f, indent=2) |
|
|
|
self.use_image_num = args.use_image_num |
|
self.use_img_from_vid = args.use_img_from_vid |
|
if self.use_image_num != 0 and not self.use_img_from_vid: |
|
self.img_cap_list = self.get_img_cap_list() |
|
|
|
def _make_dataset(self): |
|
all_mp4 = list(glob(os.path.join(self.video_folder, '**', '*.mp4'), recursive=True)) |
|
|
|
samples = [] |
|
for i in tqdm(all_mp4): |
|
video_id = os.path.basename(i).split('.')[0] |
|
|
|
|
|
ae = i |
|
if not os.path.exists(ae): |
|
continue |
|
|
|
t5 = os.path.split(i)[0].replace('data_split_1024', 'lb_t5_feature') |
|
cond_list = [] |
|
cond_llava = os.path.join(t5, f'{video_id}_t5_llava_fea.npy') |
|
mask_llava = os.path.join(t5, f'{video_id}_t5_llava_mask.npy') |
|
if os.path.exists(cond_llava) and os.path.exists(mask_llava): |
|
llava = dict(cond=cond_llava, mask=mask_llava) |
|
cond_list.append(llava) |
|
cond_sharegpt4v = os.path.join(t5, f'{video_id}_t5_sharegpt4v_fea.npy') |
|
mask_sharegpt4v = os.path.join(t5, f'{video_id}_t5_sharegpt4v_mask.npy') |
|
if os.path.exists(cond_sharegpt4v) and os.path.exists(mask_sharegpt4v): |
|
sharegpt4v = dict(cond=cond_sharegpt4v, mask=mask_sharegpt4v) |
|
cond_list.append(sharegpt4v) |
|
if len(cond_list) > 0: |
|
sample = dict(ae=ae, t5=cond_list) |
|
samples.append(sample) |
|
return samples |
|
|
|
def __len__(self): |
|
return len(self.samples) |
|
|
|
def __getitem__(self, idx): |
|
try: |
|
sample = self.samples[idx] |
|
ae, t5 = sample['ae'], sample['t5'] |
|
t5 = random.choice(t5) |
|
|
|
video = self.decord_read(ae) |
|
video = self.transform(video) |
|
video = video.transpose(0, 1) |
|
total_frames = video.shape[1] |
|
cond = torch.from_numpy(np.load(t5['cond']))[0] |
|
cond_mask = torch.from_numpy(np.load(t5['mask']))[0] |
|
|
|
if self.use_image_num != 0 and self.use_img_from_vid: |
|
select_image_idx = np.random.randint(0, total_frames, self.use_image_num) |
|
|
|
images = video.numpy()[:, select_image_idx] |
|
images = torch.from_numpy(images) |
|
video = torch.cat([video, images], dim=1) |
|
cond = torch.stack([cond] * (1+self.use_image_num)) |
|
cond_mask = torch.stack([cond_mask] * (1+self.use_image_num)) |
|
elif self.use_image_num != 0 and not self.use_img_from_vid: |
|
images, captions = self.img_cap_list[idx] |
|
raise NotImplementedError |
|
else: |
|
pass |
|
|
|
return video, cond, cond_mask |
|
except Exception as e: |
|
print(f'Error with {e}, {sample}') |
|
return self.__getitem__(random.randint(0, self.__len__() - 1)) |
|
|
|
def decord_read(self, path): |
|
decord_vr = self.v_decoder(path) |
|
total_frames = len(decord_vr) |
|
|
|
start_frame_ind, end_frame_ind = self.temporal_sample(total_frames) |
|
|
|
frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, self.num_frames, dtype=int) |
|
video_data = decord_vr.get_batch(frame_indice).asnumpy() |
|
video_data = torch.from_numpy(video_data) |
|
video_data = video_data.permute(0, 3, 1, 2) |
|
return video_data |
|
|
|
def get_img_cap_list(self): |
|
raise NotImplementedError |