import numpy as np from pathlib import Path from PIL import Image import json import torch import torch.nn.functional as F from torch.utils.data import Dataset, DataLoader, default_collate from torchvision.transforms import ToTensor, Normalize, Compose, Resize from torchvision.transforms.functional import to_tensor from pytorch_lightning import LightningDataModule from einops import rearrange def read_camera_matrix_single(json_file): # for gobjaverse with open(json_file, "r", encoding="utf8") as reader: json_content = json.load(reader) # negative sign for opencv to opengl camera_matrix = torch.zeros(3, 4) camera_matrix[:3, 0] = torch.tensor(json_content["x"]) camera_matrix[:3, 1] = -torch.tensor(json_content["y"]) camera_matrix[:3, 2] = -torch.tensor(json_content["z"]) camera_matrix[:3, 3] = torch.tensor(json_content["origin"]) """ camera_matrix = np.eye(4) camera_matrix[:3, 0] = np.array(json_content['x']) camera_matrix[:3, 1] = np.array(json_content['y']) camera_matrix[:3, 2] = np.array(json_content['z']) camera_matrix[:3, 3] = np.array(json_content['origin']) # print(camera_matrix) """ return camera_matrix def read_camera_instrinsics_single(json_file, h: int, w: int, scale: float = 1.0): with open(json_file, "r", encoding="utf8") as reader: json_content = json.load(reader) h = int(h * scale) w = int(w * scale) y_fov = json_content["y_fov"] x_fov = json_content["x_fov"] fy = h / 2 / np.tan(y_fov / 2) fx = w / 2 / np.tan(x_fov / 2) cx = w // 2 cy = h // 2 intrinsics = torch.tensor( [ [fx, fy], [cx, cy], [w, h], ], dtype=torch.float32, ) return intrinsics def compose_extrinsic_RT(RT: torch.Tensor): """ Compose the standard form extrinsic matrix from RT. Batched I/O. """ return torch.cat( [ RT, torch.tensor([[[0, 0, 0, 1]]], dtype=torch.float32).repeat( RT.shape[0], 1, 1 ), ], dim=1, ) def get_normalized_camera_intrinsics(intrinsics: torch.Tensor): """ intrinsics: (N, 3, 2), [[fx, fy], [cx, cy], [width, height]] Return batched fx, fy, cx, cy """ fx, fy = intrinsics[:, 0, 0], intrinsics[:, 0, 1] cx, cy = intrinsics[:, 1, 0], intrinsics[:, 1, 1] width, height = intrinsics[:, 2, 0], intrinsics[:, 2, 1] fx, fy = fx / width, fy / height cx, cy = cx / width, cy / height return fx, fy, cx, cy def build_camera_standard(RT: torch.Tensor, intrinsics: torch.Tensor): """ RT: (N, 3, 4) intrinsics: (N, 3, 2), [[fx, fy], [cx, cy], [width, height]] """ E = compose_extrinsic_RT(RT) fx, fy, cx, cy = get_normalized_camera_intrinsics(intrinsics) I = torch.stack( [ torch.stack([fx, torch.zeros_like(fx), cx], dim=-1), torch.stack([torch.zeros_like(fy), fy, cy], dim=-1), torch.tensor([[0, 0, 1]], dtype=torch.float32).repeat(RT.shape[0], 1), ], dim=1, ) return torch.cat( [ E.reshape(-1, 16), I.reshape(-1, 9), ], dim=-1, ) def calc_elevation(c2w): ## works for single or batched c2w ## assume world up is (0, 0, 1) pos = c2w[..., :3, 3] return np.arcsin(pos[..., 2] / np.linalg.norm(pos, axis=-1, keepdims=False)) def read_camera_matrix_single(json_file): with open(json_file, "r", encoding="utf8") as reader: json_content = json.load(reader) # negative sign for opencv to opengl # camera_matrix = np.zeros([3, 4]) # camera_matrix[:3, 0] = np.array(json_content["x"]) # camera_matrix[:3, 1] = -np.array(json_content["y"]) # camera_matrix[:3, 2] = -np.array(json_content["z"]) # camera_matrix[:3, 3] = np.array(json_content["origin"]) camera_matrix = torch.zeros([3, 4]) camera_matrix[:3, 0] = torch.tensor(json_content["x"]) camera_matrix[:3, 1] = -torch.tensor(json_content["y"]) camera_matrix[:3, 2] = -torch.tensor(json_content["z"]) camera_matrix[:3, 3] = torch.tensor(json_content["origin"]) """ camera_matrix = np.eye(4) camera_matrix[:3, 0] = np.array(json_content['x']) camera_matrix[:3, 1] = np.array(json_content['y']) camera_matrix[:3, 2] = np.array(json_content['z']) camera_matrix[:3, 3] = np.array(json_content['origin']) # print(camera_matrix) """ return camera_matrix def blend_white_bg(image): new_image = Image.new("RGB", image.size, (255, 255, 255)) new_image.paste(image, mask=image.split()[3]) return new_image def flatten_for_video(input): return input.flatten() FLATTEN_FIELDS = ["fps_id", "motion_bucket_id", "cond_aug", "elevation"] def video_collate_fn(batch: list[dict], *args, **kwargs): out = {} for key in batch[0].keys(): if key in FLATTEN_FIELDS: out[key] = default_collate([item[key] for item in batch]) out[key] = flatten_for_video(out[key]) elif key == "num_video_frames": out[key] = batch[0][key] elif key in ["frames", "latents", "rgb"]: out[key] = default_collate([item[key] for item in batch]) out[key] = rearrange(out[key], "b t c h w -> (b t) c h w") else: out[key] = default_collate([item[key] for item in batch]) if "pixelnerf_input" in out: out["pixelnerf_input"]["rgb"] = rearrange( out["pixelnerf_input"]["rgb"], "b t c h w -> (b t) c h w" ) return out class GObjaverse(Dataset): def __init__( self, root_dir, split="train", transform=None, random_front=False, max_item=None, cond_aug_mean=-3.0, cond_aug_std=0.5, condition_on_elevation=False, fps_id=0.0, motion_bucket_id=300.0, use_latents=False, load_caps=False, front_view_selection="random", load_pixelnerf=False, debug_base_idx=None, scale_pose: bool = False, max_n_cond: int = 1, **unused_kwargs, ): self.root_dir = Path(root_dir) self.split = split self.random_front = random_front self.transform = transform self.use_latents = use_latents self.ids = json.load(open(self.root_dir / "valid_uids.json", "r")) self.n_views = 24 self.load_caps = load_caps if self.load_caps: self.caps = json.load(open(self.root_dir / "text_captions_cap3d.json", "r")) self.cond_aug_mean = cond_aug_mean self.cond_aug_std = cond_aug_std self.condition_on_elevation = condition_on_elevation self.fps_id = fps_id self.motion_bucket_id = motion_bucket_id self.load_pixelnerf = load_pixelnerf self.scale_pose = scale_pose self.max_n_cond = max_n_cond if self.use_latents: self.latents_dir = self.root_dir / "latents256" self.clip_dir = self.root_dir / "clip_emb256" self.front_view_selection = front_view_selection if self.front_view_selection == "random": pass elif self.front_view_selection == "fixed": pass elif self.front_view_selection.startswith("clip_score"): self.clip_scores = torch.load(self.root_dir / "clip_score_per_view.pt") self.ids = list(self.clip_scores.keys()) else: raise ValueError( f"Unknown front view selection method {self.front_view_selection}" ) if max_item is not None: self.ids = self.ids[:max_item] ## debug self.ids = self.ids * 10000 if debug_base_idx is not None: print(f"debug mode with base idx: {debug_base_idx}") self.debug_base_idx = debug_base_idx def __getitem__(self, idx: int): if hasattr(self, "debug_base_idx"): idx = (idx + self.debug_base_idx) % len(self.ids) data = {} idx_list = np.arange(self.n_views) # if self.random_front: # roll_idx = np.random.randint(self.n_views) # idx_list = np.roll(idx_list, roll_idx) if self.front_view_selection == "random": roll_idx = np.random.randint(self.n_views) idx_list = np.roll(idx_list, roll_idx) elif self.front_view_selection == "fixed": pass elif self.front_view_selection == "clip_score_softmax": this_clip_score = ( F.softmax(self.clip_scores[self.ids[idx]], dim=-1).cpu().numpy() ) roll_idx = np.random.choice(idx_list, p=this_clip_score) idx_list = np.roll(idx_list, roll_idx) elif self.front_view_selection == "clip_score_max": this_clip_score = ( F.softmax(self.clip_scores[self.ids[idx]], dim=-1).cpu().numpy() ) roll_idx = np.argmax(this_clip_score) idx_list = np.roll(idx_list, roll_idx) frames = [] if not self.use_latents: try: for view_idx in idx_list: frame = Image.open( self.root_dir / "gobjaverse" / self.ids[idx] / f"{view_idx:05d}/{view_idx:05d}.png" ) frames.append(self.transform(frame)) except: idx = 0 frames = [] for view_idx in idx_list: frame = Image.open( self.root_dir / "gobjaverse" / self.ids[idx] / f"{view_idx:05d}/{view_idx:05d}.png" ) frames.append(self.transform(frame)) # a workaround for some bugs in gobjaverse # use idx=0 and the repeat will be resolved when gathering results, valid number of items can be checked by the len of results frames = torch.stack(frames, dim=0) cond = frames[0] cond_aug = np.exp( np.random.randn(1)[0] * self.cond_aug_std + self.cond_aug_mean ) data.update( { "frames": frames, "cond_frames_without_noise": cond, "cond_aug": torch.as_tensor([cond_aug] * self.n_views), "cond_frames": cond + cond_aug * torch.randn_like(cond), "fps_id": torch.as_tensor([self.fps_id] * self.n_views), "motion_bucket_id": torch.as_tensor( [self.motion_bucket_id] * self.n_views ), "num_video_frames": 24, "image_only_indicator": torch.as_tensor([0.0] * self.n_views), } ) else: latents = torch.load(self.latents_dir / f"{self.ids[idx]}.pt")[idx_list] clip_emb = torch.load(self.clip_dir / f"{self.ids[idx]}.pt")[idx_list][0] cond = latents[0] cond_aug = np.exp( np.random.randn(1)[0] * self.cond_aug_std + self.cond_aug_mean ) data.update( { "latents": latents, "cond_frames_without_noise": clip_emb, "cond_aug": torch.as_tensor([cond_aug] * self.n_views), "cond_frames": cond + cond_aug * torch.randn_like(cond), "fps_id": torch.as_tensor([self.fps_id] * self.n_views), "motion_bucket_id": torch.as_tensor( [self.motion_bucket_id] * self.n_views ), "num_video_frames": 24, "image_only_indicator": torch.as_tensor([0.0] * self.n_views), } ) if self.condition_on_elevation: sample_c2w = read_camera_matrix_single( self.root_dir / self.ids[idx] / f"00000/00000.json" ) elevation = calc_elevation(sample_c2w) data["elevation"] = torch.as_tensor([elevation] * self.n_views) if self.load_pixelnerf: assert "frames" in data, f"pixelnerf cannot work with latents only mode" data["pixelnerf_input"] = {} RTs = [] intrinsics = [] for view_idx in idx_list: meta = ( self.root_dir / "gobjaverse" / self.ids[idx] / f"{view_idx:05d}/{view_idx:05d}.json" ) RTs.append(read_camera_matrix_single(meta)[:3]) intrinsics.append(read_camera_instrinsics_single(meta, 256, 256)) RTs = torch.stack(RTs, dim=0) intrinsics = torch.stack(intrinsics, dim=0) cameras = build_camera_standard(RTs, intrinsics) data["pixelnerf_input"]["cameras"] = cameras downsampled = [] for view_idx in idx_list: frame = Image.open( self.root_dir / "gobjaverse" / self.ids[idx] / f"{view_idx:05d}/{view_idx:05d}.png" ).resize((32, 32)) downsampled.append(to_tensor(blend_white_bg(frame))) data["pixelnerf_input"]["rgb"] = torch.stack(downsampled, dim=0) data["pixelnerf_input"]["frames"] = data["frames"] if self.scale_pose: c2ws = cameras[..., :16].reshape(-1, 4, 4) center = c2ws[:, :3, 3].mean(0) radius = (c2ws[:, :3, 3] - center).norm(dim=-1).max() scale = 1.5 / radius c2ws[..., :3, 3] = (c2ws[..., :3, 3] - center) * scale cameras[..., :16] = c2ws.reshape(-1, 16) if self.load_caps: data["caption"] = self.caps[self.ids[idx]] data["ids"] = self.ids[idx] return data def __len__(self): return len(self.ids) def collate_fn(self, batch): if self.max_n_cond > 1: n_cond = np.random.randint(1, self.max_n_cond + 1) if n_cond > 1: for b in batch: source_index = [0] + np.random.choice( np.arange(1, self.n_views), self.max_n_cond - 1, replace=False, ).tolist() b["pixelnerf_input"]["source_index"] = torch.as_tensor(source_index) b["pixelnerf_input"]["n_cond"] = n_cond b["pixelnerf_input"]["source_images"] = b["frames"][source_index] b["pixelnerf_input"]["source_cameras"] = b["pixelnerf_input"][ "cameras" ][source_index] return video_collate_fn(batch) class ObjaverseSpiral(Dataset): def __init__( self, root_dir, split="train", transform=None, random_front=False, max_item=None, cond_aug_mean=-3.0, cond_aug_std=0.5, condition_on_elevation=False, **unused_kwargs, ): self.root_dir = Path(root_dir) self.split = split self.random_front = random_front self.transform = transform self.ids = json.load(open(self.root_dir / f"{split}_ids.json", "r")) self.n_views = 24 valid_ids = [] for idx in self.ids: if (self.root_dir / idx).exists(): valid_ids.append(idx) self.ids = valid_ids self.cond_aug_mean = cond_aug_mean self.cond_aug_std = cond_aug_std self.condition_on_elevation = condition_on_elevation if max_item is not None: self.ids = self.ids[:max_item] ## debug self.ids = self.ids * 10000 def __getitem__(self, idx: int): frames = [] idx_list = np.arange(self.n_views) if self.random_front: roll_idx = np.random.randint(self.n_views) idx_list = np.roll(idx_list, roll_idx) for view_idx in idx_list: frame = Image.open( self.root_dir / self.ids[idx] / f"{view_idx:05d}/{view_idx:05d}.png" ) frames.append(self.transform(frame)) # data = {"jpg": torch.stack(frames, dim=0)} # [T, C, H, W] frames = torch.stack(frames, dim=0) cond = frames[0] cond_aug = np.exp( np.random.randn(1)[0] * self.cond_aug_std + self.cond_aug_mean ) data = { "frames": frames, "cond_frames_without_noise": cond, "cond_aug": torch.as_tensor([cond_aug] * self.n_views), "cond_frames": cond + cond_aug * torch.randn_like(cond), "fps_id": torch.as_tensor([1.0] * self.n_views), "motion_bucket_id": torch.as_tensor([300.0] * self.n_views), "num_video_frames": 24, "image_only_indicator": torch.as_tensor([0.0] * self.n_views), } if self.condition_on_elevation: sample_c2w = read_camera_matrix_single( self.root_dir / self.ids[idx] / f"00000/00000.json" ) elevation = calc_elevation(sample_c2w) data["elevation"] = torch.as_tensor([elevation] * self.n_views) return data def __len__(self): return len(self.ids) class ObjaverseLVISSpiral(Dataset): def __init__( self, root_dir, split="train", transform=None, random_front=False, max_item=None, cond_aug_mean=-3.0, cond_aug_std=0.5, condition_on_elevation=False, use_precomputed_latents=False, **unused_kwargs, ): print("Using LVIS subset") self.root_dir = Path(root_dir) self.latent_dir = Path("/mnt/vepfs/3Ddataset/render_results/latents512") self.split = split self.random_front = random_front self.transform = transform self.use_precomputed_latents = use_precomputed_latents self.ids = json.load(open("./assets/lvis_uids.json", "r")) self.n_views = 18 valid_ids = [] for idx in self.ids: if (self.root_dir / idx).exists(): valid_ids.append(idx) self.ids = valid_ids print("=" * 30) print("Number of valid ids: ", len(self.ids)) print("=" * 30) self.cond_aug_mean = cond_aug_mean self.cond_aug_std = cond_aug_std self.condition_on_elevation = condition_on_elevation if max_item is not None: self.ids = self.ids[:max_item] ## debug self.ids = self.ids * 10000 def __getitem__(self, idx: int): frames = [] idx_list = np.arange(self.n_views) if self.random_front: roll_idx = np.random.randint(self.n_views) idx_list = np.roll(idx_list, roll_idx) for view_idx in idx_list: frame = Image.open( self.root_dir / self.ids[idx] / "elevations_0" / f"colors_{view_idx * 2}.png" ) frames.append(self.transform(frame)) frames = torch.stack(frames, dim=0) cond = frames[0] cond_aug = np.exp( np.random.randn(1)[0] * self.cond_aug_std + self.cond_aug_mean ) data = { "frames": frames, "cond_frames_without_noise": cond, "cond_aug": torch.as_tensor([cond_aug] * self.n_views), "cond_frames": cond + cond_aug * torch.randn_like(cond), "fps_id": torch.as_tensor([0.0] * self.n_views), "motion_bucket_id": torch.as_tensor([300.0] * self.n_views), "num_video_frames": self.n_views, "image_only_indicator": torch.as_tensor([0.0] * self.n_views), } if self.use_precomputed_latents: data["latents"] = torch.load(self.latent_dir / f"{self.ids[idx]}.pt") if self.condition_on_elevation: # sample_c2w = read_camera_matrix_single( # self.root_dir / self.ids[idx] / f"00000/00000.json" # ) # elevation = calc_elevation(sample_c2w) # data["elevation"] = torch.as_tensor([elevation] * self.n_views) assert False, "currently assumes elevation 0" return data def __len__(self): return len(self.ids) class ObjaverseALLSpiral(ObjaverseLVISSpiral): def __init__( self, root_dir, split="train", transform=None, random_front=False, max_item=None, cond_aug_mean=-3.0, cond_aug_std=0.5, condition_on_elevation=False, use_precomputed_latents=False, **unused_kwargs, ): print("Using ALL objects in Objaverse") self.root_dir = Path(root_dir) self.split = split self.random_front = random_front self.transform = transform self.use_precomputed_latents = use_precomputed_latents self.latent_dir = Path("/mnt/vepfs/3Ddataset/render_results/latents512") self.ids = json.load(open("./assets/all_ids.json", "r")) self.n_views = 18 valid_ids = [] for idx in self.ids: if (self.root_dir / idx).exists() and (self.root_dir / idx).is_dir(): valid_ids.append(idx) self.ids = valid_ids print("=" * 30) print("Number of valid ids: ", len(self.ids)) print("=" * 30) self.cond_aug_mean = cond_aug_mean self.cond_aug_std = cond_aug_std self.condition_on_elevation = condition_on_elevation if max_item is not None: self.ids = self.ids[:max_item] ## debug self.ids = self.ids * 10000 class ObjaverseWithPose(Dataset): def __init__( self, root_dir, split="train", transform=None, random_front=False, max_item=None, cond_aug_mean=-3.0, cond_aug_std=0.5, condition_on_elevation=False, use_precomputed_latents=False, **unused_kwargs, ): print("Using Objaverse with poses") self.root_dir = Path(root_dir) self.split = split self.random_front = random_front self.transform = transform self.use_precomputed_latents = use_precomputed_latents self.latent_dir = Path("/mnt/vepfs/3Ddataset/render_results/latents512") self.ids = json.load(open("./assets/all_ids.json", "r")) self.n_views = 18 valid_ids = [] for idx in self.ids: if (self.root_dir / idx).exists() and (self.root_dir / idx).is_dir(): valid_ids.append(idx) self.ids = valid_ids print("=" * 30) print("Number of valid ids: ", len(self.ids)) print("=" * 30) self.cond_aug_mean = cond_aug_mean self.cond_aug_std = cond_aug_std self.condition_on_elevation = condition_on_elevation def __getitem__(self, idx: int): frames = [] idx_list = np.arange(self.n_views) if self.random_front: roll_idx = np.random.randint(self.n_views) idx_list = np.roll(idx_list, roll_idx) for view_idx in idx_list: frame = Image.open( self.root_dir / self.ids[idx] / "elevations_0" / f"colors_{view_idx * 2}.png" ) frames.append(self.transform(frame)) frames = torch.stack(frames, dim=0) cond = frames[0] cond_aug = np.exp( np.random.randn(1)[0] * self.cond_aug_std + self.cond_aug_mean ) data = { "frames": frames, "cond_frames_without_noise": cond, "cond_aug": torch.as_tensor([cond_aug] * self.n_views), "cond_frames": cond + cond_aug * torch.randn_like(cond), "fps_id": torch.as_tensor([0.0] * self.n_views), "motion_bucket_id": torch.as_tensor([300.0] * self.n_views), "num_video_frames": self.n_views, "image_only_indicator": torch.as_tensor([0.0] * self.n_views), } if self.use_precomputed_latents: data["latents"] = torch.load(self.latent_dir / f"{self.ids[idx]}.pt") if self.condition_on_elevation: assert False, "currently assumes elevation 0" return data class LatentObjaverse(Dataset): def __init__( self, root_dir, split="train", random_front=False, subset="lvis", fps_id=1.0, motion_bucket_id=300.0, cond_aug_mean=-3.0, cond_aug_std=0.5, **unused_kwargs, ): self.root_dir = Path(root_dir) self.split = split self.random_front = random_front self.ids = json.load(open(Path("./assets") / f"{subset}_ids.json", "r")) self.clip_emb_dir = self.root_dir / ".." / "clip_emb512" self.n_views = 18 self.fps_id = fps_id self.motion_bucket_id = motion_bucket_id self.cond_aug_mean = cond_aug_mean self.cond_aug_std = cond_aug_std if self.random_front: print("Using a random view as front view") valid_ids = [] for idx in self.ids: if (self.root_dir / f"{idx}.pt").exists() and ( self.clip_emb_dir / f"{idx}.pt" ).exists(): valid_ids.append(idx) self.ids = valid_ids print("=" * 30) print("Number of valid ids: ", len(self.ids)) print("=" * 30) def __getitem__(self, idx: int): uid = self.ids[idx] idx_list = torch.arange(self.n_views) latents = torch.load(self.root_dir / f"{uid}.pt") clip_emb = torch.load(self.clip_emb_dir / f"{uid}.pt") if self.random_front: idx_list = torch.roll(idx_list, np.random.randint(self.n_views)) latents = latents[idx_list] clip_emb = clip_emb[idx_list][0] cond_aug = np.exp( np.random.randn(1)[0] * self.cond_aug_std + self.cond_aug_mean ) cond = latents[0] data = { "latents": latents, "cond_frames_without_noise": clip_emb, "cond_frames": cond + cond_aug * torch.randn_like(cond), "fps_id": torch.as_tensor([self.fps_id] * self.n_views), "motion_bucket_id": torch.as_tensor([self.motion_bucket_id] * self.n_views), "cond_aug": torch.as_tensor([cond_aug] * self.n_views), "num_video_frames": self.n_views, "image_only_indicator": torch.as_tensor([0.0] * self.n_views), } return data def __len__(self): return len(self.ids) class ObjaverseSpiralDataset(LightningDataModule): def __init__( self, root_dir, random_front=False, batch_size=2, num_workers=10, prefetch_factor=2, shuffle=True, max_item=None, dataset_cls="richdreamer", reso: int = 256, **kwargs, ) -> None: super().__init__() self.batch_size = batch_size self.num_workers = num_workers self.prefetch_factor = prefetch_factor self.shuffle = shuffle self.max_item = max_item self.transform = Compose( [ blend_white_bg, Resize((reso, reso)), ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ] ) data_cls = { "richdreamer": ObjaverseSpiral, "lvis": ObjaverseLVISSpiral, "shengshu_all": ObjaverseALLSpiral, "latent": LatentObjaverse, "gobjaverse": GObjaverse, }[dataset_cls] self.train_dataset = data_cls( root_dir=root_dir, split="train", random_front=random_front, transform=self.transform, max_item=self.max_item, **kwargs, ) self.test_dataset = data_cls( root_dir=root_dir, split="val", random_front=random_front, transform=self.transform, max_item=self.max_item, **kwargs, ) def train_dataloader(self): return DataLoader( self.train_dataset, batch_size=self.batch_size, shuffle=self.shuffle, num_workers=self.num_workers, prefetch_factor=self.prefetch_factor, collate_fn=video_collate_fn if not hasattr(self.train_dataset, "collate_fn") else self.train_dataset.collate_fn, ) def test_dataloader(self): return DataLoader( self.test_dataset, batch_size=self.batch_size, shuffle=self.shuffle, num_workers=self.num_workers, prefetch_factor=self.prefetch_factor, collate_fn=video_collate_fn if not hasattr(self.test_dataset, "collate_fn") else self.train_dataset.collate_fn, ) def val_dataloader(self): return DataLoader( self.test_dataset, batch_size=self.batch_size, shuffle=self.shuffle, num_workers=self.num_workers, prefetch_factor=self.prefetch_factor, collate_fn=video_collate_fn if not hasattr(self.test_dataset, "collate_fn") else self.train_dataset.collate_fn, )