import os from tqdm import tqdm import kiui from kiui.op import recenter import kornia import collections import math import time import itertools import pickle from typing import Any import lmdb import cv2 import trimesh cv2.setNumThreads(0) # disable multiprocess # import imageio import imageio.v3 as imageio import numpy as np from PIL import Image import Imath import OpenEXR from pdb import set_trace as st from pathlib import Path import torchvision from torchvision.transforms import v2 from einops import rearrange, repeat from functools import partial import io from scipy.stats import special_ortho_group import gzip import random import torch import torch as th from torch import nn import torch.nn.functional as F from torch.utils.data import DataLoader, Dataset from torchvision import transforms from torch.utils.data.distributed import DistributedSampler from pathlib import Path import lz4.frame from nsr.volumetric_rendering.ray_sampler import RaySampler import point_cloud_utils as pcu import torch.multiprocessing # torch.multiprocessing.set_sharing_strategy('file_system') from utils.general_utils import PILtoTorch, matrix_to_quaternion from guided_diffusion import logger import json import webdataset as wds from webdataset.shardlists import expand_source # st() from .shapenet import LMDBDataset, LMDBDataset_MV_Compressed, decompress_and_open_image_gzip, decompress_array from kiui.op import safe_normalize from utils.gs_utils.graphics_utils import getWorld2View2, getProjectionMatrix, getView2World from nsr.camera_utils import generate_input_camera def random_rotation_matrix(): # Generate a random rotation matrix in 3D random_rotation_3d = special_ortho_group.rvs(3) # Embed the 3x3 rotation matrix into a 4x4 matrix rotation_matrix_4x4 = np.eye(4) rotation_matrix_4x4[:3, :3] = random_rotation_3d return rotation_matrix_4x4 def fov2focal(fov, pixels): return pixels / (2 * math.tan(fov / 2)) def focal2fov(focal, pixels): return 2 * math.atan(pixels / (2 * focal)) def resize_depth_mask(depth_to_resize, resolution): depth_resized = cv2.resize(depth_to_resize, (resolution, resolution), interpolation=cv2.INTER_LANCZOS4) # interpolation=cv2.INTER_AREA) return depth_resized, depth_resized > 0 # type: ignore def resize_depth_mask_Tensor(depth_to_resize, resolution): if depth_to_resize.shape[-1] != resolution: depth_resized = torch.nn.functional.interpolate( input=depth_to_resize.unsqueeze(1), size=(resolution, resolution), # mode='bilinear', mode='nearest', # align_corners=False, ).squeeze(1) else: depth_resized = depth_to_resize return depth_resized.float(), depth_resized > 0 # type: ignore class PostProcess: def __init__( self, reso, reso_encoder, imgnet_normalize, plucker_embedding, decode_encode_img_only, mv_input, split_chunk_input, duplicate_sample, append_depth, gs_cam_format, orthog_duplicate, frame_0_as_canonical, pcd_path=None, load_pcd=False, split_chunk_size=8, append_xyz=False, ) -> None: self.load_pcd = load_pcd if pcd_path is None: # hard-coded pcd_path = '/cpfs01/user/lanyushi.p/data/FPS_PCD/pcd-V=6_256_again/fps-pcd/' self.pcd_path = Path(pcd_path) self.append_xyz = append_xyz if append_xyz: assert append_depth is False self.frame_0_as_canonical = frame_0_as_canonical self.gs_cam_format = gs_cam_format self.append_depth = append_depth self.plucker_embedding = plucker_embedding self.decode_encode_img_only = decode_encode_img_only self.duplicate_sample = duplicate_sample self.orthog_duplicate = orthog_duplicate self.zfar = 100.0 self.znear = 0.01 transformations = [] if not split_chunk_input: transformations.append(transforms.ToTensor()) if imgnet_normalize: transformations.append( transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) # type: ignore ) else: transformations.append( transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))) # type: ignore self.normalize = transforms.Compose(transformations) self.reso_encoder = reso_encoder self.reso = reso self.instance_data_length = 40 # self.pair_per_instance = 1 # compat self.mv_input = mv_input self.split_chunk_input = split_chunk_input # 8 self.chunk_size = split_chunk_size if split_chunk_input else 40 # assert self.chunk_size in [8, 10] self.V = self.chunk_size // 2 # 4 views as input # else: # assert self.chunk_size == 20 # self.V = 12 # 6 + 6 here # st() assert split_chunk_input self.pair_per_instance = 1 # else: # self.pair_per_instance = 4 if mv_input else 2 # check whether improves IO self.ray_sampler = RaySampler() # load xyz def gen_rays(self, c): # Generate rays intrinsics, c2w = c[16:], c[:16].reshape(4, 4) self.h = self.reso_encoder self.w = self.reso_encoder yy, xx = torch.meshgrid( torch.arange(self.h, dtype=torch.float32) + 0.5, torch.arange(self.w, dtype=torch.float32) + 0.5, indexing='ij') # normalize to 0-1 pixel range yy = yy / self.h xx = xx / self.w # K = np.array([f_x, 0, w / 2, 0, f_y, h / 2, 0, 0, 1]).reshape(3, 3) cx, cy, fx, fy = intrinsics[2], intrinsics[5], intrinsics[ 0], intrinsics[4] # cx *= self.w # cy *= self.h # f_x = f_y = fx * h / res_raw c2w = torch.from_numpy(c2w).float() xx = (xx - cx) / fx yy = (yy - cy) / fy zz = torch.ones_like(xx) dirs = torch.stack((xx, yy, zz), dim=-1) # OpenCV convention dirs /= torch.norm(dirs, dim=-1, keepdim=True) dirs = dirs.reshape(-1, 3, 1) del xx, yy, zz # st() dirs = (c2w[None, :3, :3] @ dirs)[..., 0] origins = c2w[None, :3, 3].expand(self.h * self.w, -1).contiguous() origins = origins.view(self.h, self.w, 3) dirs = dirs.view(self.h, self.w, 3) return origins, dirs def _post_process_batch_sample(self, sample): # sample is an instance batch here caption, ins = sample[-2:] instance_samples = [] for instance_idx in range(sample[0].shape[0]): instance_samples.append( self._post_process_sample(item[instance_idx] for item in sample[:-2])) return (*instance_samples, caption, ins) def _post_process_sample(self, data_sample): # raw_img, depth, c, bbox, caption, ins = data_sample # st() raw_img, depth, c, bbox = data_sample bbox = (bbox * (self.reso / 256)).astype( np.uint8) # normalize bbox to the reso range if raw_img.shape[-2] != self.reso_encoder: img_to_encoder = cv2.resize(raw_img, (self.reso_encoder, self.reso_encoder), interpolation=cv2.INTER_LANCZOS4) else: img_to_encoder = raw_img img_to_encoder = self.normalize(img_to_encoder) if self.plucker_embedding: rays_o, rays_d = self.gen_rays(c) rays_plucker = torch.cat( [torch.cross(rays_o, rays_d, dim=-1), rays_d], dim=-1).permute(2, 0, 1) # [h, w, 6] -> 6,h,w img_to_encoder = torch.cat([img_to_encoder, rays_plucker], 0) img = cv2.resize(raw_img, (self.reso, self.reso), interpolation=cv2.INTER_LANCZOS4) img = torch.from_numpy(img).permute(2, 0, 1) / 127.5 - 1 if self.decode_encode_img_only: depth_reso, fg_mask_reso = depth, depth else: depth_reso, fg_mask_reso = resize_depth_mask(depth, self.reso) # return { # # **sample, # 'img_to_encoder': img_to_encoder, # 'img': img, # 'depth_mask': fg_mask_reso, # # 'img_sr': img_sr, # 'depth': depth_reso, # 'c': c, # 'bbox': bbox, # 'caption': caption, # 'ins': ins # # ! no need to load img_sr for now # } # if len(data_sample) == 4: return (img_to_encoder, img, fg_mask_reso, depth_reso, c, bbox) # else: # return (img_to_encoder, img, fg_mask_reso, depth_reso, c, bbox, data_sample[-2], data_sample[-1]) def canonicalize_pts(self, c, pcd, for_encoder=True, canonical_idx=0): # pcd: sampled in world space assert c.shape[0] == self.chunk_size assert for_encoder # st() B = c.shape[0] camera_poses = c[:, :16].reshape(B, 4, 4) # 3x4 cam_radius = np.linalg.norm( c[[0, self.V]][:, :16].reshape(2, 4, 4)[:, :3, 3], axis=-1, keepdims=False) # since g-buffer adopts dynamic radius here. frame1_fixed_pos = np.repeat(np.eye(4)[None], 2, axis=0) frame1_fixed_pos[:, 2, -1] = -cam_radius transform = frame1_fixed_pos @ np.linalg.inv(camera_poses[[0, self.V ]]) # B 4 4 transform = np.expand_dims(transform, axis=1) # B 1 4 4 # from LGM, https://github.com/3DTopia/LGM/blob/fe8d12cff8c827df7bb77a3c8e8b37408cb6fe4c/core/provider_objaverse.py#L127 # transform = torch.tensor([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, self.opt.cam_radius], [0, 0, 0, 1]], dtype=torch.float32) @ torch.inverse(c[[0,4]]) repeated_homo_pcd = np.repeat(np.concatenate( [pcd, np.ones_like(pcd[..., 0:1])], -1)[None], 2, axis=0)[..., None] # B N 4 1 new_pcd = (transform @ repeated_homo_pcd)[..., :3, 0] # 2 N 3 return new_pcd def canonicalize_pts_v6(self, c, pcd, for_encoder=True, canonical_idx=0): exit() # deprecated function # pcd: sampled in world space assert c.shape[0] == self.chunk_size assert for_encoder encoder_canonical_idx = [0, 6, 12, 18] B = c.shape[0] camera_poses = c[:, :16].reshape(B, 4, 4) # 3x4 cam_radius = np.linalg.norm( c[encoder_canonical_idx][:, :16].reshape(4, 4, 4)[:, :3, 3], axis=-1, keepdims=False) # since g-buffer adopts dynamic radius here. frame1_fixed_pos = np.repeat(np.eye(4)[None], 4, axis=0) frame1_fixed_pos[:, 2, -1] = -cam_radius transform = frame1_fixed_pos @ np.linalg.inv( camera_poses[encoder_canonical_idx]) # B 4 4 transform = np.expand_dims(transform, axis=1) # B 1 4 4 # from LGM, https://github.com/3DTopia/LGM/blob/fe8d12cff8c827df7bb77a3c8e8b37408cb6fe4c/core/provider_objaverse.py#L127 # transform = torch.tensor([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, self.opt.cam_radius], [0, 0, 0, 1]], dtype=torch.float32) @ torch.inverse(c[[0,4]]) repeated_homo_pcd = np.repeat(np.concatenate( [pcd, np.ones_like(pcd[..., 0:1])], -1)[None], 4, axis=0)[..., None] # B N 4 1 new_pcd = (transform @ repeated_homo_pcd)[..., :3, 0] # 2 N 3 return new_pcd def normalize_camera(self, c, for_encoder=True, canonical_idx=0): assert c.shape[0] == self.chunk_size # 8 o r10 B = c.shape[0] camera_poses = c[:, :16].reshape(B, 4, 4) # 3x4 if for_encoder: encoder_canonical_idx = [0, self.V] # st() cam_radius = np.linalg.norm( c[encoder_canonical_idx][:, :16].reshape(2, 4, 4)[:, :3, 3], axis=-1, keepdims=False) # since g-buffer adopts dynamic radius here. frame1_fixed_pos = np.repeat(np.eye(4)[None], 2, axis=0) frame1_fixed_pos[:, 2, -1] = -cam_radius transform = frame1_fixed_pos @ np.linalg.inv( camera_poses[encoder_canonical_idx]) # from LGM, https://github.com/3DTopia/LGM/blob/fe8d12cff8c827df7bb77a3c8e8b37408cb6fe4c/core/provider_objaverse.py#L127 # transform = torch.tensor([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, self.opt.cam_radius], [0, 0, 0, 1]], dtype=torch.float32) @ torch.inverse(c[[0,4]]) new_camera_poses = np.repeat( transform, self.V, axis=0 ) @ camera_poses # [V, 4, 4]. np.repeat() is th.repeat_interleave() else: cam_radius = np.linalg.norm( c[canonical_idx][:16].reshape(4, 4)[:3, 3], axis=-1, keepdims=False) # since g-buffer adopts dynamic radius here. frame1_fixed_pos = np.eye(4) frame1_fixed_pos[2, -1] = -cam_radius transform = frame1_fixed_pos @ np.linalg.inv( camera_poses[canonical_idx]) # 4,4 # from LGM, https://github.com/3DTopia/LGM/blob/fe8d12cff8c827df7bb77a3c8e8b37408cb6fe4c/core/provider_objaverse.py#L127 # transform = torch.tensor([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, self.opt.cam_radius], [0, 0, 0, 1]], dtype=torch.float32) @ torch.inverse(c[[0,4]]) new_camera_poses = np.repeat(transform[None], self.chunk_size, axis=0) @ camera_poses # [V, 4, 4] c = np.concatenate([new_camera_poses.reshape(B, 16), c[:, 16:]], axis=-1) return c def normalize_camera_v6(self, c, for_encoder=True, canonical_idx=0): B = c.shape[0] camera_poses = c[:, :16].reshape(B, 4, 4) # 3x4 if for_encoder: assert c.shape[0] == 24 encoder_canonical_idx = [0, 6, 12, 18] cam_radius = np.linalg.norm( c[encoder_canonical_idx][:, :16].reshape(4, 4, 4)[:, :3, 3], axis=-1, keepdims=False) # since g-buffer adopts dynamic radius here. frame1_fixed_pos = np.repeat(np.eye(4)[None], 4, axis=0) frame1_fixed_pos[:, 2, -1] = -cam_radius transform = frame1_fixed_pos @ np.linalg.inv( camera_poses[encoder_canonical_idx]) # from LGM, https://github.com/3DTopia/LGM/blob/fe8d12cff8c827df7bb77a3c8e8b37408cb6fe4c/core/provider_objaverse.py#L127 # transform = torch.tensor([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, self.opt.cam_radius], [0, 0, 0, 1]], dtype=torch.float32) @ torch.inverse(c[[0,4]]) new_camera_poses = np.repeat(transform, 6, axis=0) @ camera_poses # [V, 4, 4] else: assert c.shape[0] == 12 cam_radius = np.linalg.norm( c[canonical_idx][:16].reshape(4, 4)[:3, 3], axis=-1, keepdims=False) # since g-buffer adopts dynamic radius here. frame1_fixed_pos = np.eye(4) frame1_fixed_pos[2, -1] = -cam_radius transform = frame1_fixed_pos @ np.linalg.inv( camera_poses[canonical_idx]) # 4,4 # from LGM, https://github.com/3DTopia/LGM/blob/fe8d12cff8c827df7bb77a3c8e8b37408cb6fe4c/core/provider_objaverse.py#L127 # transform = torch.tensor([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, self.opt.cam_radius], [0, 0, 0, 1]], dtype=torch.float32) @ torch.inverse(c[[0,4]]) new_camera_poses = np.repeat(transform[None], 12, axis=0) @ camera_poses # [V, 4, 4] c = np.concatenate([new_camera_poses.reshape(B, 16), c[:, 16:]], axis=-1) return c def get_plucker_ray(self, c): rays_plucker = [] for idx in range(c.shape[0]): rays_o, rays_d = self.gen_rays(c[idx]) rays_plucker.append( torch.cat([torch.cross(rays_o, rays_d, dim=-1), rays_d], dim=-1).permute(2, 0, 1)) # [h, w, 6] -> 6,h,w rays_plucker = torch.stack(rays_plucker, 0) return rays_plucker def _unproj_depth_given_c(self, c, depth): # get xyz hxw for each pixel, like MCC # img_size = self.reso img_size = depth.shape[-1] B = c.shape[0] cam2world_matrix = c[:, :16].reshape(B, 4, 4) intrinsics = c[:, 16:25].reshape(B, 3, 3) ray_origins, ray_directions = self.ray_sampler( # shape: cam2world_matrix, intrinsics, img_size)[:2] depth = depth.reshape(B, -1).unsqueeze(-1) xyz = ray_origins + depth * ray_directions # BV HW 3, already in the world space xyz = xyz.reshape(B, img_size, img_size, 3).permute(0, 3, 1, 2) # B 3 H W xyz = xyz.clip( -0.45, 0.45) # g-buffer saves depth with anti-alias = True ..... xyz = torch.where(xyz.abs() == 0.45, 0, xyz) # no boundary here? Yes. return xyz def _post_process_sample_batch(self, data_sample): # raw_img, depth, c, bbox, caption, ins = data_sample alpha = None if len(data_sample) == 4: raw_img, depth, c, bbox = data_sample else: raw_img, depth, c, alpha, bbox = data_sample # put c to position 2 if isinstance(depth, tuple): self.append_normal = True depth, normal = depth else: self.append_normal = False normal = None # if raw_img.shape[-1] == 4: # depth_reso, _ = resize_depth_mask_Tensor( # torch.from_numpy(depth), self.reso) # raw_img, fg_mask_reso = raw_img[..., :3], raw_img[..., -1] # # st() # ! check has 1 dim in alpha? # else: if not isinstance(depth, torch.Tensor): depth = torch.from_numpy(depth).float() else: depth = depth.float() depth_reso, fg_mask_reso = resize_depth_mask_Tensor(depth, self.reso) if alpha is None: alpha = fg_mask_reso else: # ! resize first # st() alpha = torch.from_numpy(alpha / 255.0).float() if alpha.shape[-1] != self.reso: # bilinear inteprolate reshape alpha = torch.nn.functional.interpolate( input=alpha.unsqueeze(1), size=(self.reso, self.reso), mode='bilinear', align_corners=False, ).squeeze(1) if self.reso < 256: bbox = (bbox * (self.reso / 256)).astype( np.uint8) # normalize bbox to the reso range else: # 3dgs bbox = bbox.astype(np.uint8) # st() # ! shall compat with 320 input # assert raw_img.shape[-2] == self.reso_encoder # img_to_encoder = cv2.resize( # raw_img, (self.reso_encoder, self.reso_encoder), # interpolation=cv2.INTER_LANCZOS4) # else: # img_to_encoder = raw_img raw_img = torch.from_numpy(raw_img).permute(0, 3, 1, 2) / 255.0 # [0,1] if normal is not None: normal = torch.from_numpy(normal).permute(0,3,1,2) # if raw_img.shape[-1] != self.reso: if raw_img.shape[1] != self.reso_encoder: img_to_encoder = torch.nn.functional.interpolate( input=raw_img, size=(self.reso_encoder, self.reso_encoder), mode='bilinear', align_corners=False,) img_to_encoder = self.normalize(img_to_encoder) if normal is not None: normal_for_encoder = torch.nn.functional.interpolate( input=normal, size=(self.reso_encoder, self.reso_encoder), # mode='bilinear', mode='nearest', # align_corners=False, ) else: img_to_encoder = self.normalize(raw_img) normal_for_encoder = normal if raw_img.shape[-1] != self.reso: img = torch.nn.functional.interpolate( input=raw_img, size=(self.reso, self.reso), mode='bilinear', align_corners=False, ) # [-1,1] range img = img * 2 - 1 # as gt if normal is not None: normal = torch.nn.functional.interpolate( input=normal, size=(self.reso, self.reso), # mode='bilinear', mode='nearest', # align_corners=False, ) else: img = raw_img * 2 - 1 # fg_mask_reso = depth[..., -1:] # ! use pad_v6_fn = lambda x: torch.concat([x, x[:4]], 0) if isinstance( x, torch.Tensor) else np.concatenate([x, x[:4]], 0) # ! processing encoder input image. # ! normalize camera feats if self.frame_0_as_canonical: # 4 views as input per batch # if self.chunk_size in [8, 10]: if True: # encoder_canonical_idx = [0, 4] # encoder_canonical_idx = [0, self.chunk_size//2] encoder_canonical_idx = [0, self.V] c_for_encoder = self.normalize_camera(c, for_encoder=True) c_for_render = self.normalize_camera( c, for_encoder=False, canonical_idx=encoder_canonical_idx[0] ) # allocated to nv_c, frame0 (in 8 views) as the canonical c_for_render_nv = self.normalize_camera( c, for_encoder=False, canonical_idx=encoder_canonical_idx[1] ) # allocated to nv_c, frame0 (in 8 views) as the canonical c_for_render = np.concatenate([c_for_render, c_for_render_nv], axis=-1) # for compat # st() else: assert self.chunk_size == 20 c_for_encoder = self.normalize_camera_v6(c, for_encoder=True) # paired_c_0 = np.concatenate([c[0:6], c[12:18]]) paired_c_1 = np.concatenate([c[6:12], c[18:24]]) def process_paired_camera(paired_c): c_for_render = self.normalize_camera_v6( paired_c, for_encoder=False, canonical_idx=0 ) # allocated to nv_c, frame0 (in 8 views) as the canonical c_for_render_nv = self.normalize_camera_v6( paired_c, for_encoder=False, canonical_idx=6 ) # allocated to nv_c, frame0 (in 8 views) as the canonical c_for_render = np.concatenate( [c_for_render, c_for_render_nv], axis=-1) # for compat return c_for_render paired_c_for_render_0 = process_paired_camera(paired_c_0) paired_c_for_render_1 = process_paired_camera(paired_c_1) c_for_render = np.empty(shape=(24, 50)) c_for_render[list(range(6)) + list(range(12, 18))] = paired_c_for_render_0 c_for_render[list(range(6, 12)) + list(range(18, 24))] = paired_c_for_render_1 else: # use g-buffer canonical c c_for_encoder, c_for_render = c, c if self.append_normal and normal is not None: img_to_encoder = torch.cat([img_to_encoder, normal_for_encoder], # img_to_encoder = torch.cat([img_to_encoder, normal], 1) # concat in C dim if self.plucker_embedding: # rays_plucker = self.get_plucker_ray(c) rays_plucker = self.get_plucker_ray(c_for_encoder) img_to_encoder = torch.cat([img_to_encoder, rays_plucker], 1) # concat in C dim # torchvision.utils.save_image(raw_img, 'tmp/inp.png', normalize=True, value_range=(0,1), nrow=1, padding=0) # torchvision.utils.save_image(rays_plucker[:,:3], 'tmp/plucker.png', normalize=True, value_range=(-1,1), nrow=1, padding=0) # torchvision.utils.save_image(depth_reso.unsqueeze(1), 'tmp/depth.png', normalize=True, nrow=1, padding=0) c = torch.from_numpy(c_for_render).to(torch.float32) if self.append_depth: normalized_depth = torch.from_numpy(depth_reso).clone().unsqueeze( 1) # min=0 # normalized_depth -= torch.min(normalized_depth) # always 0 here # normalized_depth /= torch.max(normalized_depth) # normalized_depth = normalized_depth.unsqueeze(1) * 2 - 1 # normalize to [-1,1] # st() img_to_encoder = torch.cat([img_to_encoder, normalized_depth], 1) # concat in C dim elif self.append_xyz: depth_for_unproj = depth.clone() depth_for_unproj[depth_for_unproj == 0] = 1e10 # so that rays_o will not appear in the final pcd. xyz = self._unproj_depth_given_c(c.float(), depth) # pcu.save_mesh_v(f'unproj_xyz_before_Nearest.ply', xyz[0:9].float().detach().permute(0,2,3,1).reshape(-1,3).cpu().numpy(),) if xyz.shape[-1] != self.reso_encoder: xyz = torch.nn.functional.interpolate( input=xyz, # [-1,1] # size=(self.reso, self.reso), size=(self.reso_encoder, self.reso_encoder), mode='nearest', ) # pcu.save_mesh_v(f'unproj_xyz_afterNearest.ply', xyz[0:9].float().detach().permute(0,2,3,1).reshape(-1,3).cpu().numpy(),) # st() img_to_encoder = torch.cat([img_to_encoder, xyz], 1) return (img_to_encoder, img, alpha, depth_reso, c, torch.from_numpy(bbox)) def rand_sample_idx(self): return random.randint(0, self.instance_data_length - 1) def rand_pair(self): return (self.rand_sample_idx() for _ in range(2)) def paired_post_process(self, sample): # repeat n times? all_inp_list = [] all_nv_list = [] caption, ins = sample[-2:] # expanded_return = [] for _ in range(self.pair_per_instance): cano_idx, nv_idx = self.rand_pair() cano_sample = self._post_process_sample(item[cano_idx] for item in sample[:-2]) nv_sample = self._post_process_sample(item[nv_idx] for item in sample[:-2]) all_inp_list.extend(cano_sample) all_nv_list.extend(nv_sample) return (*all_inp_list, *all_nv_list, caption, ins) # return [cano_sample, nv_sample, caption, ins] # return (*cano_sample, *nv_sample, caption, ins) def get_source_cw2wT(self, source_cameras_view_to_world): return matrix_to_quaternion( source_cameras_view_to_world[:3, :3].transpose(0, 1)) def c_to_3dgs_format(self, pose): # TODO, switch to torch version (batched later) c2w = pose[:16].reshape(4, 4) # 3x4 # ! load cam w2c = np.linalg.inv(c2w) R = np.transpose( w2c[:3, :3]) # R is stored transposed due to 'glm' in CUDA code T = w2c[:3, 3] fx = pose[16] FovX = focal2fov(fx, 1) FovY = focal2fov(fx, 1) tanfovx = math.tan(FovX * 0.5) tanfovy = math.tan(FovY * 0.5) assert tanfovx == tanfovy trans = np.array([0.0, 0.0, 0.0]) scale = 1.0 view_world_transform = torch.tensor(getView2World(R, T, trans, scale)).transpose( 0, 1) world_view_transform = torch.tensor(getWorld2View2(R, T, trans, scale)).transpose( 0, 1) projection_matrix = getProjectionMatrix(znear=self.znear, zfar=self.zfar, fovX=FovX, fovY=FovY).transpose(0, 1) full_proj_transform = (world_view_transform.unsqueeze(0).bmm( projection_matrix.unsqueeze(0))).squeeze(0) camera_center = world_view_transform.inverse()[3, :3] # ! check pytorch3d camera system alignment. # item.update(viewpoint_cam=[viewpoint_cam]) c = {} # c["source_cv2wT_quat"] = self.get_source_cw2wT(view_world_transform) c.update( # projection_matrix=projection_matrix, # K R=torch.from_numpy(R), T=torch.from_numpy(T), cam_view=world_view_transform, # world_view_transform cam_view_proj=full_proj_transform, # full_proj_transform cam_pos=camera_center, tanfov=tanfovx, # TODO, fix in the renderer orig_pose=torch.from_numpy(pose), orig_c2w=torch.from_numpy(c2w), orig_w2c=torch.from_numpy(w2c), orig_intrin=torch.from_numpy(pose[16:]).reshape(3,3), # tanfovy=tanfovy, ) return c # dict for gs rendering def paired_post_process_chunk(self, sample): # st() # sample_npz, ins, caption = sample_pyd # three items # sample = *(sample[0][k] for k in ['raw_img', 'depth', 'c', 'bbox']), sample[-1], sample[-2] # repeat n times? all_inp_list = [] all_nv_list = [] auxiliary_sample = list(sample[-2:]) # caption, ins = sample[-2:] ins = sample[-1] assert sample[0].shape[0] == self.chunk_size # random chunks # expanded_return = [] if self.load_pcd: # fps_pcd = pcu.load_mesh_v( # # str(self.pcd_path / ins / 'fps-24576.ply')) # N, 3 # str(self.pcd_path / ins / 'fps-4096.ply')) # N, 3 # # 'fps-4096.ply')) # N, 3 fps_pcd = trimesh.load(str(self.pcd_path / ins / 'fps-4096.ply')).vertices auxiliary_sample += [fps_pcd] assert self.duplicate_sample # st() if self.duplicate_sample: # ! shuffle before process, since frame_0_as_canonical fixed c. if self.chunk_size in [20, 18, 16, 12]: shuffle_sample = sample[:-2] # no order shuffle required else: shuffle_sample = [] # indices = torch.randperm(self.chunk_size) indices = np.random.permutation(self.chunk_size) for _, item in enumerate(sample[:-2]): shuffle_sample.append(item[indices]) # random shuffle processed_sample = self._post_process_sample_batch(shuffle_sample) # ! process pcd if frmae_0 alignment if self.load_pcd: if self.frame_0_as_canonical: # ! normalize camera feats # normalized camera feats as in paper (transform the first pose to a fixed position) # if self.chunk_size == 20: # auxiliary_sample[-1] = self.canonicalize_pts_v6( # c=shuffle_sample[2], # pcd=auxiliary_sample[-1], # for_encoder=True) # B N 3 # else: auxiliary_sample[-1] = self.canonicalize_pts( c=shuffle_sample[2], pcd=auxiliary_sample[-1], for_encoder=True) # B N 3 else: auxiliary_sample[-1] = np.repeat( auxiliary_sample[-1][None], 2, axis=0) # share the same camera syste, just repeat assert not self.orthog_duplicate # if self.chunk_size == 8: all_inp_list.extend(item[:self.V] for item in processed_sample) all_nv_list.extend(item[self.V:] for item in processed_sample) # elif self.chunk_size == 20: # V=6 # # indices_v6 = [np.random.permutation(self.chunk_size)[:12] for _ in range(2)] # random sample 6 views from chunks # all_inp_list.extend(item[:12] for item in processed_sample) # # indices_v6 = np.concatenate([np.arange(12, 20), np.arange(0,4)]) # all_nv_list.extend( # item[12:] for item in # processed_sample) # already repeated inside batch fn # else: # raise NotImplementedError(self.chunk_size) # else: # all_inp_list.extend(item[:8] for item in processed_sample) # all_nv_list.extend(item[8:] for item in processed_sample) # st() return (*all_inp_list, *all_nv_list, *auxiliary_sample) else: processed_sample = self._post_process_sample_batch( # avoid shuffle shorten processing time item[:4] for item in sample[:-2]) all_inp_list.extend(item for item in processed_sample) all_nv_list.extend(item for item in processed_sample) # ! placeholder # return (*all_inp_list, *all_nv_list, caption, ins) return (*all_inp_list, *all_nv_list, *auxiliary_sample) # randomly shuffle 8 views, avoid overfitting def single_sample_create_dict_noBatch(self, sample, prefix=''): # if len(sample) == 1: # sample = sample[0] # assert len(sample) == 6 img_to_encoder, img, fg_mask_reso, depth_reso, c, bbox = sample if self.gs_cam_format: # TODO, can optimize later after model converges B, V, _ = c.shape # B 4 25 c = rearrange(c, 'B V C -> (B V) C').cpu().numpy() # c = c.cpu().numpy() all_gs_c = [self.c_to_3dgs_format(pose) for pose in c] # st() # all_gs_c = self.c_to_3dgs_format(c.cpu().numpy()) c = { k: rearrange(torch.stack([gs_c[k] for gs_c in all_gs_c]), '(B V) ... -> B V ...', B=B, V=V) # torch.stack([gs_c[k] for gs_c in all_gs_c]) if isinstance(all_gs_c[0][k], torch.Tensor) else all_gs_c[0][k] for k in all_gs_c[0].keys() } # c = collate_gs_c return { # **sample, f'{prefix}img_to_encoder': img_to_encoder, f'{prefix}img': img, f'{prefix}depth_mask': fg_mask_reso, f'{prefix}depth': depth_reso, f'{prefix}c': c, f'{prefix}bbox': bbox, } def single_sample_create_dict(self, sample, prefix=''): # if len(sample) == 1: # sample = sample[0] # assert len(sample) == 6 img_to_encoder, img, fg_mask_reso, depth_reso, c, bbox = sample if self.gs_cam_format: # TODO, can optimize later after model converges B, V, _ = c.shape # B 4 25 c = rearrange(c, 'B V C -> (B V) C').cpu().numpy() all_gs_c = [self.c_to_3dgs_format(pose) for pose in c] c = { k: rearrange(torch.stack([gs_c[k] for gs_c in all_gs_c]), '(B V) ... -> B V ...', B=B, V=V) if isinstance(all_gs_c[0][k], torch.Tensor) else all_gs_c[0][k] for k in all_gs_c[0].keys() } # c = collate_gs_c return { # **sample, f'{prefix}img_to_encoder': img_to_encoder, f'{prefix}img': img, f'{prefix}depth_mask': fg_mask_reso, f'{prefix}depth': depth_reso, f'{prefix}c': c, f'{prefix}bbox': bbox, } def single_instance_sample_create_dict(self, sample, prfix=''): assert len(sample) == 42 inp_sample_list = [[] for _ in range(6)] for item in sample[:40]: for item_idx in range(6): inp_sample_list[item_idx].append(item[0][item_idx]) inp_sample = self.single_sample_create_dict( (torch.stack(item_list) for item_list in inp_sample_list), prefix='') return { **inp_sample, # 'caption': sample[-2], 'ins': sample[-1] } def decode_gzip(self, sample_pyd, shape=(256, 256)): # sample_npz, ins, caption = sample_pyd # three items # c, bbox, depth, ins, caption, raw_img = sample_pyd[:5], sample_pyd[5:] # wds.to_tuple('raw_img.jpeg', 'depth.jpeg', # 'd_near.npy', # 'd_far.npy', # "c.npy", 'bbox.npy', 'ins.txt', 'caption.txt'), # raw_img, depth, alpha_mask, d_near, d_far, c, bbox, ins, caption = sample_pyd raw_img, depth_alpha, = sample_pyd # return raw_img, depth_alpha # raw_img, caption = sample_pyd # return raw_img, caption # st() raw_img = rearrange(raw_img, 'h (b w) c -> b h w c', b=self.chunk_size) depth = rearrange(depth, 'h (b w) c -> b h w c', b=self.chunk_size) alpha_mask = rearrange( alpha_mask, 'h (b w) c -> b h w c', b=self.chunk_size) / 255.0 d_far = d_far.reshape(self.chunk_size, 1, 1, 1) d_near = d_near.reshape(self.chunk_size, 1, 1, 1) # d = 1 / ( (d_normalized / 255) * (far-near) + near) depth = 1 / ((depth / 255) * (d_far - d_near) + d_near) depth = depth[..., 0] # decoded from jpeg # depth = decompress_array(depth['depth'], (self.chunk_size, *shape), # np.float32, # decompress=True, # decompress_fn=lz4.frame.decompress) # return raw_img, depth, d_near, d_far, c, bbox, caption, ins raw_img = np.concatenate([raw_img, alpha_mask[..., 0:1]], -1) return raw_img, depth, c, bbox, caption, ins def decode_zip( self, sample_pyd, ): shape = (self.reso_encoder, self.reso_encoder) if isinstance(sample_pyd, tuple): sample_pyd = sample_pyd[0] assert isinstance(sample_pyd, dict) raw_img = decompress_and_open_image_gzip( sample_pyd['raw_img'], is_img=True, decompress=True, decompress_fn=lz4.frame.decompress) caption = sample_pyd['caption'].decode('utf-8') ins = sample_pyd['ins'].decode('utf-8') c = decompress_array(sample_pyd['c'], ( self.chunk_size, 25, ), np.float32, decompress=True, decompress_fn=lz4.frame.decompress) bbox = decompress_array( sample_pyd['bbox'], ( self.chunk_size, 4, ), np.float32, # decompress=False) decompress=True, decompress_fn=lz4.frame.decompress) if self.decode_encode_img_only: depth = np.zeros(shape=(self.chunk_size, *shape)) # save loading time else: depth = decompress_array(sample_pyd['depth'], (self.chunk_size, *shape), np.float32, decompress=True, decompress_fn=lz4.frame.decompress) # return {'raw_img': raw_img, 'depth': depth, 'bbox': bbox, 'caption': caption, 'ins': ins, 'c': c} # return raw_img, depth, c, bbox, caption, ins # return raw_img, bbox, caption, ins # return bbox, caption, ins return raw_img, depth, c, bbox, caption, ins # ! run single-instance pipeline first # return raw_img[0], depth[0], c[0], bbox[0], caption, ins def create_dict_nobatch(self, sample): # sample = [item[0] for item in sample] # wds wrap items in [] sample_length = 6 # if self.load_pcd: # sample_length += 1 cano_sample_list = [[] for _ in range(sample_length)] nv_sample_list = [[] for _ in range(sample_length)] # st() # bs = (len(sample)-2) // 6 for idx in range(0, self.pair_per_instance): cano_sample = sample[sample_length * idx:sample_length * (idx + 1)] nv_sample = sample[sample_length * self.pair_per_instance + sample_length * idx:sample_length * self.pair_per_instance + sample_length * (idx + 1)] for item_idx in range(sample_length): if self.frame_0_as_canonical: # ! cycle input/output view for more pairs if item_idx == 4: cano_sample_list[item_idx].append( cano_sample[item_idx][..., :25]) nv_sample_list[item_idx].append( nv_sample[item_idx][..., :25]) cano_sample_list[item_idx].append( nv_sample[item_idx][..., 25:]) nv_sample_list[item_idx].append( cano_sample[item_idx][..., 25:]) else: cano_sample_list[item_idx].append( cano_sample[item_idx]) nv_sample_list[item_idx].append(nv_sample[item_idx]) cano_sample_list[item_idx].append(nv_sample[item_idx]) nv_sample_list[item_idx].append(cano_sample[item_idx]) else: cano_sample_list[item_idx].append(cano_sample[item_idx]) nv_sample_list[item_idx].append(nv_sample[item_idx]) cano_sample_list[item_idx].append(nv_sample[item_idx]) nv_sample_list[item_idx].append(cano_sample[item_idx]) cano_sample = self.single_sample_create_dict_noBatch( (torch.stack(item_list, 0) for item_list in cano_sample_list), prefix='' ) # torch.Size([5, 10, 256, 256]). Since no batch dim here for now. nv_sample = self.single_sample_create_dict_noBatch( (torch.stack(item_list, 0) for item_list in nv_sample_list), prefix='nv_') ret_dict = { **cano_sample, **nv_sample, } if not self.load_pcd: ret_dict.update({'caption': sample[-2], 'ins': sample[-1]}) else: # if self.frame_0_as_canonical: # # fps_pcd = rearrange( sample[-1], 'B V ... -> (B V) ...') # ! wrong order. # # if self.chunk_size == 8: # fps_pcd = rearrange( # sample[-1], 'B V ... -> (V B) ...') # mimic torch.repeat # # else: # # fps_pcd = rearrange( sample[-1], 'B V ... -> (B V) ...') # ugly code to match the input format... # else: # fps_pcd = sample[-1].repeat( # 2, 1, # 1) # mimic torch.cat(), from torch.Size([3, 4096, 3]) # ! TODO, check fps_pcd order ret_dict.update({ 'caption': sample[-3], 'ins': sample[-2], 'fps_pcd': sample[-1] }) return ret_dict def create_dict(self, sample): # sample = [item[0] for item in sample] # wds wrap items in [] # st() sample_length = 6 # if self.load_pcd: # sample_length += 1 cano_sample_list = [[] for _ in range(sample_length)] nv_sample_list = [[] for _ in range(sample_length)] # st() # bs = (len(sample)-2) // 6 for idx in range(0, self.pair_per_instance): cano_sample = sample[sample_length * idx:sample_length * (idx + 1)] nv_sample = sample[sample_length * self.pair_per_instance + sample_length * idx:sample_length * self.pair_per_instance + sample_length * (idx + 1)] for item_idx in range(sample_length): if self.frame_0_as_canonical: # ! cycle input/output view for more pairs if item_idx == 4: cano_sample_list[item_idx].append( cano_sample[item_idx][..., :25]) nv_sample_list[item_idx].append( nv_sample[item_idx][..., :25]) cano_sample_list[item_idx].append( nv_sample[item_idx][..., 25:]) nv_sample_list[item_idx].append( cano_sample[item_idx][..., 25:]) else: cano_sample_list[item_idx].append( cano_sample[item_idx]) nv_sample_list[item_idx].append(nv_sample[item_idx]) cano_sample_list[item_idx].append(nv_sample[item_idx]) nv_sample_list[item_idx].append(cano_sample[item_idx]) else: cano_sample_list[item_idx].append(cano_sample[item_idx]) nv_sample_list[item_idx].append(nv_sample[item_idx]) cano_sample_list[item_idx].append(nv_sample[item_idx]) nv_sample_list[item_idx].append(cano_sample[item_idx]) # if self.split_chunk_input: # cano_sample = self.single_sample_create_dict( # (torch.cat(item_list, 0) for item_list in cano_sample_list), # prefix='') # nv_sample = self.single_sample_create_dict( # (torch.cat(item_list, 0) for item_list in nv_sample_list), # prefix='nv_') # else: # st() cano_sample = self.single_sample_create_dict( (torch.cat(item_list, 0) for item_list in cano_sample_list), prefix='') # torch.Size([4, 4, 10, 256, 256]) nv_sample = self.single_sample_create_dict( (torch.cat(item_list, 0) for item_list in nv_sample_list), prefix='nv_') ret_dict = { **cano_sample, **nv_sample, } if not self.load_pcd: ret_dict.update({'caption': sample[-2], 'ins': sample[-1]}) else: if self.frame_0_as_canonical: # fps_pcd = rearrange( sample[-1], 'B V ... -> (B V) ...') # ! wrong order. # if self.chunk_size == 8: fps_pcd = rearrange( sample[-1], 'B V ... -> (V B) ...') # mimic torch.repeat # else: # fps_pcd = rearrange( sample[-1], 'B V ... -> (B V) ...') # ugly code to match the input format... else: fps_pcd = sample[-1].repeat( 2, 1, 1) # mimic torch.cat(), from torch.Size([3, 4096, 3]) ret_dict.update({ 'caption': sample[-3], 'ins': sample[-2], 'fps_pcd': fps_pcd }) return ret_dict def prepare_mv_input(self, sample): # sample = [item[0] for item in sample] # wds wrap items in [] bs = len(sample['caption']) # number of instances chunk_size = sample['img'].shape[0] // bs assert self.split_chunk_input for k, v in sample.items(): if isinstance(v, torch.Tensor) and k != 'fps_pcd': sample[k] = rearrange(v, "b f c ... -> (b f) c ...", f=self.V).contiguous() # # ! shift nv # else: # for k, v in sample.items(): # if k not in ['ins', 'caption']: # rolled_idx = torch.LongTensor( # list( # itertools.chain.from_iterable( # list(range(i, sample['img'].shape[0], bs)) # for i in range(bs)))) # v = torch.index_select(v, dim=0, index=rolled_idx) # sample[k] = v # # img = sample['img'] # # gt = sample['nv_img'] # # torchvision.utils.save_image(img[0], 'inp.jpg', normalize=True) # # torchvision.utils.save_image(gt[0], 'nv.jpg', normalize=True) # for k, v in sample.items(): # if 'nv' in k: # rolled_idx = torch.LongTensor( # list( # itertools.chain.from_iterable( # list( # np.roll( # np.arange(i * chunk_size, (i + 1) * # chunk_size), 4) # for i in range(bs))))) # v = torch.index_select(v, dim=0, index=rolled_idx) # sample[k] = v # torchvision.utils.save_image(sample['nv_img'], 'nv.png', normalize=True) # torchvision.utils.save_image(sample['img'], 'inp.png', normalize=True) return sample def load_dataset( file_path="", reso=64, reso_encoder=224, batch_size=1, # shuffle=True, num_workers=6, load_depth=False, preprocess=None, imgnet_normalize=True, dataset_size=-1, trainer_name='input_rec', use_lmdb=False, use_wds=False, use_chunk=False, use_lmdb_compressed=False, infi_sampler=True): # st() # dataset_cls = { # 'input_rec': MultiViewDataset, # 'nv': NovelViewDataset, # }[trainer_name] # st() if use_wds: return load_wds_data(file_path, reso, reso_encoder, batch_size, num_workers) if use_lmdb: logger.log('using LMDB dataset') # dataset_cls = LMDBDataset_MV # 2.5-3iter/s, but unstable, drops to 1 later. if use_lmdb_compressed: if 'nv' in trainer_name: dataset_cls = Objv_LMDBDataset_NV_Compressed # 2.5-3iter/s, but unstable, drops to 1 later. else: dataset_cls = Objv_LMDBDataset_MV_Compressed # 2.5-3iter/s, but unstable, drops to 1 later. else: if 'nv' in trainer_name: dataset_cls = Objv_LMDBDataset_NV_NoCompressed # 2.5-3iter/s, but unstable, drops to 1 later. else: dataset_cls = Objv_LMDBDataset_MV_NoCompressed # 2.5-3iter/s, but unstable, drops to 1 later. # dataset = dataset_cls(file_path) elif use_chunk: dataset_cls = ChunkObjaverseDataset else: if 'nv' in trainer_name: dataset_cls = NovelViewObjverseDataset else: dataset_cls = MultiViewObjverseDataset # 1.5-2iter/s dataset = dataset_cls(file_path, reso, reso_encoder, test=False, preprocess=preprocess, load_depth=load_depth, imgnet_normalize=imgnet_normalize, dataset_size=dataset_size) logger.log('dataset_cls: {}, dataset size: {}'.format( trainer_name, len(dataset))) if use_chunk: def chunk_collate_fn(sample): # st() default_collate_sample = torch.utils.data.default_collate( sample[0]) st() return default_collate_sample collate_fn = chunk_collate_fn else: collate_fn = None loader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, drop_last=False, pin_memory=True, persistent_workers=num_workers > 0, shuffle=use_chunk, collate_fn=collate_fn) return loader def chunk_collate_fn(sample): sample = torch.utils.data.default_collate(sample) # ! change from stack to cat # sample = self.post_process.prepare_mv_input(sample) bs = len(sample['caption']) # number of instances # chunk_size = sample['img'].shape[0] // bs def merge_internal_batch(sample, merge_b_only=False): for k, v in sample.items(): if isinstance(v, torch.Tensor): if v.ndim > 1: if k == 'fps_pcd' or merge_b_only: sample[k] = rearrange( v, "b1 b2 ... -> (b1 b2) ...").float().contiguous() else: sample[k] = rearrange( v, "b1 b2 f c ... -> (b1 b2 f) c ...").float( ).contiguous() elif k == 'tanfov': sample[k] = v[0].float().item() # tanfov. if isinstance(sample['c'], dict): # 3dgs merge_internal_batch(sample['c'], merge_b_only=True) merge_internal_batch(sample['nv_c'], merge_b_only=True) merge_internal_batch(sample) return sample def chunk_ddpm_collate_fn(sample): sample = torch.utils.data.default_collate(sample) # ! change from stack to cat # sample = self.post_process.prepare_mv_input(sample) # bs = len(sample['caption']) # number of instances # chunk_size = sample['img'].shape[0] // bs def merge_internal_batch(sample, merge_b_only=False): for k, v in sample.items(): if isinstance(v, torch.Tensor): if v.ndim > 1: # if k in ['c', 'latent']: sample[k] = rearrange( v, "b1 b2 ... -> (b1 b2) ...").float().contiguous() # else: # img # sample[k] = rearrange( # v, "b1 b2 f ... -> (b1 b2 f) ...").float( # ).contiguous() else: # caption & ins v = [v[i][0] for i in range(len(v))] merge_internal_batch(sample) # if 'caption' in sample: # sample['caption'] = sample['caption'][0] + sample['caption'][1] return sample def load_data_cls( file_path="", reso=64, reso_encoder=224, batch_size=1, # shuffle=True, num_workers=6, load_depth=False, preprocess=None, imgnet_normalize=True, dataset_size=-1, trainer_name='input_rec', use_lmdb=False, use_wds=False, use_chunk=False, use_lmdb_compressed=False, # plucker_embedding=False, # frame_0_as_canonical=False, infi_sampler=True, load_latent=False, return_dataset=False, load_caption_dataset=False, load_mv_dataset=False, **kwargs): # st() # dataset_cls = { # 'input_rec': MultiViewDataset, # 'nv': NovelViewDataset, # }[trainer_name] # st() # if use_lmdb: # logger.log('using LMDB dataset') # # dataset_cls = LMDBDataset_MV # 2.5-3iter/s, but unstable, drops to 1 later. # if 'nv' in trainer_name: # dataset_cls = Objv_LMDBDataset_NV_Compressed # 2.5-3iter/s, but unstable, drops to 1 later. # else: # dataset_cls = Objv_LMDBDataset_MV_Compressed # 2.5-3iter/s, but unstable, drops to 1 later. # # dataset = dataset_cls(file_path) collate_fn = None if use_lmdb: logger.log('using LMDB dataset') # dataset_cls = LMDBDataset_MV # 2.5-3iter/s, but unstable, drops to 1 later. if use_lmdb_compressed: if 'nv' in trainer_name: dataset_cls = Objv_LMDBDataset_NV_Compressed # 2.5-3iter/s, but unstable, drops to 1 later. else: dataset_cls = Objv_LMDBDataset_MV_Compressed # 2.5-3iter/s, but unstable, drops to 1 later. else: if 'nv' in trainer_name: dataset_cls = Objv_LMDBDataset_NV_NoCompressed # 2.5-3iter/s, but unstable, drops to 1 later. else: dataset_cls = Objv_LMDBDataset_MV_NoCompressed # 2.5-3iter/s, but unstable, drops to 1 later. elif use_chunk: if load_latent: # if 'gs_cam_format' in kwargs: if kwargs['gs_cam_format']: if load_caption_dataset: dataset_cls = ChunkObjaverseDatasetDDPMgsT23D collate_fn = chunk_ddpm_collate_fn else: if load_mv_dataset: # dataset_cls = ChunkObjaverseDatasetDDPMgsMV23D # ! if multi-view dataset_cls = ChunkObjaverseDatasetDDPMgsMV23DSynthetic # ! if multi-view # collate_fn = chunk_ddpm_collate_fn collate_fn = None else: dataset_cls = ChunkObjaverseDatasetDDPMgsI23D collate_fn = None else: dataset_cls = ChunkObjaverseDatasetDDPM collate_fn = chunk_ddpm_collate_fn else: dataset_cls = ChunkObjaverseDataset collate_fn = chunk_collate_fn else: if 'nv' in trainer_name: dataset_cls = NovelViewObjverseDataset # 1.5-2iter/s else: dataset_cls = MultiViewObjverseDataset dataset = dataset_cls(file_path, reso, reso_encoder, test=False, preprocess=preprocess, load_depth=load_depth, imgnet_normalize=imgnet_normalize, dataset_size=dataset_size, **kwargs # plucker_embedding=plucker_embedding ) logger.log('dataset_cls: {}, dataset size: {}'.format( trainer_name, len(dataset))) # st() return dataset def load_data( file_path="", reso=64, reso_encoder=224, batch_size=1, # shuffle=True, num_workers=6, load_depth=False, preprocess=None, imgnet_normalize=True, dataset_size=-1, trainer_name='input_rec', use_lmdb=False, use_wds=False, use_chunk=False, use_lmdb_compressed=False, # plucker_embedding=False, # frame_0_as_canonical=False, infi_sampler=True, load_latent=False, return_dataset=False, load_caption_dataset=False, load_mv_dataset=False, **kwargs): # st() # dataset_cls = { # 'input_rec': MultiViewDataset, # 'nv': NovelViewDataset, # }[trainer_name] # st() # if use_lmdb: # logger.log('using LMDB dataset') # # dataset_cls = LMDBDataset_MV # 2.5-3iter/s, but unstable, drops to 1 later. # if 'nv' in trainer_name: # dataset_cls = Objv_LMDBDataset_NV_Compressed # 2.5-3iter/s, but unstable, drops to 1 later. # else: # dataset_cls = Objv_LMDBDataset_MV_Compressed # 2.5-3iter/s, but unstable, drops to 1 later. # # dataset = dataset_cls(file_path) collate_fn = None if use_lmdb: logger.log('using LMDB dataset') # dataset_cls = LMDBDataset_MV # 2.5-3iter/s, but unstable, drops to 1 later. if use_lmdb_compressed: if 'nv' in trainer_name: dataset_cls = Objv_LMDBDataset_NV_Compressed # 2.5-3iter/s, but unstable, drops to 1 later. else: dataset_cls = Objv_LMDBDataset_MV_Compressed # 2.5-3iter/s, but unstable, drops to 1 later. else: if 'nv' in trainer_name: dataset_cls = Objv_LMDBDataset_NV_NoCompressed # 2.5-3iter/s, but unstable, drops to 1 later. else: dataset_cls = Objv_LMDBDataset_MV_NoCompressed # 2.5-3iter/s, but unstable, drops to 1 later. elif use_chunk: # st() if load_latent: if kwargs['gs_cam_format']: if load_caption_dataset: dataset_cls = ChunkObjaverseDatasetDDPMgsT23D # collate_fn = chunk_ddpm_collate_fn collate_fn = None else: if load_mv_dataset: # dataset_cls = ChunkObjaverseDatasetDDPMgsMV23D dataset_cls = ChunkObjaverseDatasetDDPMgsMV23DSynthetic # ! if multi-view # collate_fn = chunk_ddpm_collate_fn collate_fn = None else: # dataset_cls = ChunkObjaverseDatasetDDPMgsI23D # load i23d # collate_fn = None # load mv dataset for i23d dataset_cls = ChunkObjaverseDatasetDDPMgsI23D_loadMV collate_fn = chunk_ddpm_collate_fn else: dataset_cls = ChunkObjaverseDatasetDDPM collate_fn = chunk_ddpm_collate_fn else: dataset_cls = ChunkObjaverseDataset collate_fn = chunk_collate_fn else: if 'nv' in trainer_name: dataset_cls = NovelViewObjverseDataset # 1.5-2iter/s else: dataset_cls = MultiViewObjverseDataset dataset = dataset_cls(file_path, reso, reso_encoder, test=False, preprocess=preprocess, load_depth=load_depth, imgnet_normalize=imgnet_normalize, dataset_size=dataset_size, **kwargs # plucker_embedding=plucker_embedding ) logger.log('dataset_cls: {}, dataset size: {}'.format( trainer_name, len(dataset))) # st() if return_dataset: return dataset assert infi_sampler if infi_sampler: train_sampler = DistributedSampler(dataset=dataset, shuffle=True, drop_last=True) loader = DataLoader( dataset, batch_size=batch_size, num_workers=num_workers, drop_last=True, pin_memory=True, persistent_workers=num_workers > 0, sampler=train_sampler, collate_fn=collate_fn, # prefetch_factor=3 if num_workers>0 else None, ) while True: yield from loader # else: # # loader = DataLoader(dataset, # # batch_size=batch_size, # # num_workers=num_workers, # # drop_last=False, # # pin_memory=True, # # persistent_workers=num_workers > 0, # # shuffle=False) # st() # return dataset def load_eval_data( file_path="", reso=64, reso_encoder=224, batch_size=1, num_workers=1, load_depth=False, preprocess=None, imgnet_normalize=True, interval=1, use_lmdb=False, plucker_embedding=False, load_real=False, load_mv_real=False, load_gso=False, four_view_for_latent=False, shuffle_across_cls=False, load_extra_36_view=False, gs_cam_format=False, single_view_for_i23d=False, use_chunk=False, **kwargs, ): collate_fn = None if use_lmdb: logger.log('using LMDB dataset') dataset_cls = Objv_LMDBDataset_MV_Compressed # 2.5-3iter/s, but unstable, drops to 1 later. dataset = dataset_cls(file_path, reso, reso_encoder, test=True, preprocess=preprocess, load_depth=load_depth, imgnet_normalize=imgnet_normalize, interval=interval) elif use_chunk: dataset = ChunkObjaverseDataset( file_path, reso, reso_encoder, test=False, preprocess=preprocess, load_depth=load_depth, imgnet_normalize=imgnet_normalize, # dataset_size=dataset_size, gs_cam_format=gs_cam_format, plucker_embedding=plucker_embedding, wds_split_all=2, # frame_0_as_canonical=frame_0_as_canonical, **kwargs) collate_fn = chunk_collate_fn elif load_real: if load_mv_real: dataset_cls = RealMVDataset elif load_gso: # st() dataset_cls = RealDataset_GSO else: # single-view i23d dataset_cls = RealDataset dataset = dataset_cls(file_path, reso, reso_encoder, preprocess=preprocess, load_depth=load_depth, test=True, imgnet_normalize=imgnet_normalize, interval=interval, plucker_embedding=plucker_embedding) else: dataset = MultiViewObjverseDataset( file_path, reso, reso_encoder, preprocess=preprocess, load_depth=load_depth, test=True, imgnet_normalize=imgnet_normalize, interval=interval, plucker_embedding=plucker_embedding, four_view_for_latent=four_view_for_latent, load_extra_36_view=load_extra_36_view, shuffle_across_cls=shuffle_across_cls, gs_cam_format=gs_cam_format, single_view_for_i23d=single_view_for_i23d, **kwargs) print('eval dataset size: {}'.format(len(dataset))) # train_sampler = DistributedSampler(dataset=dataset) loader = DataLoader( dataset, batch_size=batch_size, num_workers=num_workers, drop_last=False, shuffle=False, collate_fn=collate_fn, ) # sampler=train_sampler) # return loader return iter(loader) def load_data_for_lmdb( file_path="", reso=64, reso_encoder=224, batch_size=1, # shuffle=True, num_workers=6, load_depth=False, preprocess=None, imgnet_normalize=True, dataset_size=-1, trainer_name='input_rec', shuffle_across_cls=False, four_view_for_latent=False, wds_split=1): # st() # dataset_cls = { # 'input_rec': MultiViewDataset, # 'nv': NovelViewDataset, # }[trainer_name] # if 'nv' in trainer_name: # dataset_cls = NovelViewDataset # else: # dataset_cls = MultiViewDataset # st() # dataset_cls = MultiViewObjverseDatasetforLMDB dataset_cls = MultiViewObjverseDatasetforLMDB_nocaption dataset = dataset_cls(file_path, reso, reso_encoder, test=False, preprocess=preprocess, load_depth=load_depth, imgnet_normalize=imgnet_normalize, dataset_size=dataset_size, shuffle_across_cls=shuffle_across_cls, wds_split=wds_split, four_view_for_latent=four_view_for_latent) logger.log('dataset_cls: {}, dataset size: {}'.format( trainer_name, len(dataset))) # train_sampler = DistributedSampler(dataset=dataset, shuffle=True, drop_last=True) loader = DataLoader( dataset, shuffle=False, batch_size=batch_size, num_workers=num_workers, drop_last=False, # prefetch_factor=2, # prefetch_factor=3, pin_memory=True, persistent_workers=num_workers > 0, ) # sampler=train_sampler) # while True: # yield from loader return loader, dataset.dataset_name, len(dataset) def load_lmdb_for_lmdb( file_path="", reso=64, reso_encoder=224, batch_size=1, # shuffle=True, num_workers=6, load_depth=False, preprocess=None, imgnet_normalize=True, dataset_size=-1, trainer_name='input_rec'): # st() # dataset_cls = { # 'input_rec': MultiViewDataset, # 'nv': NovelViewDataset, # }[trainer_name] # if 'nv' in trainer_name: # dataset_cls = NovelViewDataset # else: # dataset_cls = MultiViewDataset # st() dataset_cls = Objv_LMDBDataset_MV_Compressed_for_lmdb dataset = dataset_cls(file_path, reso, reso_encoder, test=False, preprocess=preprocess, load_depth=load_depth, imgnet_normalize=imgnet_normalize, dataset_size=dataset_size) logger.log('dataset_cls: {}, dataset size: {}'.format( trainer_name, len(dataset))) # train_sampler = DistributedSampler(dataset=dataset, shuffle=True, drop_last=True) loader = DataLoader( dataset, shuffle=False, batch_size=batch_size, num_workers=num_workers, drop_last=False, # prefetch_factor=2, # prefetch_factor=3, pin_memory=True, persistent_workers=True, ) # sampler=train_sampler) # while True: # yield from loader return loader, len(dataset) def load_memory_data( file_path="", reso=64, reso_encoder=224, batch_size=1, num_workers=1, # load_depth=True, preprocess=None, imgnet_normalize=True, use_chunk=True, **kwargs): # load a single-instance into the memory to speed up training IO # dataset = MultiViewObjverseDataset(file_path, collate_fn = None if use_chunk: dataset_cls = ChunkObjaverseDataset collate_fn = chunk_collate_fn else: dataset_cls = NovelViewObjverseDataset dataset = dataset_cls(file_path, reso, reso_encoder, preprocess=preprocess, load_depth=True, test=False, overfitting=True, imgnet_normalize=imgnet_normalize, overfitting_bs=batch_size, **kwargs) logger.log('!!!!!!! memory dataset size: {} !!!!!!'.format(len(dataset))) # train_sampler = DistributedSampler(dataset=dataset) loader = DataLoader( dataset, batch_size=len(dataset), num_workers=num_workers, drop_last=False, shuffle=False, collate_fn = collate_fn ) all_data: dict = next( iter(loader) ) # torchvision.utils.save_image(all_data['img'], 'gt.jpg', normalize=True, value_range=(-1,1)) # st() if kwargs.get('gs_cam_format', False): # gs rendering pipeline # ! load V=4 images for training in a batch. while True: # st() # indices = torch.randperm(len(dataset))[:4] indices = torch.randperm( len(dataset) * 2)[:batch_size] # all instances # indices2 = torch.randperm(len(dataset))[:] # all instances batch_c = collections.defaultdict(dict) V = all_data['c']['source_cv2wT_quat'].shape[1] for k in ['c', 'nv_c']: for k_c, v_c in all_data[k].items(): if k_c == 'tanfov': continue try: batch_c[k][ k_c] = torch.index_select( # ! chunk data reading pipeline v_c, dim=0, index=indices ).reshape(batch_size, V, *v_c.shape[2:]).float( ) if isinstance( v_c, torch.Tensor) else v_c # float except Exception as e: st() print(e) # ! read chunk not required, already float batch_c['c']['tanfov'] = all_data['c']['tanfov'] batch_c['nv_c']['tanfov'] = all_data['nv_c']['tanfov'] indices_range = torch.arange(indices[0]*V, (indices[0]+1)*V) batch_data = {} for k, v in all_data.items(): if k not in ['c', 'nv_c']: try: if k == 'fps_pcd': batch_data[k] = torch.index_select( v, dim=0, index=indices).float() if isinstance( v, torch.Tensor) else v # float else: batch_data[k] = torch.index_select( v, dim=0, index=indices_range).float() if isinstance( v, torch.Tensor) else v # float except: st() print(e) memory_batch_data = { **batch_data, **batch_c, } yield memory_batch_data else: while True: start_idx = np.random.randint(0, len(dataset) - batch_size + 1) yield { k: v[start_idx:start_idx + batch_size] for k, v in all_data.items() } def read_dnormal(normald_path, cond_pos, h=None, w=None): cond_cam_dis = np.linalg.norm(cond_pos, 2) near = 0.867 #sqrt(3) * 0.5 near_distance = cond_cam_dis - near normald = cv2.imread(normald_path, cv2.IMREAD_UNCHANGED).astype(np.float32) normal, depth = normald[..., :3], normald[..., 3:] depth[depth < near_distance] = 0 if h is not None: assert w is not None if depth.shape[1] != h: depth = cv2.resize(depth, (h, w), interpolation=cv2.INTER_NEAREST ) # 512,512, 1 -> self.reso, self.reso # depth = cv2.resize(depth, (h, w), interpolation=cv2.INTER_LANCZOS4 # ) # ! may fail if nearest. dirty data. # st() else: depth = depth[..., 0] if normal.shape[1] != h: normal = cv2.resize(normal, (h, w), interpolation=cv2.INTER_NEAREST ) # 512,512, 1 -> self.reso, self.reso else: depth = depth[..., 0] return torch.from_numpy(depth).float(), torch.from_numpy(normal).float() def get_intri(target_im=None, h=None, w=None, normalize=False): if target_im is None: assert (h is not None and w is not None) else: h, w = target_im.shape[:2] fx = fy = 1422.222 res_raw = 1024 f_x = f_y = fx * h / res_raw K = np.array([f_x, 0, w / 2, 0, f_y, h / 2, 0, 0, 1]).reshape(3, 3) if normalize: # center is [0.5, 0.5], eg3d renderer tradition K[:6] /= h # print("intr: ", K) return K def convert_pose(C2W): # https://github.com/modelscope/richdreamer/blob/c3d9a77fa15fc42dbae12c2d41d64aaec14efd37/dataset/gobjaverse/depth_warp_example.py#L402 flip_yz = np.eye(4) flip_yz[1, 1] = -1 flip_yz[2, 2] = -1 C2W = np.matmul(C2W, flip_yz) return torch.from_numpy(C2W) def read_camera_matrix_single(json_file): with open(json_file, 'r', encoding='utf8') as reader: json_content = json.load(reader) ''' # NOTE that different from unity2blender experiments. camera_matrix = np.eye(4) camera_matrix[:3, 0] = np.array(json_content['x']) camera_matrix[:3, 1] = -np.array(json_content['y']) camera_matrix[:3, 2] = -np.array(json_content['z']) camera_matrix[:3, 3] = np.array(json_content['origin']) ''' camera_matrix = np.eye(4) # blender-based camera_matrix[:3, 0] = np.array(json_content['x']) camera_matrix[:3, 1] = np.array(json_content['y']) camera_matrix[:3, 2] = np.array(json_content['z']) camera_matrix[:3, 3] = np.array(json_content['origin']) # print(camera_matrix) # ''' # return convert_pose(camera_matrix) return camera_matrix def unity2blender(normal): normal_clone = normal.copy() normal_clone[..., 0] = -normal[..., -1] normal_clone[..., 1] = -normal[..., 0] normal_clone[..., 2] = normal[..., 1] return normal_clone def unity2blender_fix(normal): # up blue, left green, front (towards inside) red normal_clone = normal.copy() # normal_clone[..., 0] = -normal[..., 2] # normal_clone[..., 1] = -normal[..., 0] normal_clone[..., 0] = -normal[..., 0] # swap r and g normal_clone[..., 1] = -normal[..., 2] normal_clone[..., 2] = normal[..., 1] return normal_clone def unity2blender_th(normal): assert normal.shape[1] == 3 # B 3 H W... normal_clone = normal.clone() normal_clone[:, 0, ...] = -normal[:, -1, ...] normal_clone[:, 1, ...] = -normal[:, 0, ...] normal_clone[:, 2, ...] = normal[:, 1, ...] return normal_clone def blender2midas(img): '''Blender: rub midas: lub ''' img[..., 0] = -img[..., 0] img[..., 1] = -img[..., 1] img[..., -1] = -img[..., -1] return img def current_milli_time(): return round(time.time() * 1000) # modified from ShapeNet class class MultiViewObjverseDataset(Dataset): def __init__( self, file_path, reso, reso_encoder, preprocess=None, classes=False, load_depth=False, test=False, scene_scale=1, overfitting=False, imgnet_normalize=True, dataset_size=-1, overfitting_bs=-1, interval=1, plucker_embedding=False, shuffle_across_cls=False, wds_split=1, # 4 splits to accelerate preprocessing four_view_for_latent=False, single_view_for_i23d=False, load_extra_36_view=False, gs_cam_format=False, frame_0_as_canonical=False, **kwargs): self.load_extra_36_view = load_extra_36_view # st() self.gs_cam_format = gs_cam_format self.frame_0_as_canonical = frame_0_as_canonical self.four_view_for_latent = four_view_for_latent # export 0 12 30 36, 4 views for reconstruction self.single_view_for_i23d = single_view_for_i23d self.file_path = file_path self.overfitting = overfitting self.scene_scale = scene_scale self.reso = reso self.reso_encoder = reso_encoder self.classes = False self.load_depth = load_depth self.preprocess = preprocess self.plucker_embedding = plucker_embedding self.intrinsics = get_intri(h=self.reso, w=self.reso, normalize=True).reshape(9) assert not self.classes, "Not support class condition now." dataset_name = Path(self.file_path).stem.split('_')[0] self.dataset_name = dataset_name self.zfar = 100.0 self.znear = 0.01 # if test: # self.ins_list = sorted(os.listdir(self.file_path))[0:1] # the first 1 instance for evaluation reference. # else: # ! TODO, read from list? def load_single_cls_instances(file_path): ins_list = [] # the first 1 instance for evaluation reference. # ''' # for dict_dir in os.listdir(file_path)[:]: # for dict_dir in os.listdir(file_path)[:]: for dict_dir in os.listdir(file_path): # for dict_dir in os.listdir(file_path)[:2]: for ins_dir in os.listdir(os.path.join(file_path, dict_dir)): # self.ins_list.append(os.path.join(self.file_path, dict_dir, ins_dir,)) # /nas/shared/V2V/yslan/logs/nips23/Reconstruction/final/objav/vae/MV/170K/infer-latents/189w/v=6-rotate/latent_dir # st() # check latent whether saved # root = '/nas/shared/V2V/yslan/logs/nips23/Reconstruction/final/objav/vae/MV/170K/infer-latents/189w/v=6-rotate/latent_dir' # if os.path.exists(os.path.join(root,file_path.split('/')[-1], dict_dir, ins_dir, 'latent.npy') ): # continue # pcd_root = '/nas/shared/V2V/yslan/logs/nips23/Reconstruction/pcd-V=8_24576_polish' # pcd_root = '/nas/shared/V2V/yslan/logs/nips23/Reconstruction/pcd-V=10_4096_polish' # if os.path.exists( # os.path.join(pcd_root, 'fps-pcd', # file_path.split('/')[-1], dict_dir, # ins_dir, 'fps-4096.ply')): # continue # ! split=8 has some missing instances # root = '/cpfs01/user/lanyushi.p/data/chunk-jpeg-normal/bs_16_fixsave3/170K/384/' # if os.path.exists(os.path.join(root,file_path.split('/')[-1], dict_dir, ins_dir,) ): # continue # else: # ins_list.append( # os.path.join(file_path, dict_dir, ins_dir, # 'campos_512_v4')) # filter out some data if not os.path.exists(os.path.join(file_path, dict_dir, ins_dir, 'campos_512_v2')): continue if not os.path.exists(os.path.join(file_path, dict_dir, ins_dir, 'campos_512_v2', '00025', '00025.png')): continue if len(os.listdir(os.path.join(file_path, dict_dir, ins_dir, 'campos_512_v2'))) != 40: continue ins_list.append( os.path.join(file_path, dict_dir, ins_dir, 'campos_512_v2')) # ''' # check pcd performnace # ins_list.append( # os.path.join(file_path, '0', '10634', # 'campos_512_v4')) return ins_list # st() self.ins_list = [] # for subset in ['Animals', 'Transportations_tar', 'Furnitures']: # for subset in ['Furnitures']: # selected subset for training # if False: if True: for subset in [ # ! around 17W instances in total. # 'Animals', # 'BuildingsOutdoor', # 'daily-used', # 'Furnitures', # 'Food', # 'Plants', # 'Electronics', # 'Transportations_tar', # 'Human-Shape', 'gobjaverse_alignment_unzip', ]: # selected subset for training # if os.path.exists(f'{self.file_path}/{subset}.txt'): # dataset_list = f'{self.file_path}/{subset}_filtered.txt' dataset_list = f'{self.file_path}/{subset}_filtered_more.txt' assert os.path.exists(dataset_list) if os.path.exists(dataset_list): with open(dataset_list, 'r') as f: self.ins_list += [os.path.join(self.file_path, item.strip()) for item in f.readlines()] else: self.ins_list += load_single_cls_instances( os.path.join(self.file_path, subset)) # st() # current_time = int(current_milli_time() # ) # randomly shuffle given current time # random.seed(current_time) # random.shuffle(self.ins_list) else: # preprocess single class self.ins_list = load_single_cls_instances(self.file_path) self.ins_list = sorted(self.ins_list) if overfitting: self.ins_list = self.ins_list[:1] self.rgb_list = [] self.frame0_pose_list = [] self.pose_list = [] self.depth_list = [] self.data_ins_list = [] self.instance_data_length = -1 # self.pcd_path = Path('/cpfs01/shared/V2V/V2V_hdd/yslan/logs/nips23/Reconstruction/pcd-V=6/fps-pcd') self.pcd_path = Path( '/nas/shared/V2V/yslan/logs/nips23/Reconstruction/pcd-V=6/fps-pcd') with open( '/nas/shared/public/yslan/data/text_captions_cap3d.json') as f: # '/nas/shared/V2V/yslan/aigc3d/text_captions_cap3d.json') as f: self.caption_data = json.load(f) self.shuffle_across_cls = shuffle_across_cls # for ins in self.ins_list[47000:]: if four_view_for_latent: # also saving dense pcd # self.wds_split_all = 1 # ! when dumping latent # self.wds_split_all = 2 # ! when dumping latent # self.wds_split_all = 4 # self.wds_split_all = 6 # self.wds_split_all = 4 # self.wds_split_all = 5 # self.wds_split_all = 6 # self.wds_split_all = 7 # self.wds_split_all = 1 self.wds_split_all = 8 # self.wds_split_all = 2 # ins_list_to_process = self.ins_list all_ins_size = len(self.ins_list) ratio_size = all_ins_size // self.wds_split_all + 1 # ratio_size = int(all_ins_size / self.wds_split_all) + 1 ins_list_to_process = self.ins_list[ratio_size * (wds_split):ratio_size * (wds_split + 1)] else: # ! create shards dataset # self.wds_split_all = 4 self.wds_split_all = 8 # self.wds_split_all = 1 all_ins_size = len(self.ins_list) random.seed(0) random.shuffle(self.ins_list) # avoid same category appears in the same shard ratio_size = all_ins_size // self.wds_split_all + 1 ins_list_to_process = self.ins_list[ratio_size * # 1 - 8 (wds_split - 1):ratio_size * wds_split] # uniform_sample = False uniform_sample = True # st() for ins in tqdm(ins_list_to_process): # ins = os.path.join( # # self.file_path, ins , 'campos_512_v4' # self.file_path, ins , # # 'compos_512_v4' # ) # cur_rgb_path = os.path.join(self.file_path, ins, 'compos_512_v4') # cur_pose_path = os.path.join(self.file_path, ins, 'pose') # st() # ][:27]) if self.four_view_for_latent: # cur_all_fname = [t.split('.')[0] for t in os.listdir(ins) # ] # use full set for training # cur_all_fname = [f'{idx:05d}' for idx in [0, 12, 30, 36] # cur_all_fname = [f'{idx:05d}' for idx in [6,12,18,24] # cur_all_fname = [f'{idx:05d}' for idx in [7,16,24,25] # cur_all_fname = [f'{idx:05d}' for idx in [25,26,0,9,18,27,33,39]] cur_all_fname = [ f'{idx:05d}' for idx in [25, 26, 6, 12, 18, 24, 27, 31, 35, 39] # ! for extracting PCD ] # cur_all_fname = [f'{idx:05d}' for idx in [25,26,0,9,18,27,30,33,36,39]] # more down side for better bottom coverage. # cur_all_fname = [f'{idx:05d}' for idx in [25,0, 7,15]] # cur_all_fname = [f'{idx:05d}' for idx in [4,12,20,25,26] # cur_all_fname = [f'{idx:05d}' for idx in [6,12,18,24,25,26] # cur_all_fname = [f'{idx:05d}' for idx in [6,12,18,24,25,26, 39, 33, 27] # cur_all_fname = [f'{idx:05d}' for idx in [6,12,18,24,25,26, 39, 33, 27] # cur_all_fname = [ # f'{idx:05d}' for idx in [25, 26, 27, 30, 33, 36] # ] # for pcd unprojection # cur_all_fname = [ # f'{idx:05d}' for idx in [25, 26, 27, 30] # ! for infer latents # ] # # cur_all_fname = [ # f'{idx:05d}' for idx in [25, 27, 29, 31, 33, 35, 37 # ] # ! for infer latents # ] # # cur_all_fname = [ # f'{idx:05d}' for idx in [25, 27, 31, 35 # ] # ! for infer latents # ] # # cur_all_fname += [f'{idx:05d}' for idx in range(40) if idx not in [0,12,30,36]] # ! four views for inference elif self.single_view_for_i23d: # cur_all_fname = [f'{idx:05d}' # for idx in [16]] # 20 is also fine cur_all_fname = [f'{idx:05d}' for idx in [2]] # ! furniture side view else: cur_all_fname = [t.split('.')[0] for t in os.listdir(ins) ] # use full set for training if shuffle_across_cls: if uniform_sample: cur_all_fname = sorted(cur_all_fname) # 0-24, 25 views # 25,26, 2 views # 27-39, 13 views uniform_all_fname = [] # !!!! if bs=9 or 8 for idx in range(6): if idx % 2 == 0: chunk_all_fname = [25] else: chunk_all_fname = [26] # chunk_all_fname = [25] # no bottom view required as input # start_1 = np.random.randint(0,5) # for first 24 views # chunk_all_fname += [start_1+uniform_idx for uniform_idx in range(0,25,5)] start_1 = np.random.randint(0,4) # for first 24 views, v=8 chunk_all_fname += [start_1+uniform_idx for uniform_idx in range(0,25,7)] # [0-21] start_2 = np.random.randint(0,5) + 27 # for first 24 views chunk_all_fname += [start_2, start_2 + 4, start_2 + 8] assert len(chunk_all_fname) == 8, len(chunk_all_fname) uniform_all_fname += [cur_all_fname[fname] for fname in chunk_all_fname] # ! if bs=6 # for idx in range(8): # if idx % 2 == 0: # chunk_all_fname = [ # 25 # ] # no bottom view required as input # else: # chunk_all_fname = [ # 26 # ] # no bottom view required as input # start_1 = np.random.randint( # 0, 7) # for first 24 views # # chunk_all_fname += [start_1+uniform_idx for uniform_idx in range(0,25,5)] # chunk_all_fname += [ # start_1 + uniform_idx # for uniform_idx in range(0, 25, 9) # ] # 0 9 18 # start_2 = np.random.randint( # 0, 7) + 27 # for first 24 views # # chunk_all_fname += [start_2, start_2 + 4, start_2 + 8] # chunk_all_fname += [start_2, # start_2 + 6] # 2 frames # assert len(chunk_all_fname) == 6 # uniform_all_fname += [ # cur_all_fname[fname] # for fname in chunk_all_fname # ] cur_all_fname = uniform_all_fname else: current_time = int(current_milli_time( )) # randomly shuffle given current time random.seed(current_time) random.shuffle(cur_all_fname) else: cur_all_fname = sorted(cur_all_fname) # ! skip the check # if self.instance_data_length == -1: # self.instance_data_length = len(cur_all_fname) # else: # try: # data missing? # assert len(cur_all_fname) == self.instance_data_length # except: # # with open('error_log.txt', 'a') as f: # # f.write(str(e) + '\n') # with open('missing_ins_new2.txt', 'a') as f: # f.write(str(Path(ins.parent)) + # '\n') # remove the "campos_512_v4" # continue # if test: # use middle image as the novel view model input # mid_index = len(cur_all_fname) // 3 * 2 # cur_all_fname.insert(0, cur_all_fname[mid_index]) self.frame0_pose_list += ([ os.path.join(ins, fname, fname + '.json') for fname in [cur_all_fname[0]] ] * len(cur_all_fname)) self.pose_list += ([ os.path.join(ins, fname, fname + '.json') for fname in cur_all_fname ]) self.rgb_list += ([ os.path.join(ins, fname, fname + '.png') for fname in cur_all_fname ]) self.depth_list += ([ os.path.join(ins, fname, fname + '_nd.exr') for fname in cur_all_fname ]) self.data_ins_list += ([ins] * len(cur_all_fname)) # check # ! setup normalizataion transformations = [ transforms.ToTensor(), # [0,1] range ] if imgnet_normalize: transformations.append( transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) # type: ignore ) else: transformations.append( transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))) # type: ignore # st() self.normalize = transforms.Compose(transformations) def get_source_cw2wT(self, source_cameras_view_to_world): return matrix_to_quaternion( source_cameras_view_to_world[:3, :3].transpose(0, 1)) def c_to_3dgs_format(self, pose): # TODO, switch to torch version (batched later) c2w = pose[:16].reshape(4, 4) # 3x4 # ! load cam w2c = np.linalg.inv(c2w) R = np.transpose( w2c[:3, :3]) # R is stored transposed due to 'glm' in CUDA code T = w2c[:3, 3] fx = pose[16] FovX = focal2fov(fx, 1) FovY = focal2fov(fx, 1) tanfovx = math.tan(FovX * 0.5) tanfovy = math.tan(FovY * 0.5) assert tanfovx == tanfovy trans = np.array([0.0, 0.0, 0.0]) scale = 1.0 world_view_transform = torch.tensor(getWorld2View2(R, T, trans, scale)).transpose( 0, 1) projection_matrix = getProjectionMatrix(znear=self.znear, zfar=self.zfar, fovX=FovX, fovY=FovY).transpose(0, 1) full_proj_transform = (world_view_transform.unsqueeze(0).bmm( projection_matrix.unsqueeze(0))).squeeze(0) camera_center = world_view_transform.inverse()[3, :3] view_world_transform = torch.tensor(getView2World(R, T, trans, scale)).transpose( 0, 1) # item.update(viewpoint_cam=[viewpoint_cam]) c = {} c["source_cv2wT_quat"] = self.get_source_cw2wT(view_world_transform) c.update( # projection_matrix=projection_matrix, # K cam_view=world_view_transform, # world_view_transform cam_view_proj=full_proj_transform, # full_proj_transform cam_pos=camera_center, tanfov=tanfovx, # TODO, fix in the renderer # orig_c2w=c2w, # orig_w2c=w2c, orig_pose=torch.from_numpy(pose), orig_c2w=torch.from_numpy(c2w), orig_w2c=torch.from_numpy(w2c), # tanfovy=tanfovy, ) return c # dict for gs rendering def __len__(self): return len(self.rgb_list) def load_bbox(self, mask): # st() nonzero_value = torch.nonzero(mask) height, width = nonzero_value.max(dim=0)[0] top, left = nonzero_value.min(dim=0)[0] bbox = torch.tensor([top, left, height, width], dtype=torch.float32) return bbox def __getitem__(self, idx): # try: data = self._read_data(idx) return data # except Exception as e: # # with open('error_log_pcd.txt', 'a') as f: # with open('error_log_pcd.txt', 'a') as f: # f.write(str(e) + '\n') # with open('error_idx_pcd.txt', 'a') as f: # f.write(str(self.data_ins_list[idx]) + '\n') # print(e, flush=True) # return {} def gen_rays(self, c2w): # Generate rays self.h = self.reso_encoder self.w = self.reso_encoder yy, xx = torch.meshgrid( torch.arange(self.h, dtype=torch.float32) + 0.5, torch.arange(self.w, dtype=torch.float32) + 0.5, indexing='ij') # normalize to 0-1 pixel range yy = yy / self.h xx = xx / self.w # K = np.array([f_x, 0, w / 2, 0, f_y, h / 2, 0, 0, 1]).reshape(3, 3) cx, cy, fx, fy = self.intrinsics[2], self.intrinsics[ 5], self.intrinsics[0], self.intrinsics[4] # cx *= self.w # cy *= self.h # f_x = f_y = fx * h / res_raw c2w = torch.from_numpy(c2w).float() xx = (xx - cx) / fx yy = (yy - cy) / fy zz = torch.ones_like(xx) dirs = torch.stack((xx, yy, zz), dim=-1) # OpenCV convention dirs /= torch.norm(dirs, dim=-1, keepdim=True) dirs = dirs.reshape(-1, 3, 1) del xx, yy, zz # st() dirs = (c2w[None, :3, :3] @ dirs)[..., 0] origins = c2w[None, :3, 3].expand(self.h * self.w, -1).contiguous() origins = origins.view(self.h, self.w, 3) dirs = dirs.view(self.h, self.w, 3) return origins, dirs def normalize_camera(self, c, c_frame0): # assert c.shape[0] == self.chunk_size # 8 o r10 B = c.shape[0] camera_poses = c[:, :16].reshape(B, 4, 4) # 3x4 canonical_camera_poses = c_frame0[:, :16].reshape(B, 4, 4) # if for_encoder: # encoder_canonical_idx = [0, self.V] # st() cam_radius = np.linalg.norm( c_frame0[:, :16].reshape(1, 4, 4)[:, :3, 3], axis=-1, keepdims=False) # since g-buffer adopts dynamic radius here. frame1_fixed_pos = np.repeat(np.eye(4)[None], 1, axis=0) frame1_fixed_pos[:, 2, -1] = -cam_radius transform = frame1_fixed_pos @ np.linalg.inv(canonical_camera_poses) # from LGM, https://github.com/3DTopia/LGM/blob/fe8d12cff8c827df7bb77a3c8e8b37408cb6fe4c/core/provider_objaverse.py#L127 # transform = torch.tensor([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, self.opt.cam_radius], [0, 0, 0, 1]], dtype=torch.float32) @ torch.inverse(c[[0,4]]) new_camera_poses = np.repeat( transform, 1, axis=0 ) @ camera_poses # [V, 4, 4]. np.repeat() is th.repeat_interleave() # else: # cam_radius = np.linalg.norm( # c[canonical_idx][:16].reshape(4, 4)[:3, 3], # axis=-1, # keepdims=False # ) # since g-buffer adopts dynamic radius here. # frame1_fixed_pos = np.eye(4) # frame1_fixed_pos[2, -1] = -cam_radius # transform = frame1_fixed_pos @ np.linalg.inv( # camera_poses[canonical_idx]) # 4,4 # # from LGM, https://github.com/3DTopia/LGM/blob/fe8d12cff8c827df7bb77a3c8e8b37408cb6fe4c/core/provider_objaverse.py#L127 # # transform = torch.tensor([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, self.opt.cam_radius], [0, 0, 0, 1]], dtype=torch.float32) @ torch.inverse(c[[0,4]]) # new_camera_poses = np.repeat( # transform[None], self.chunk_size, # axis=0) @ camera_poses # [V, 4, 4] # st() c = np.concatenate([new_camera_poses.reshape(B, 16), c[:, 16:]], axis=-1) # st() return c def _read_data( self, idx, ): rgb_fname = self.rgb_list[idx] pose_fname = self.pose_list[idx] raw_img = imageio.imread(rgb_fname) # ! RGBD alpha_mask = raw_img[..., -1:] / 255 raw_img = alpha_mask * raw_img[..., :3] + ( 1 - alpha_mask) * np.ones_like(raw_img[..., :3]) * 255 raw_img = raw_img.astype( np.uint8) # otherwise, float64 won't call ToTensor() # return raw_img # st() if self.preprocess is None: img_to_encoder = cv2.resize(raw_img, (self.reso_encoder, self.reso_encoder), interpolation=cv2.INTER_LANCZOS4) # interpolation=cv2.INTER_AREA) img_to_encoder = img_to_encoder[ ..., :3] #[3, reso_encoder, reso_encoder] img_to_encoder = self.normalize(img_to_encoder) else: img_to_encoder = self.preprocess(Image.open(rgb_fname)) # clip # return img_to_encoder img = cv2.resize(raw_img, (self.reso, self.reso), interpolation=cv2.INTER_LANCZOS4) # interpolation=cv2.INTER_AREA) # img_sr = cv2.resize(raw_img, (512, 512), interpolation=cv2.INTER_AREA) # img_sr = cv2.resize(raw_img, (256, 256), interpolation=cv2.INTER_AREA) # just as refinement, since eg3d uses 64->128 final resolution # img_sr = cv2.resize(raw_img, (128, 128), interpolation=cv2.INTER_AREA) # just as refinement, since eg3d uses 64->128 final resolution # img_sr = cv2.resize( # raw_img, (128, 128), interpolation=cv2.INTER_LANCZOS4 # ) # just as refinement, since eg3d uses 64->128 final resolution # img = torch.from_numpy(img)[..., :3].permute( # 2, 0, 1) / 255.0 #[3, reso, reso] img = torch.from_numpy(img)[..., :3].permute( 2, 0, 1 ) / 127.5 - 1 #[3, reso, reso], normalize to [-1,1], follow triplane range # img_sr = torch.from_numpy(img_sr)[..., :3].permute( # 2, 0, 1 # ) / 127.5 - 1 #[3, reso, reso], normalize to [-1,1], follow triplane range c2w = read_camera_matrix_single(pose_fname) #[1, 4, 4] -> [1, 16] # c = np.concatenate([c2w, self.intrinsics], axis=0).reshape(25) # 25, no '1' dim needed. # return c2w # if self.load_depth: # depth, depth_mask, depth_mask_sr = read_dnormal(self.depth_list[idx], # try: depth, normal = read_dnormal(self.depth_list[idx], c2w[:3, 3:], self.reso, self.reso) # ! frame0 alignment # if self.frame_0_as_canonical: # return depth # except: # # print(self.depth_list[idx]) # raise NotImplementedError(self.depth_list[idx]) # if depth try: bbox = self.load_bbox(depth > 0) except: print(rgb_fname, flush=True) with open('error_log.txt', 'a') as f: f.write(str(rgb_fname + '\n')) bbox = self.load_bbox(torch.ones_like(depth)) # plucker # ! normalize camera c = np.concatenate([c2w.reshape(16), self.intrinsics], axis=0).reshape(25).astype( np.float32) # 25, no '1' dim needed. if self.frame_0_as_canonical: # 4 views as input per batch frame0_pose_name = self.frame0_pose_list[idx] c2w_frame0 = read_camera_matrix_single( frame0_pose_name) #[1, 4, 4] -> [1, 16] c = self.normalize_camera(c[None], c2w_frame0[None])[0] c2w = c[:16].reshape(4, 4) # ! # st() # pass rays_o, rays_d = self.gen_rays(c2w) rays_plucker = torch.cat([torch.cross(rays_o, rays_d, dim=-1), rays_d], dim=-1) # [h, w, 6] img_to_encoder = torch.cat( [img_to_encoder, rays_plucker.permute(2, 0, 1)], 0).float() # concat in C dim # ! add depth as input depth, normal = read_dnormal(self.depth_list[idx], c2w[:3, 3:], self.reso_encoder, self.reso_encoder) normalized_depth = depth.unsqueeze(0) # min=0 img_to_encoder = torch.cat([img_to_encoder, normalized_depth], 0) # concat in C dim if self.gs_cam_format: c = self.c_to_3dgs_format(c) else: c = torch.from_numpy(c) ret_dict = { # 'rgb_fname': rgb_fname, 'img_to_encoder': img_to_encoder, 'img': img, 'c': c, # 'img_sr': img_sr, # 'ins_name': self.data_ins_list[idx] } # ins = str( # (Path(self.data_ins_list[idx]).relative_to(self.file_path)).parent) pcd_ins = Path(self.data_ins_list[idx]).relative_to( Path(self.file_path).parent).parent # load pcd # fps_pcd = pcu.load_mesh_v( # str(self.pcd_path / pcd_ins / 'fps-10000.ply')) ins = str( # for compat (Path(self.data_ins_list[idx]).relative_to(self.file_path)).parent) # if self.shuffle_across_cls: caption = self.caption_data['/'.join(ins.split('/')[1:])] # else: # caption = self.caption_data[ins] ret_dict.update({ 'depth': depth, 'normal': normal, 'alpha_mask': alpha_mask, 'depth_mask': depth > 0, # 'depth_mask_sr': depth_mask_sr, 'bbox': bbox, 'caption': caption, 'rays_plucker': rays_plucker, # cam embedding used in lgm 'ins': ins, # placeholder # 'fps_pcd': fps_pcd, }) return ret_dict # class MultiViewObjverseDatasetChunk(MultiViewObjverseDataset): # def __init__(self, # file_path, # reso, # reso_encoder, # preprocess=None, # classes=False, # load_depth=False, # test=False, # scene_scale=1, # overfitting=False, # imgnet_normalize=True, # dataset_size=-1, # overfitting_bs=-1, # interval=1, # plucker_embedding=False, # shuffle_across_cls=False, # wds_split=1, # four_view_for_latent=False, # single_view_for_i23d=False, # load_extra_36_view=False, # gs_cam_format=False, # **kwargs): # super().__init__(file_path, reso, reso_encoder, preprocess, classes, # load_depth, test, scene_scale, overfitting, # imgnet_normalize, dataset_size, overfitting_bs, # interval, plucker_embedding, shuffle_across_cls, # wds_split, four_view_for_latent, single_view_for_i23d, # load_extra_36_view, gs_cam_format, **kwargs) # # load 40 views at a time, for inferring latents. # TODO merge all the useful APIs together class ChunkObjaverseDataset(Dataset): def __init__( self, file_path, reso, reso_encoder, preprocess=None, classes=False, load_depth=False, test=False, scene_scale=1, overfitting=False, imgnet_normalize=True, dataset_size=-1, overfitting_bs=-1, interval=1, plucker_embedding=False, shuffle_across_cls=False, wds_split=1, # 4 splits to accelerate preprocessing four_view_for_latent=False, single_view_for_i23d=False, load_extra_36_view=False, gs_cam_format=False, frame_0_as_canonical=True, split_chunk_size=10, mv_input=True, append_depth=False, append_xyz=False, wds_split_all=1, pcd_path=None, load_pcd=False, read_normal=False, load_raw=False, load_instance_only=False, mv_latent_dir='', perturb_pcd_scale=0.0, # shards_folder_num=4, # eval=False, **kwargs): super().__init__() # st() self.mv_latent_dir = mv_latent_dir self.load_raw = load_raw self.load_instance_only = load_instance_only self.read_normal = read_normal self.file_path = file_path self.chunk_size = split_chunk_size self.gs_cam_format = gs_cam_format self.frame_0_as_canonical = frame_0_as_canonical self.four_view_for_latent = four_view_for_latent # export 0 12 30 36, 4 views for reconstruction self.overfitting = overfitting self.scene_scale = scene_scale self.reso = reso self.reso_encoder = reso_encoder self.classes = False self.load_depth = load_depth self.preprocess = preprocess self.plucker_embedding = plucker_embedding self.intrinsics = get_intri(h=self.reso, w=self.reso, normalize=True).reshape(9) self.perturb_pcd_scale = perturb_pcd_scale assert not self.classes, "Not support class condition now." dataset_name = Path(self.file_path).stem.split('_')[0] self.dataset_name = dataset_name self.ray_sampler = RaySampler() self.zfar = 100.0 self.znear = 0.01 # ! load all chunk paths self.chunk_list = [] # if dataset_size != -1: # predefined instance # self.chunk_list = self.fetch_chunk_list(os.path.join(self.file_path, 'debug')) # else: # # for shard_idx in range(1, 5): # shard_dir 1-4 by default # for shard_idx in os.listdir(self.file_path): # self.chunk_list += self.fetch_chunk_list(os.path.join(self.file_path, shard_idx)) def load_single_cls_instances(file_path): ins_list = [] # the first 1 instance for evaluation reference. for dict_dir in os.listdir(file_path)[:]: # ! for debugging for ins_dir in os.listdir(os.path.join(file_path, dict_dir)): ins_list.append( os.path.join(file_path, dict_dir, ins_dir, 'campos_512_v4')) return ins_list # st() if self.load_raw: with open( # '/nas/shared/V2V/yslan/aigc3d/text_captions_cap3d.json') as f: # '/nas/shared/public/yslan//data/text_captions_cap3d.json') as f: './dataset/text_captions_3dtopia.json') as f: self.caption_data = json.load(f) # with open # # '/nas/shared/V2V/yslan/aigc3d/text_captions_cap3d.json') as f: # '/nas/shared/public/yslan//data/text_captions_cap3d.json') as f: # # '/cpfs01/shared/public/yhluo/Projects/threed/3D-Enhancer/develop/text_captions_3dtopia.json') as f: # self.old_caption_data = json.load(f) for subset in [ # ! around 17.6 W instances in total. 'Animals', # 'daily-used', # 'BuildingsOutdoor', # 'Furnitures', # 'Food', # 'Plants', # 'Electronics', # 'Transportations_tar', # 'Human-Shape', ]: # selected subset for training # self.chunk_list += load_single_cls_instances( # os.path.join(self.file_path, subset)) with open(f'shell_scripts/raw_img_list/{subset}.txt', 'r') as f: self.chunk_list += [os.path.join(subset, item.strip()) for item in f.readlines()] # st() # save to local # with open('/cpfs01/user/lanyushi.p/Repo/diffusion-3d/shell_scripts/shards_list/chunk_list.txt', 'w') as f: # f.writelines(self.chunk_list) # load raw g-objv dataset # self.img_ext = 'png' # ln3diff # for k, v in dataset_json.items(): # directly load from folders instead # self.chunk_list.extend(v) else: # ! direclty load from json with open(f'{self.file_path}/dataset.json', 'r') as f: dataset_json = json.load(f) # dataset_json = {'Animals': ['Animals/0/10017/1']} if self.chunk_size == 12: self.img_ext = 'png' # ln3diff for k, v in dataset_json.items(): self.chunk_list.extend(v) else: # extract latent assert self.chunk_size in [16,18, 20] self.img_ext = 'jpg' # more views for k, v in dataset_json.items(): # if k != 'BuildingsOutdoor': # cannot be handled by gs self.chunk_list.extend(v) # filter # st() # root = '/nas/shared/V2V/yslan/logs/nips23/Reconstruction/final/objav/vae/gs/infer-latents/768/8x8/animals-gs-latent/latent_dir' # root = '/nas/shared/V2V/yslan/logs/nips23/Reconstruction/final/objav/vae/gs/infer-latents/768/8x8/animals-gs-latent-dim=10-fullset/latent_dir' # filtered_chunk_list = [] # for v in self.chunk_list: # if os.path.exists(os.path.join(root, v[:-2], 'gaussians.npy') ): # continue # filtered_chunk_list.append(v) # self.chunk_list = filtered_chunk_list dataset_size = len(self.chunk_list) self.chunk_list = sorted(self.chunk_list) # self.chunk_list, self.eval_list = self.chunk_list[:int(dataset_size*0.95)], self.chunk_list[int(dataset_size*0.95):] # self.chunk_list = self.eval_list # self.wds_split_all = wds_split_all # for # self.wds_split_all = 1 # self.wds_split_all = 7 # self.wds_split_all = 4 self.wds_split_all = 1 # ! filter # st() if wds_split_all != 1: # ! retrieve the right wds split all_ins_size = len(self.chunk_list) ratio_size = all_ins_size // self.wds_split_all + 1 # ratio_size = int(all_ins_size / self.wds_split_all) + 1 print('ratio_size: ', ratio_size, 'all_ins_size: ', all_ins_size) self.chunk_list = self.chunk_list[ratio_size * (wds_split):ratio_size * (wds_split + 1)] # st() # load images from raw self.rgb_list = [] if self.load_instance_only: for ins in tqdm(self.chunk_list): ins_name = str(Path(ins).parent) # cur_all_fname = [f'{t:05d}' for t in range(40)] # load all instances for now self.rgb_list += ([ os.path.join(self.file_path, ins, fname + '.png') for fname in [f'{t}' for t in range(2)] # for fname in [f'{t:05d}' for t in range(2)] ]) # synthetic mv data # index mapping of mvi data to objv single-view data self.mvi_objv_mapping = { '0': '00000', '1': '00012', } # load gt mv data self.gt_chunk_list = [] self.gt_mv_file_path = '/cpfs01/user/lanyushi.p/data/chunk-jpeg-normal/bs_16_fixsave3/170K/512/' assert self.chunk_size in [16,18, 20] with open(f'{self.gt_mv_file_path}/dataset.json', 'r') as f: dataset_json = json.load(f) # dataset_json = {'Animals': dataset_json['Animals'] } # self.img_ext = 'jpg' # more views for k, v in dataset_json.items(): # if k != 'BuildingsOutdoor': # cannot be handled by gs self.gt_chunk_list.extend(v) elif self.load_raw: for ins in tqdm(self.chunk_list): # # st() # ins = ins[len('/nas/shared/V2V/yslan/aigc3d/unzip4/'):] # ins_name = str(Path(ins).relative_to(self.file_path).parent) ins_name = str(Path(ins).parent) # latent_path = os.path.join(self.mv_latent_dir, ins_name, 'latent.npz') # if not os.path.exists(latent_path): # continue cur_all_fname = [f'{t:05d}' for t in range(40)] # load all instances for now self.rgb_list += ([ os.path.join(self.file_path, ins, fname, fname + '.png') for fname in cur_all_fname ]) self.post_process = PostProcess( reso, reso_encoder, imgnet_normalize=imgnet_normalize, plucker_embedding=plucker_embedding, decode_encode_img_only=False, mv_input=mv_input, split_chunk_input=split_chunk_size, duplicate_sample=True, append_depth=append_depth, append_xyz=append_xyz, gs_cam_format=gs_cam_format, orthog_duplicate=False, frame_0_as_canonical=frame_0_as_canonical, pcd_path=pcd_path, load_pcd=load_pcd, split_chunk_size=split_chunk_size, ) self.kernel = torch.tensor([[0, 1, 0], [1, 1, 1], [0, 1, 0]]) # self.no_bottom = True # avoid loading bottom vew def fetch_chunk_list(self, file_path): if os.path.isdir(file_path): chunks = [ os.path.join(file_path, fname) for fname in os.listdir(file_path) if fname.isdigit() ] return chunks else: return [] def _pre_process_chunk(self): # e.g., remove bottom view pass def read_chunk(self, chunk_path): # equivalent to decode_zip() in wds # reshape chunk raw_img = imageio.imread( os.path.join(chunk_path, f'raw_img.{self.img_ext}')) h, bw, c = raw_img.shape raw_img = raw_img.reshape(h, self.chunk_size, -1, c).transpose( (1, 0, 2, 3)) c = np.load(os.path.join(chunk_path, 'c.npy')) with open(os.path.join(chunk_path, 'caption.txt'), 'r', encoding="utf-8") as f: caption = f.read() with open(os.path.join(chunk_path, 'ins.txt'), 'r', encoding="utf-8") as f: ins = f.read() bbox = np.load(os.path.join(chunk_path, 'bbox.npy')) if self.chunk_size > 16: depth_alpha = imageio.imread( os.path.join(chunk_path, 'depth_alpha.jpg')) # 2h 10w depth_alpha = depth_alpha.reshape(h * 2, self.chunk_size, -1).transpose((1, 0, 2)) depth, alpha = np.split(depth_alpha, 2, axis=1) d_near_far = np.load(os.path.join(chunk_path, 'd_near_far.npy')) d_near = d_near_far[0].reshape(self.chunk_size, 1, 1) d_far = d_near_far[1].reshape(self.chunk_size, 1, 1) # d = 1 / ( (d_normalized / 255) * (far-near) + near) depth = 1 / ((depth / 255) * (d_far - d_near) + d_near) depth[depth > 2.9] = 0.0 # background as 0, follow old tradition # ! filter anti-alias artifacts erode_mask = kornia.morphology.erosion( torch.from_numpy(alpha == 255).float().unsqueeze(1), self.kernel) # B 1 H W depth = (torch.from_numpy(depth).unsqueeze(1) * erode_mask).squeeze( 1) # shrink anti-alias bug else: # load separate alpha and depth map alpha = imageio.imread( os.path.join(chunk_path, f'alpha.{self.img_ext}')) alpha = alpha.reshape(h, self.chunk_size, h).transpose( (1, 0, 2)) depth = np.load(os.path.join(chunk_path, 'depth.npz'))['depth'] # depth = depth * (alpha==255) # mask out background # depth = np.stack([depth, alpha], -1) # rgba # if self.no_bottom: # raw_img # pass if self.read_normal: normal = imageio.imread(os.path.join( chunk_path, 'normal.png')).astype(np.float32) / 255.0 normal = (normal * 2 - 1).reshape(h, self.chunk_size, -1, 3).transpose((1, 0, 2, 3)) # fix g-buffer normal rendering coordinate # normal = unity2blender(normal) # ! still wrong normal = unity2blender_fix(normal) # ! depth = (depth, normal) # ? return raw_img, depth, c, alpha, bbox, caption, ins def __len__(self): return len(self.chunk_list) def __getitem__(self, index) -> Any: sample = self.read_chunk( os.path.join(self.file_path, self.chunk_list[index])) sample = self.post_process.paired_post_process_chunk(sample) sample = self.post_process.create_dict_nobatch(sample) # aug pcd # st() if self.perturb_pcd_scale > 0: if random.random() > 0.5: t = np.random.rand(sample['fps_pcd'].shape[0], 1, 1) * self.perturb_pcd_scale sample['fps_pcd'] = sample['fps_pcd'] + t * np.random.randn(*sample['fps_pcd'].shape) # type: ignore sample['fps_pcd'] = np.clip(sample['fps_pcd'], -0.45, 0.45) # truncate noisy augmentation return sample class ChunkObjaverseDatasetDDPM(ChunkObjaverseDataset): def __init__( self, file_path, reso, reso_encoder, preprocess=None, classes=False, load_depth=False, test=False, scene_scale=1, overfitting=False, imgnet_normalize=True, dataset_size=-1, overfitting_bs=-1, interval=1, plucker_embedding=False, shuffle_across_cls=False, wds_split=1, # 4 splits to accelerate preprocessing four_view_for_latent=False, single_view_for_i23d=False, load_extra_36_view=False, gs_cam_format=False, frame_0_as_canonical=True, split_chunk_size=10, mv_input=True, append_depth=False, append_xyz=False, pcd_path=None, load_pcd=False, read_normal=False, mv_latent_dir='', load_raw=False, # shards_folder_num=4, # eval=False, **kwargs): super().__init__( file_path, reso, reso_encoder, preprocess=None, classes=False, load_depth=False, test=False, scene_scale=1, overfitting=False, imgnet_normalize=True, dataset_size=-1, overfitting_bs=-1, interval=1, plucker_embedding=False, shuffle_across_cls=False, wds_split=1, # 4 splits to accelerate preprocessing four_view_for_latent=False, single_view_for_i23d=False, load_extra_36_view=False, gs_cam_format=False, frame_0_as_canonical=True, split_chunk_size=split_chunk_size, mv_input=True, append_depth=False, append_xyz=False, pcd_path=None, load_pcd=False, read_normal=False, load_raw=load_raw, mv_latent_dir=mv_latent_dir, # shards_folder_num=4, # eval=False, **kwargs) self.n_cond_frames = 6 self.perspective_transformer = v2.RandomPerspective(distortion_scale=0.4, p=0.15, fill=1, interpolation=torchvision.transforms.InterpolationMode.NEAREST) self.mv_resize_cls = torchvision.transforms.Resize(320, interpolation=torchvision.transforms.InterpolationMode.BILINEAR, max_size=None, antialias=True) # ! read img c, caption. def get_plucker_ray(self, c): rays_plucker = [] for idx in range(c.shape[0]): rays_o, rays_d = self.gen_rays(c[idx]) rays_plucker.append( torch.cat([torch.cross(rays_o, rays_d, dim=-1), rays_d], dim=-1).permute(2, 0, 1)) # [h, w, 6] -> 6,h,w rays_plucker = torch.stack(rays_plucker, 0) return rays_plucker def read_chunk(self, chunk_path): # equivalent to decode_zip() in wds # reshape chunk raw_img = imageio.imread( os.path.join(chunk_path, f'raw_img.{self.img_ext}')).astype(np.float32) h, bw, c = raw_img.shape raw_img = raw_img.reshape(h, self.chunk_size, -1, c).transpose( (1, 0, 2, 3)) c = np.load(os.path.join(chunk_path, 'c.npy')).astype(np.float32) with open(os.path.join(chunk_path, 'caption.txt'), 'r', encoding="utf-8") as f: caption = f.read() with open(os.path.join(chunk_path, 'ins.txt'), 'r', encoding="utf-8") as f: ins = f.read() return raw_img, c, caption, ins def _load_latent(self, ins): # if 'adv' in self.mv_latent_dir: # new latent codes saved have 3 augmentations # idx = random.choice([0,1,2]) # latent = np.load(os.path.join(self.mv_latent_dir, ins, f'latent-{idx}.npy')) # pre-calculated VAE latent # else: latent = np.load(os.path.join(self.mv_latent_dir, ins, 'latent.npy')) # pre-calculated VAE latent latent = repeat(latent, 'C H W -> B C H W', B=2) # return {'latent': latent} return latent def normalize_camera(self, c, c_frame0): # assert c.shape[0] == self.chunk_size # 8 o r10 B = c.shape[0] camera_poses = c[:, :16].reshape(B, 4, 4) # 3x4 canonical_camera_poses = c_frame0[:, :16].reshape(1, 4, 4) inverse_canonical_pose = np.linalg.inv(canonical_camera_poses) inverse_canonical_pose = np.repeat(inverse_canonical_pose, B, 0) cam_radius = np.linalg.norm( c_frame0[:, :16].reshape(1, 4, 4)[:, :3, 3], axis=-1, keepdims=False) # since g-buffer adopts dynamic radius here. frame1_fixed_pos = np.repeat(np.eye(4)[None], 1, axis=0) frame1_fixed_pos[:, 2, -1] = -cam_radius transform = frame1_fixed_pos @ inverse_canonical_pose new_camera_poses = np.repeat( transform, 1, axis=0 ) @ camera_poses # [V, 4, 4]. np.repeat() is th.repeat_interleave() c = np.concatenate([new_camera_poses.reshape(B, 16), c[:, 16:]], axis=-1) return c # @autocast # def plucker_embedding(self, c): # rays_o, rays_d = self.gen_rays(c) # rays_plucker = torch.cat( # [torch.cross(rays_o, rays_d, dim=-1), rays_d], # dim=-1).permute(2, 0, 1) # [h, w, 6] -> 6,h,w # return rays_plucker def __getitem__(self, index) -> Any: raw_img, c, caption, ins = self.read_chunk( os.path.join(self.file_path, self.chunk_list[index])) # sample = self.post_process.paired_post_process_chunk(sample) # ! random zoom in (scale augmentation) # for i in range(img.shape[0]): # for v in range(img.shape[1]): # if random.random() > 0.8: # rand_bg_scale = random.randint(60,99) / 100 # st() # img[i,v] = recenter(img[i,v], np.ones_like(img[i,v]), border_ratio=rand_bg_scale) # ! process raw_img = torch.from_numpy(raw_img).permute(0, 3, 1, 2) / 255.0 # [0,1] if raw_img.shape[-1] != self.reso: raw_img = torch.nn.functional.interpolate( input=raw_img, size=(self.reso, self.reso), mode='bilinear', align_corners=False, ) img = raw_img * 2 - 1 # as gt # ! load latent latent, _ = self._load_latent(ins) # ! shuffle indices = np.random.permutation(self.chunk_size) img = img[indices] c = c[indices] img = self.perspective_transformer(img) # create 3D inconsistency # ! split along V and repeat other stuffs accordingly img = rearrange(img, '(B V) ... -> B V ...', B=2)[:, :self.n_cond_frames] c = rearrange(c, '(B V) ... -> B V ...', B=2)[:, :self.n_cond_frames] # 2 6 25 # rand perspective aug caption = [caption, caption] ins = [ins, ins] # load plucker coord # st() # plucker_c = self.get_plucker_ray(rearrange(c[:, 1:1+self.n_cond_frames], "b t ... -> (b t) ...")) # plucker_c = rearrange(c, '(B V) ... -> B V ...', B=2) # 2 6 25 # use view-space camera tradition c[0] = self.normalize_camera(c[0], c[0,0:1]) c[1] = self.normalize_camera(c[1], c[1,0:1]) # https://github.com/TencentARC/InstantMesh/blob/7fe95627cf819748f7830b2b278f302a9d798d17/src/model.py#L70 # c = np.concatenate([c[..., :12], c[..., 16:17], c[..., 20:21], c[..., 18:19], c[..., 21:22]], axis=-1) # c = c + np.random.randn(*c.shape) * 0.04 - 0.02 # ! to dict # sample = self.post_process.create_dict_nobatch(sample) ret_dict = { 'caption': caption, 'ins': ins, 'c': c, 'img': img, # fix inp img range to [-1,1] 'latent': latent, # **latent } # st() return ret_dict class ChunkObjaverseDatasetDDPMgs(ChunkObjaverseDatasetDDPM): def __init__( self, file_path, reso, reso_encoder, preprocess=None, classes=False, load_depth=False, test=False, scene_scale=1, overfitting=False, imgnet_normalize=True, dataset_size=-1, overfitting_bs=-1, interval=1, plucker_embedding=False, shuffle_across_cls=False, wds_split=1, # 4 splits to accelerate preprocessing four_view_for_latent=False, single_view_for_i23d=False, load_extra_36_view=False, gs_cam_format=False, frame_0_as_canonical=True, split_chunk_size=10, mv_input=True, append_depth=False, append_xyz=False, pcd_path=None, load_pcd=False, read_normal=False, mv_latent_dir='', load_raw=False, # shards_folder_num=4, # eval=False, **kwargs): super().__init__( file_path, reso, reso_encoder, preprocess=preprocess, classes=classes, load_depth=load_depth, test=test, scene_scale=scene_scale, overfitting=overfitting, imgnet_normalize=imgnet_normalize, dataset_size=dataset_size, overfitting_bs=overfitting_bs, interval=interval, plucker_embedding=plucker_embedding, shuffle_across_cls=shuffle_across_cls, wds_split=wds_split, # 4 splits to accelerate preprocessing four_view_for_latent=four_view_for_latent, single_view_for_i23d=single_view_for_i23d, load_extra_36_view=load_extra_36_view, gs_cam_format=gs_cam_format, frame_0_as_canonical=frame_0_as_canonical, split_chunk_size=split_chunk_size, mv_input=mv_input, append_depth=append_depth, append_xyz=append_xyz, pcd_path=pcd_path, load_pcd=load_pcd, read_normal=read_normal, mv_latent_dir=mv_latent_dir, load_raw=load_raw, # shards_folder_num=4, # eval=False, **kwargs) self.avoid_loading_first = False # self.feat_scale_factor = torch.Tensor([0.99227685, 1.014337 , 0.20842505, 0.98727155, 0.3305389 , # 0.38729668, 1.0155401 , 0.9728264 , 1.0009694 , 0.97328585, # 0.2881106 , 0.1652732 , 0.3482468 , 0.9971449 , 0.99895126, # 0.18491288]).float().reshape(1,1,-1) # stat for normalization # self.xyz_mean = torch.Tensor([-0.00053714, 0.08095618, -0.01914407] ).reshape(1, 3).float() # self.xyz_std = np.array([0.14593576, 0.15753542, 0.18873914] ).reshape(1,3).astype(np.float32) # self.xyz_std = np.array([0.14593576, 0.15753542, 0.18873914] ).reshape(1,3).astype(np.float32) self.xyz_std = 0.164 # a global scaler self.kl_mean = np.array([ 0.0184, 0.0024, 0.0926, 0.0517, 0.1781, 0.7137, -0.0355, 0.0267, 0.0183, 0.0164, -0.5090, 0.2406, 0.2733, -0.0256, -0.0285, 0.0761]).reshape(1,16).astype(np.float32) self.kl_std = np.array([1.0018, 1.0309, 1.3001, 1.0160, 0.8182, 0.8023, 1.0591, 0.9789, 0.9966, 0.9448, 0.8908, 1.4595, 0.7957, 0.9871, 1.0236, 1.2923]).reshape(1,16).astype(np.float32) def normalize_pcd_act(self, x): return x / self.xyz_std def normalize_kl_feat(self, latent): # return latent / self.feat_scale_factor return (latent-self.kl_mean) / self.kl_std def _load_latent(self, ins, rand_pick_one=False, pick_both=False): if 'adv' in self.mv_latent_dir: # new latent codes saved have 3 augmentations idx = random.choice([0,1,2]) # idx = random.choice([0]) latent = np.load(os.path.join(self.mv_latent_dir, ins, f'latent-{idx}.npz')) # pre-calculated VAE latent else: latent = np.load(os.path.join(self.mv_latent_dir, ins, 'latent.npz')) # pre-calculated VAE latent latent, fps_xyz = latent['latent_normalized'], latent['query_pcd_xyz'] # 2,768,16; 2,768,3 if not pick_both: if rand_pick_one: rand_idx = random.randint(0,1) else: rand_idx = 0 latent, fps_xyz = latent[rand_idx:rand_idx+1], fps_xyz[rand_idx:rand_idx+1] # per-channel normalize to std=1 & concat # latent_pcd = np.concatenate([self.normalize_kl_feat(latent), self.normalize_pcd_act(fps_xyz)], -1) # latent_pcd = np.concatenate([latent, self.normalize_pcd_act(fps_xyz)], -1) # return latent_pcd, fps_xyz return latent, fps_xyz def __getitem__(self, index) -> Any: raw_img, c, caption, ins = self.read_chunk( os.path.join(self.file_path, self.chunk_list[index])) # sample = self.post_process.paired_post_process_chunk(sample) # ! random zoom in (scale augmentation) # for i in range(img.shape[0]): # for v in range(img.shape[1]): # if random.random() > 0.8: # rand_bg_scale = random.randint(60,99) / 100 # st() # img[i,v] = recenter(img[i,v], np.ones_like(img[i,v]), border_ratio=rand_bg_scale) # ! process raw_img = torch.from_numpy(raw_img).permute(0, 3, 1, 2) / 255.0 # [0,1] if raw_img.shape[-1] != self.reso: raw_img = torch.nn.functional.interpolate( input=raw_img, size=(self.reso, self.reso), mode='bilinear', align_corners=False, ) img = raw_img * 2 - 1 # as gt # ! load latent # latent, _ = self._load_latent(ins) latent, fps_xyz = self._load_latent(ins, pick_both=True) # analyzing xyz/latent disentangled diffusion # latent, fps_xyz = latent[0], fps_xyz[0] # remove batch dim here # fps_xyz = fps_xyz / self.scaling_factor # for xyz training normalized_fps_xyz = self.normalize_pcd_act(fps_xyz) if self.avoid_loading_first: # for training mv model index = list(range(1,6)) + list(range(7,12)) img = img[index] c = c[index] # ! shuffle indices = np.random.permutation(img.shape[0]) img = img[indices] c = c[indices] img = self.perspective_transformer(img) # create 3D inconsistency # ! split along V and repeat other stuffs accordingly img = rearrange(img, '(B V) ... -> B V ...', B=2)[:, :self.n_cond_frames] c = rearrange(c, '(B V) ... -> B V ...', B=2)[:, :self.n_cond_frames] # 2 6 25 # rand perspective aug caption = [caption, caption] ins = [ins, ins] # load plucker coord # st() # plucker_c = self.get_plucker_ray(rearrange(c[:, 1:1+self.n_cond_frames], "b t ... -> (b t) ...")) # plucker_c = rearrange(c, '(B V) ... -> B V ...', B=2) # 2 6 25 # use view-space camera tradition c[0] = self.normalize_camera(c[0], c[0,0:1]) c[1] = self.normalize_camera(c[1], c[1,0:1]) # ! to dict # sample = self.post_process.create_dict_nobatch(sample) ret_dict = { 'caption': caption, 'ins': ins, 'c': c, 'img': img, # fix inp img range to [-1,1] 'latent': latent, 'normalized-fps-xyz': normalized_fps_xyz # **latent } # st() return ret_dict class ChunkObjaverseDatasetDDPMgsT23D(ChunkObjaverseDatasetDDPMgs): def __init__( self, file_path, reso, reso_encoder, preprocess=None, classes=False, load_depth=False, test=False, scene_scale=1, overfitting=False, imgnet_normalize=True, dataset_size=-1, overfitting_bs=-1, interval=1, plucker_embedding=False, shuffle_across_cls=False, wds_split=1, # 4 splits to accelerate preprocessing four_view_for_latent=False, single_view_for_i23d=False, load_extra_36_view=False, gs_cam_format=False, frame_0_as_canonical=True, split_chunk_size=10, mv_input=True, append_depth=False, append_xyz=False, pcd_path=None, load_pcd=False, read_normal=False, mv_latent_dir='', # shards_folder_num=4, # eval=False, **kwargs): super().__init__( file_path, reso, reso_encoder, preprocess=preprocess, classes=classes, load_depth=load_depth, test=test, scene_scale=scene_scale, overfitting=overfitting, imgnet_normalize=imgnet_normalize, dataset_size=dataset_size, overfitting_bs=overfitting_bs, interval=interval, plucker_embedding=plucker_embedding, shuffle_across_cls=shuffle_across_cls, wds_split=wds_split, # 4 splits to accelerate preprocessing four_view_for_latent=four_view_for_latent, single_view_for_i23d=single_view_for_i23d, load_extra_36_view=load_extra_36_view, gs_cam_format=gs_cam_format, frame_0_as_canonical=frame_0_as_canonical, split_chunk_size=split_chunk_size, mv_input=mv_input, append_depth=append_depth, append_xyz=append_xyz, pcd_path=pcd_path, load_pcd=load_pcd, read_normal=read_normal, mv_latent_dir=mv_latent_dir, load_raw=True, # shards_folder_num=4, # eval=False, **kwargs) # def __len__(self): # return 40 def __len__(self): return len(self.rgb_list) def __getitem__(self, index) -> Any: rgb_path = self.rgb_list[index] ins = str(Path(rgb_path).relative_to(self.file_path).parent.parent.parent) # load caption caption = self.caption_data['/'.join(ins.split('/')[1:])] # chunk_path = os.path.join(self.file_path, self.chunk_list[index]) # # load caption # with open(os.path.join(chunk_path, 'caption.txt'), # 'r', # encoding="utf-8") as f: # caption = f.read() # # load latent # with open(os.path.join(chunk_path, 'ins.txt'), 'r', # encoding="utf-8") as f: # ins = f.read() latent, fps_xyz = self._load_latent(ins, True) # analyzing xyz/latent disentangled diffusion latent, fps_xyz = latent[0], fps_xyz[0] # remove batch dim here # fps_xyz = fps_xyz / self.scaling_factor # for xyz training normalized_fps_xyz = self.normalize_pcd_act(fps_xyz) # ! to dict ret_dict = { # 'caption': caption, 'latent': latent, # 'img': img, 'fps-xyz': fps_xyz, 'normalized-fps-xyz': normalized_fps_xyz, 'caption': caption } return ret_dict class ChunkObjaverseDatasetDDPMgsI23D(ChunkObjaverseDatasetDDPMgs): def __init__( self, file_path, reso, reso_encoder, preprocess=None, classes=False, load_depth=False, test=False, scene_scale=1, overfitting=False, imgnet_normalize=True, dataset_size=-1, overfitting_bs=-1, interval=1, plucker_embedding=False, shuffle_across_cls=False, wds_split=1, # 4 splits to accelerate preprocessing four_view_for_latent=False, single_view_for_i23d=False, load_extra_36_view=False, gs_cam_format=False, frame_0_as_canonical=True, split_chunk_size=10, mv_input=True, append_depth=False, append_xyz=False, pcd_path=None, load_pcd=False, read_normal=False, mv_latent_dir='', # shards_folder_num=4, # eval=False, **kwargs): super().__init__( file_path, reso, reso_encoder, preprocess=preprocess, classes=classes, load_depth=load_depth, test=test, scene_scale=scene_scale, overfitting=overfitting, imgnet_normalize=imgnet_normalize, dataset_size=dataset_size, overfitting_bs=overfitting_bs, interval=interval, plucker_embedding=plucker_embedding, shuffle_across_cls=shuffle_across_cls, wds_split=wds_split, # 4 splits to accelerate preprocessing four_view_for_latent=four_view_for_latent, single_view_for_i23d=single_view_for_i23d, load_extra_36_view=load_extra_36_view, gs_cam_format=gs_cam_format, frame_0_as_canonical=frame_0_as_canonical, split_chunk_size=split_chunk_size, mv_input=mv_input, append_depth=append_depth, append_xyz=append_xyz, pcd_path=pcd_path, load_pcd=load_pcd, read_normal=read_normal, mv_latent_dir=mv_latent_dir, load_raw=True, # shards_folder_num=4, # eval=False, **kwargs) assert self.load_raw self.scaling_factor = np.array([0.14593576, 0.15753542, 0.18873914]) def __len__(self): return len(self.rgb_list) # def __len__(self): # return 40 def __getitem__(self, index) -> Any: rgb_path = self.rgb_list[index] ins = str(Path(rgb_path).relative_to(self.file_path).parent.parent.parent) raw_img = imageio.imread(rgb_path).astype(np.float32) alpha_mask = raw_img[..., -1:] / 255 raw_img = alpha_mask * raw_img[..., :3] + ( 1 - alpha_mask) * np.ones_like(raw_img[..., :3]) * 255 raw_img = cv2.resize(raw_img, (self.reso, self.reso), interpolation=cv2.INTER_CUBIC) raw_img = torch.from_numpy(raw_img).permute(2,0,1).clip(0,255) # [0,1] img = raw_img / 127.5 - 1 # with open(os.path.join(chunk_path, 'caption.txt'), # 'r', # encoding="utf-8") as f: # caption = f.read() # latent = self._load_latent(ins, True)[0] latent, fps_xyz = self._load_latent(ins, True) # analyzing xyz/latent disentangled diffusion latent, fps_xyz = latent[0], fps_xyz[0] # fps_xyz = fps_xyz / self.scaling_factor # for xyz training normalized_fps_xyz = self.normalize_pcd_act(fps_xyz) # load caption caption = self.caption_data['/'.join(ins.split('/')[1:])] # ! to dict ret_dict = { # 'caption': caption, 'latent': latent, 'img': img.numpy(), # no idea whether loading Tensor leads to 'too many files opened' 'fps-xyz': fps_xyz, 'normalized-fps-xyz': normalized_fps_xyz, 'caption': caption } return ret_dict class ChunkObjaverseDatasetDDPMgsMV23D(ChunkObjaverseDatasetDDPMgs): def __init__( self, file_path, reso, reso_encoder, preprocess=None, classes=False, load_depth=False, test=False, scene_scale=1, overfitting=False, imgnet_normalize=True, dataset_size=-1, overfitting_bs=-1, interval=1, plucker_embedding=False, shuffle_across_cls=False, wds_split=1, # 4 splits to accelerate preprocessing four_view_for_latent=False, single_view_for_i23d=False, load_extra_36_view=False, gs_cam_format=False, frame_0_as_canonical=True, split_chunk_size=10, mv_input=True, append_depth=False, append_xyz=False, pcd_path=None, load_pcd=False, read_normal=False, mv_latent_dir='', # shards_folder_num=4, # eval=False, **kwargs): super().__init__( file_path, reso, reso_encoder, preprocess=preprocess, classes=classes, load_depth=load_depth, test=test, scene_scale=scene_scale, overfitting=overfitting, imgnet_normalize=imgnet_normalize, dataset_size=dataset_size, overfitting_bs=overfitting_bs, interval=interval, plucker_embedding=plucker_embedding, shuffle_across_cls=shuffle_across_cls, wds_split=wds_split, # 4 splits to accelerate preprocessing four_view_for_latent=four_view_for_latent, single_view_for_i23d=single_view_for_i23d, load_extra_36_view=load_extra_36_view, gs_cam_format=gs_cam_format, frame_0_as_canonical=frame_0_as_canonical, split_chunk_size=split_chunk_size, mv_input=mv_input, append_depth=append_depth, append_xyz=append_xyz, pcd_path=pcd_path, load_pcd=load_pcd, read_normal=read_normal, mv_latent_dir=mv_latent_dir, load_raw=False, # shards_folder_num=4, # eval=False, **kwargs) assert not self.load_raw # self.scaling_factor = np.array([0.14593576, 0.15753542, 0.18873914]) self.n_cond_frames = 4 # a easy version for now. self.avoid_loading_first = True def __getitem__(self, index) -> Any: raw_img, c, caption, ins = self.read_chunk( os.path.join(self.file_path, self.chunk_list[index])) # ! process raw_img = torch.from_numpy(raw_img).permute(0, 3, 1, 2) / 255.0 # [0,1] if raw_img.shape[-1] != self.reso: raw_img = torch.nn.functional.interpolate( input=raw_img, size=(self.reso, self.reso), mode='bilinear', align_corners=False, ) img = raw_img * 2 - 1 # as gt # ! load latent # latent, _ = self._load_latent(ins) latent, fps_xyz = self._load_latent(ins, pick_both=True) # analyzing xyz/latent disentangled diffusion # latent, fps_xyz = latent[0], fps_xyz[0] # remove batch dim here # fps_xyz = fps_xyz / self.scaling_factor # for xyz training normalized_fps_xyz = self.normalize_pcd_act(fps_xyz) if self.avoid_loading_first: # for training mv model index = list(range(1,self.chunk_size//2)) + list(range(self.chunk_size//2+1, self.chunk_size)) img = img[index] c = c[index] # ! shuffle indices = np.random.permutation(img.shape[0]) img = img[indices] c = c[indices] aug_img = self.perspective_transformer(img) # create 3D inconsistency # ! split along V and repeat other stuffs accordingly img = rearrange(img, '(B V) ... -> B V ...', B=2)[:, 0:1] # only return first view (randomly sampled) aug_img = rearrange(aug_img, '(B V) ... -> B V ...', B=2)[:, 1:self.n_cond_frames+1] c = rearrange(c, '(B V) ... -> B V ...', B=2)[:, 1:self.n_cond_frames+1] # 2 6 25 # use view-space camera tradition c[0] = self.normalize_camera(c[0], c[0,0:1]) c[1] = self.normalize_camera(c[1], c[1,0:1]) caption = [caption, caption] ins = [ins, ins] # ! to dict # sample = self.post_process.create_dict_nobatch(sample) ret_dict = { 'caption': caption, 'ins': ins, 'c': c, 'img': img, # fix inp img range to [-1,1] 'mv_img': aug_img, 'latent': latent, 'normalized-fps-xyz': normalized_fps_xyz # **latent } # st() return ret_dict class ChunkObjaverseDatasetDDPMgsMV23DSynthetic(ChunkObjaverseDatasetDDPMgs): def __init__( self, file_path, reso, reso_encoder, preprocess=None, classes=False, load_depth=False, test=False, scene_scale=1, overfitting=False, imgnet_normalize=True, dataset_size=-1, overfitting_bs=-1, interval=1, plucker_embedding=False, shuffle_across_cls=False, wds_split=1, # 4 splits to accelerate preprocessing four_view_for_latent=False, single_view_for_i23d=False, load_extra_36_view=False, gs_cam_format=False, frame_0_as_canonical=True, split_chunk_size=10, mv_input=True, append_depth=False, append_xyz=False, pcd_path=None, load_pcd=False, read_normal=False, mv_latent_dir='', # shards_folder_num=4, # eval=False, **kwargs): super().__init__( file_path, reso, reso_encoder, preprocess=preprocess, classes=classes, load_depth=load_depth, test=test, scene_scale=scene_scale, overfitting=overfitting, imgnet_normalize=imgnet_normalize, dataset_size=dataset_size, overfitting_bs=overfitting_bs, interval=interval, plucker_embedding=plucker_embedding, shuffle_across_cls=shuffle_across_cls, wds_split=wds_split, # 4 splits to accelerate preprocessing four_view_for_latent=four_view_for_latent, single_view_for_i23d=single_view_for_i23d, load_extra_36_view=load_extra_36_view, gs_cam_format=gs_cam_format, frame_0_as_canonical=frame_0_as_canonical, split_chunk_size=split_chunk_size, mv_input=mv_input, append_depth=append_depth, append_xyz=append_xyz, pcd_path=pcd_path, load_pcd=load_pcd, read_normal=read_normal, mv_latent_dir=mv_latent_dir, load_raw=True, load_instance_only=True, # shards_folder_num=4, # eval=False, **kwargs) # assert not self.load_raw # self.scaling_factor = np.array([0.14593576, 0.15753542, 0.18873914]) self.n_cond_frames = 6 # a easy version for now. self.avoid_loading_first = True self.indices = np.array([0,1,2,3,4,5]) self.img_root_dir = '/cpfs01/user/lanyushi.p/data/unzip4_img' azimuths = np.array([30, 90, 150, 210, 270, 330]).astype(float) elevations = np.array([20, -10, 20, -10, 20, -10]).astype(float) zero123pp_pose, _ = generate_input_camera(1.8, [[elevations[i], azimuths[i]] for i in range(6)], fov=30) K = torch.Tensor([1.3889, 0.0000, 0.5000, 0.0000, 1.3889, 0.5000, 0.0000, 0.0000, 0.0039]).to(zero123pp_pose) # keeps the same zero123pp_pose = torch.cat([zero123pp_pose.reshape(6,-1), K.unsqueeze(0).repeat(6,1)], dim=-1) eval_camera = zero123pp_pose[self.indices].float().cpu().numpy() # for normalization self.eval_camera = self.normalize_camera(eval_camera, eval_camera[0:1]) # the first img is not used. # self.load_synthetic_only = False self.load_synthetic_only = True def __len__(self): return len(self.rgb_list) def _getitem_synthetic(self, index) -> Any: rgb_fname = Path(self.rgb_list[index]) # ins = self.mvi_objv_mapping(rgb_fname.parent.parent.stem) # ins = str(Path(rgb_fname).parent.parent.stem) ins = str((Path(rgb_fname).relative_to(self.file_path)).parent.parent) mv_img = imageio.imread(rgb_fname) # st() mv_img = rearrange(mv_img, '(n h) (m w) c -> (n m) h w c', n=3, m=2)[self.indices] # (6, 3, 320, 320) mv_img = np.stack([recenter(img, np.ones_like(img), border_ratio=0.1) for img in mv_img], axis=0) mv_img = rearrange(mv_img, 'b h w c -> b c h w') # to torch tradition mv_img = torch.from_numpy(mv_img) / 127.5 - 1 # ! load single-view image here img_idx = self.mvi_objv_mapping[rgb_fname.stem] img_path = os.path.join(self.img_root_dir, rgb_fname.parent.relative_to(self.file_path), img_idx, f'{img_idx}.png') raw_img = imageio.imread(img_path).astype(np.float32) alpha_mask = raw_img[..., -1:] / 255 raw_img = alpha_mask * raw_img[..., :3] + ( 1 - alpha_mask) * np.ones_like(raw_img[..., :3]) * 255 raw_img = cv2.resize(raw_img, (self.reso, self.reso), interpolation=cv2.INTER_CUBIC) raw_img = torch.from_numpy(raw_img).permute(2,0,1).clip(0,255) # [0,1] img = raw_img / 127.5 - 1 latent, fps_xyz = self._load_latent(ins, pick_both=False) # analyzing xyz/latent disentangled diffusion latent, fps_xyz = latent[0], fps_xyz[0] normalized_fps_xyz = self.normalize_pcd_act(fps_xyz) # for stage-1 # use view-space camera tradition # ins = [ins, ins] # st() caption = self.caption_data['/'.join(ins.split('/')[1:])] # ! to dict # sample = self.post_process.create_dict_nobatch(sample) ret_dict = { 'caption': caption, # 'ins': ins, 'c': self.eval_camera, 'img': img, # fix inp img range to [-1,1] 'mv_img': mv_img, 'latent': latent, 'normalized-fps-xyz': normalized_fps_xyz, 'fps-xyz': fps_xyz, } return ret_dict def _getitem_gt(self, index) -> Any: raw_img, c, caption, ins = self.read_chunk( os.path.join(self.gt_mv_file_path, self.gt_chunk_list[index])) # ! process raw_img = torch.from_numpy(raw_img).permute(0, 3, 1, 2) / 255.0 # [0,1] if raw_img.shape[-1] != self.reso: raw_img = torch.nn.functional.interpolate( input=raw_img, size=(self.reso, self.reso), mode='bilinear', align_corners=False, ) img = raw_img * 2 - 1 # as gt # ! load latent # latent, _ = self._load_latent(ins) latent, fps_xyz = self._load_latent(ins, pick_both=True) # analyzing xyz/latent disentangled diffusion # latent, fps_xyz = latent[0], fps_xyz[0] # remove batch dim here # fps_xyz = fps_xyz / self.scaling_factor # for xyz training normalized_fps_xyz = self.normalize_pcd_act(fps_xyz) if self.avoid_loading_first: # for training mv model index = list(range(1,self.chunk_size//2)) + list(range(self.chunk_size//2+1, self.chunk_size)) img = img[index] c = c[index] # ! shuffle indices = np.random.permutation(img.shape[0]) img = img[indices] c = c[indices] # st() aug_img = self.mv_resize_cls(img) aug_img = self.perspective_transformer(aug_img) # create 3D inconsistency # ! split along V and repeat other stuffs accordingly img = rearrange(img, '(B V) ... -> B V ...', B=2)[:, 0:1] # only return first view (randomly sampled) aug_img = rearrange(aug_img, '(B V) ... -> B V ...', B=2)[:, 1:self.n_cond_frames+1] c = rearrange(c, '(B V) ... -> B V ...', B=2)[:, 1:self.n_cond_frames+1] # 2 6 25 # use view-space camera tradition c[0] = self.normalize_camera(c[0], c[0,0:1]) c[1] = self.normalize_camera(c[1], c[1,0:1]) caption = [caption, caption] ins = [ins, ins] # ! to dict # sample = self.post_process.create_dict_nobatch(sample) ret_dict = { 'caption': caption, 'ins': ins, 'c': c, 'img': img, # fix inp img range to [-1,1] 'mv_img': aug_img, 'latent': latent, 'normalized-fps-xyz': normalized_fps_xyz, 'fps-xyz': fps_xyz, } return ret_dict def __getitem__(self, index) -> Any: # load synthetic version try: synthetic_mv = self._getitem_synthetic(index) except Exception as e: # logger.log(Path(self.rgb_list[index]), 'missing') synthetic_mv = self._getitem_synthetic(random.randint(0, len(self.rgb_list)//2)) if self.load_synthetic_only: return synthetic_mv else: # load gt mv chunk gt_chunk_index = random.randint(0, len(self.gt_chunk_list)-1) gt_mv = self._getitem_gt(gt_chunk_index) # merge them together along batch dim merged_mv = {} for k, v in synthetic_mv.items(): # merge, synthetic - gt order if k not in ['caption', 'ins']: if k == 'img': merged_mv[k] = np.concatenate([v[None], gt_mv[k][:, 0]], axis=0).astype(np.float32) else: merged_mv[k] = np.concatenate([v[None], gt_mv[k]], axis=0).astype(np.float32) else: merged_mv[k] = [v] + gt_mv[k] # list return merged_mv class ChunkObjaverseDatasetDDPMgsI23D_loadMV(ChunkObjaverseDatasetDDPMgs): def __init__( self, file_path, reso, reso_encoder, preprocess=None, classes=False, load_depth=False, test=False, scene_scale=1, overfitting=False, imgnet_normalize=True, dataset_size=-1, overfitting_bs=-1, interval=1, plucker_embedding=False, shuffle_across_cls=False, wds_split=1, # 4 splits to accelerate preprocessing four_view_for_latent=False, single_view_for_i23d=False, load_extra_36_view=False, gs_cam_format=False, frame_0_as_canonical=True, split_chunk_size=10, mv_input=True, append_depth=False, append_xyz=False, pcd_path=None, load_pcd=False, read_normal=False, mv_latent_dir='', canonicalize_pcd=False, # shards_folder_num=4, # eval=False, **kwargs): super().__init__( file_path, reso, reso_encoder, preprocess=preprocess, classes=classes, load_depth=load_depth, test=test, scene_scale=scene_scale, overfitting=overfitting, imgnet_normalize=imgnet_normalize, dataset_size=dataset_size, overfitting_bs=overfitting_bs, interval=interval, plucker_embedding=plucker_embedding, shuffle_across_cls=shuffle_across_cls, wds_split=wds_split, # 4 splits to accelerate preprocessing four_view_for_latent=four_view_for_latent, single_view_for_i23d=single_view_for_i23d, load_extra_36_view=load_extra_36_view, gs_cam_format=gs_cam_format, frame_0_as_canonical=frame_0_as_canonical, split_chunk_size=split_chunk_size, mv_input=mv_input, append_depth=append_depth, append_xyz=append_xyz, pcd_path=pcd_path, load_pcd=load_pcd, read_normal=read_normal, mv_latent_dir=mv_latent_dir, load_raw=False, # shards_folder_num=4, # eval=False, **kwargs) assert not self.load_raw # self.scaling_factor = np.array([0.14593576, 0.15753542, 0.18873914]) self.n_cond_frames = 5 # a easy version for now. self.avoid_loading_first = True # self.canonicalize_pcd = canonicalize_pcd # self.canonicalize_pcd = True self.canonicalize_pcd = False def canonicalize_xyz(self, c, pcd): B = c.shape[0] camera_poses_rot = c[:, :16].reshape(B, 4, 4)[:, :3, :3] R_inv = np.transpose(camera_poses_rot, (0,2,1)) # w2c rotation new_pcd = (R_inv @ np.transpose(pcd, (0,2,1))) # B 3 3 @ B 3 N new_pcd = np.transpose(new_pcd, (0,2,1)) return new_pcd def __getitem__(self, index) -> Any: raw_img, c, caption, ins = self.read_chunk( os.path.join(self.file_path, self.chunk_list[index])) # ! process raw_img = torch.from_numpy(raw_img).permute(0, 3, 1, 2) / 255.0 # [0,1] if raw_img.shape[-1] != self.reso: raw_img = torch.nn.functional.interpolate( input=raw_img, size=(self.reso, self.reso), mode='bilinear', align_corners=False, ) img = raw_img * 2 - 1 # as gt # ! load latent # latent, _ = self._load_latent(ins) if self.avoid_loading_first: # for training mv model index = list(range(1,self.chunk_size//2)) + list(range(self.chunk_size//2+1, self.chunk_size)) img = img[index] c = c[index] # ! shuffle indices = np.random.permutation(img.shape[0])[:self.n_cond_frames*2] img = img[indices] c = c[indices] latent, fps_xyz = self._load_latent(ins, pick_both=True) # analyzing xyz/latent disentangled diffusion # latent, fps_xyz = latent[0], fps_xyz[0] # remove batch dim here fps_xyz = np.repeat(fps_xyz, self.n_cond_frames, 0) latent = np.repeat(latent, self.n_cond_frames, 0) normalized_fps_xyz = self.normalize_pcd_act(fps_xyz) if self.canonicalize_pcd: normalized_fps_xyz = self.canonicalize_xyz(c, normalized_fps_xyz) # repeat caption = [caption] * self.n_cond_frames * 2 ins = [ins] * self.n_cond_frames * 2 ret_dict = { 'caption': caption, 'ins': ins, 'c': c, 'img': img, # fix inp img range to [-1,1] 'latent': latent, 'normalized-fps-xyz': normalized_fps_xyz, 'fps-xyz': fps_xyz, # **latent } return ret_dict class RealDataset(Dataset): def __init__( self, file_path, reso, reso_encoder, preprocess=None, classes=False, load_depth=False, test=False, scene_scale=1, overfitting=False, imgnet_normalize=True, dataset_size=-1, overfitting_bs=-1, interval=1, plucker_embedding=False, shuffle_across_cls=False, wds_split=1, # 4 splits to accelerate preprocessing ) -> None: super().__init__() self.file_path = file_path self.overfitting = overfitting self.scene_scale = scene_scale self.reso = reso self.reso_encoder = reso_encoder self.classes = False self.load_depth = load_depth self.preprocess = preprocess self.plucker_embedding = plucker_embedding self.rgb_list = [] all_fname = [ t for t in os.listdir(self.file_path) if t.split('.')[1] in ['png', 'jpg'] ] all_fname = [name for name in all_fname if '-input' in name ] self.rgb_list += ([ os.path.join(self.file_path, fname) for fname in all_fname ]) # st() # if len(self.rgb_list) == 1: # # placeholder # self.rgb_list = self.rgb_list * 40 # ! setup normalizataion transformations = [ transforms.ToTensor(), # [0,1] range ] assert imgnet_normalize if imgnet_normalize: transformations.append( transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) # type: ignore ) else: transformations.append( transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))) # type: ignore self.normalize = transforms.Compose(transformations) # camera = torch.load('eval_pose.pt', map_location='cpu') # self.eval_camera = camera # pre-cache # self.calc_rays_plucker() def __len__(self): return len(self.rgb_list) def __getitem__(self, index) -> Any: # return super().__getitem__(index) rgb_fname = self.rgb_list[index] # ! preprocess, normalize raw_img = imageio.imread(rgb_fname) # interpolation=cv2.INTER_AREA) if raw_img.shape[-1] == 4: alpha_mask = raw_img[..., 3:4] / 255.0 bg_white = np.ones_like(alpha_mask) * 255.0 raw_img = raw_img[..., :3] * alpha_mask + ( 1 - alpha_mask) * bg_white #[3, reso_encoder, reso_encoder] raw_img = raw_img.astype(np.uint8) # raw_img = recenter(raw_img, np.ones_like(raw_img), border_ratio=0.2) # log gt img = cv2.resize(raw_img, (self.reso, self.reso), interpolation=cv2.INTER_LANCZOS4) img = torch.from_numpy(img)[..., :3].permute( 2, 0, 1 ) / 127.5 - 1 #[3, reso, reso], normalize to [-1,1], follow triplane range ret_dict = { # 'rgb_fname': rgb_fname, # 'img_to_encoder': # img_to_encoder.unsqueeze(0).repeat_interleave(40, 0), 'img': img, # 'c': self.eval_camera, # TODO, get pre-calculated samples # 'ins': 'placeholder', # 'bbox': 'placeholder', # 'caption': 'placeholder', } # ! repeat as a intance return ret_dict class RealDataset_GSO(Dataset): def __init__( self, file_path, reso, reso_encoder, preprocess=None, classes=False, load_depth=False, test=False, scene_scale=1, overfitting=False, imgnet_normalize=True, dataset_size=-1, overfitting_bs=-1, interval=1, plucker_embedding=False, shuffle_across_cls=False, wds_split=1, # 4 splits to accelerate preprocessing ) -> None: super().__init__() self.file_path = file_path self.overfitting = overfitting self.scene_scale = scene_scale self.reso = reso self.reso_encoder = reso_encoder self.classes = False self.load_depth = load_depth self.preprocess = preprocess self.plucker_embedding = plucker_embedding self.rgb_list = [] # ! for gso-rendering all_objs = os.listdir(self.file_path) all_objs.sort() if True: # instant-mesh picked images # if False: all_instances = os.listdir(self.file_path) # all_fname = [ # t for t in all_instances # if t.split('.')[1] in ['png', 'jpg'] # ] # all_fname = [name for name in all_fname if '-input' in name ] # all_fname = ['house2-input.png', 'plant-input.png'] all_fname = ['house2-input.png'] self.rgb_list = [os.path.join(self.file_path, name) for name in all_fname] if False: for obj_folder in tqdm(all_objs[515:]): # for obj_folder in tqdm(all_objs[:515]): # for obj_folder in tqdm(all_objs[:]): # for obj_folder in tqdm(sorted(os.listdir(self.file_path))[515:]): # for idx in range(0,25,5): for idx in [0]: # only query frontal view is enough self.rgb_list.append(os.path.join(self.file_path, obj_folder, 'rgba', f'{idx:03d}.png')) # for free-3d rendering if False: # if True: # all_instances = sorted(os.listdir(self.file_path)) all_instances = ['BAGEL_WITH_CHEESE', 'BALANCING_CACTUS', 'Baby_Elements_Stacking_Cups', 'Breyer_Horse_Of_The_Year_2015', 'COAST_GUARD_BOAT', 'CONE_SORTING', 'CREATIVE_BLOCKS_35_MM', 'Cole_Hardware_Mini_Honey_Dipper', 'FAIRY_TALE_BLOCKS', 'FIRE_ENGINE', 'FOOD_BEVERAGE_SET', 'GEOMETRIC_PEG_BOARD', 'Great_Dinos_Triceratops_Toy', 'JUICER_SET', 'STACKING_BEAR', 'STACKING_RING', 'Schleich_African_Black_Rhino'] for instance in all_instances: self.rgb_list += ([ # os.path.join(self.file_path, instance, 'rgb', f'{fname:06d}.png') for fname in range(0,250,50) # os.path.join(self.file_path, instance, 'rgb', f'{fname:06d}.png') for fname in range(0,250,100) # os.path.join(self.file_path, instance, f'{fname:03d}.png') for fname in range(0,25,5) os.path.join(self.file_path, instance, 'render_mvs_25', 'model', f'{fname:03d}.png') for fname in range(0,25,4) ]) # if True: # g-objv animals images for i23d eval if False: # if True: objv_dataset = '/mnt/sfs-common/yslan/Dataset/Obajverse/chunk-jpeg-normal/bs_16_fixsave3/170K/512/' dataset_json = os.path.join(objv_dataset, 'dataset.json') with open(dataset_json, 'r') as f: dataset_json = json.load(f) # all_objs = dataset_json['Animals'][::3][:6250] all_objs = dataset_json['Animals'][::3][1100:2200][:600] for obj_folder in tqdm(all_objs[:]): for idx in [0]: # only query frontal view is enough self.rgb_list.append(os.path.join(self.file_path, obj_folder, f'{idx}.jpg')) # ! setup normalizataion transformations = [ transforms.ToTensor(), # [0,1] range ] assert imgnet_normalize if imgnet_normalize: transformations.append( transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) # type: ignore ) else: transformations.append( transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))) # type: ignore self.normalize = transforms.Compose(transformations) # camera = torch.load('eval_pose.pt', map_location='cpu') # self.eval_camera = camera # pre-cache # self.calc_rays_plucker() def __len__(self): return len(self.rgb_list) def __getitem__(self, index) -> Any: # return super().__getitem__(index) rgb_fname = self.rgb_list[index] # ! preprocess, normalize raw_img = imageio.imread(rgb_fname) # interpolation=cv2.INTER_AREA) if raw_img.shape[-1] == 4: alpha_mask = raw_img[..., 3:4] / 255.0 bg_white = np.ones_like(alpha_mask) * 255.0 raw_img = raw_img[..., :3] * alpha_mask + ( 1 - alpha_mask) * bg_white #[3, reso_encoder, reso_encoder] raw_img = raw_img.astype(np.uint8) # raw_img = recenter(raw_img, np.ones_like(raw_img), border_ratio=0.2) # log gt img = cv2.resize(raw_img, (self.reso, self.reso), interpolation=cv2.INTER_LANCZOS4) img = torch.from_numpy(img)[..., :3].permute( 2, 0, 1 ) / 127.5 - 1 #[3, reso, reso], normalize to [-1,1], follow triplane range ret_dict = { 'img': img, # 'ins': str(Path(rgb_fname).parent.parent.stem), # for gso-rendering 'ins': str(Path(rgb_fname).relative_to(self.file_path)), # for gso-rendering # 'ins': rgb_fname, } return ret_dict class RealMVDataset(Dataset): def __init__( self, file_path, reso, reso_encoder, preprocess=None, classes=False, load_depth=False, test=False, scene_scale=1, overfitting=False, imgnet_normalize=True, dataset_size=-1, overfitting_bs=-1, interval=1, plucker_embedding=False, shuffle_across_cls=False, wds_split=1, # 4 splits to accelerate preprocessing ) -> None: super().__init__() self.file_path = file_path self.overfitting = overfitting self.scene_scale = scene_scale self.reso = reso self.reso_encoder = reso_encoder self.classes = False self.load_depth = load_depth self.preprocess = preprocess self.plucker_embedding = plucker_embedding self.rgb_list = [] all_fname = [ t for t in os.listdir(self.file_path) if t.split('.')[1] in ['png', 'jpg'] ] all_fname = [name for name in all_fname if '-input' in name ] # all_fname = [name for name in all_fname if 'sorting_board-input' in name ] # all_fname = [name for name in all_fname if 'teasure_chest-input' in name ] # all_fname = [name for name in all_fname if 'bubble_mart_blue-input' in name ] # all_fname = [name for name in all_fname if 'chair_comfort-input' in name ] self.rgb_list += ([ os.path.join(self.file_path, fname) for fname in all_fname ]) # if len(self.rgb_list) == 1: # # placeholder # self.rgb_list = self.rgb_list * 40 # ! setup normalizataion transformations = [ transforms.ToTensor(), # [0,1] range ] azimuths = np.array([30, 90, 150, 210, 270, 330]).astype(float) elevations = np.array([20, -10, 20, -10, 20, -10]).astype(float) # zero123pp_pose, _ = generate_input_camera(1.6, [[elevations[i], azimuths[i]] for i in range(6)], fov=30) zero123pp_pose, _ = generate_input_camera(1.8, [[elevations[i], azimuths[i]] for i in range(6)], fov=30) K = torch.Tensor([1.3889, 0.0000, 0.5000, 0.0000, 1.3889, 0.5000, 0.0000, 0.0000, 0.0039]).to(zero123pp_pose) # keeps the same # st() zero123pp_pose = torch.cat([zero123pp_pose.reshape(6,-1), K.unsqueeze(0).repeat(6,1)], dim=-1) # ! directly adopt gt input # self.indices = np.array([0,2,4,5]) # eval_camera = zero123pp_pose[self.indices] # self.eval_camera = torch.cat([torch.zeros_like(eval_camera[0:1]),eval_camera], 0) # first c not used as condition here, just placeholder # ! adopt mv-diffusion output as input. # self.indices = np.array([1,0,2,4,5]) self.indices = np.array([0,1,2,3,4,5]) eval_camera = zero123pp_pose[self.indices].float().cpu().numpy() # for normalization # eval_camera = zero123pp_pose[self.indices] # self.eval_camera = eval_camera # self.eval_camera = torch.cat([torch.zeros_like(eval_camera[0:1]),eval_camera], 0) # first c not used as condition here, just placeholder # # * normalize here self.eval_camera = self.normalize_camera(eval_camera, eval_camera[0:1]) # the first img is not used. # self.mv_resize_cls = torchvision.transforms.Resize(320, interpolation=torchvision.transforms.InterpolationMode.BILINEAR, # max_size=None, antialias=True) def normalize_camera(self, c, c_frame0): # assert c.shape[0] == self.chunk_size # 8 o r10 B = c.shape[0] camera_poses = c[:, :16].reshape(B, 4, 4) # 3x4 canonical_camera_poses = c_frame0[:, :16].reshape(1, 4, 4) inverse_canonical_pose = np.linalg.inv(canonical_camera_poses) inverse_canonical_pose = np.repeat(inverse_canonical_pose, B, 0) cam_radius = np.linalg.norm( c_frame0[:, :16].reshape(1, 4, 4)[:, :3, 3], axis=-1, keepdims=False) # since g-buffer adopts dynamic radius here. frame1_fixed_pos = np.repeat(np.eye(4)[None], 1, axis=0) frame1_fixed_pos[:, 2, -1] = -cam_radius transform = frame1_fixed_pos @ inverse_canonical_pose new_camera_poses = np.repeat( transform, 1, axis=0 ) @ camera_poses # [V, 4, 4]. np.repeat() is th.repeat_interleave() c = np.concatenate([new_camera_poses.reshape(B, 16), c[:, 16:]], axis=-1) return c def __len__(self): return len(self.rgb_list) def __getitem__(self, index) -> Any: # return super().__getitem__(index) rgb_fname = self.rgb_list[index] raw_img = imageio.imread(rgb_fname)[..., :3] raw_img = cv2.resize(raw_img, (self.reso, self.reso), interpolation=cv2.INTER_CUBIC) raw_img = torch.from_numpy(raw_img).permute(2,0,1).clip(0,255) # [0,1] img = raw_img / 127.5 - 1 # ! if loading mv-diff output views mv_img = imageio.imread(rgb_fname.replace('-input', '')) mv_img = rearrange(mv_img, '(n h) (m w) c -> (n m) h w c', n=3, m=2)[self.indices] # (6, 3, 320, 320) mv_img = np.stack([recenter(img, np.ones_like(img), border_ratio=0.1) for img in mv_img], axis=0) mv_img = rearrange(mv_img, 'b h w c -> b c h w') # to torch tradition mv_img = torch.from_numpy(mv_img) / 127.5 - 1 ret_dict = { 'img': img, 'mv_img': mv_img, 'c': self.eval_camera, 'caption': 'null', } return ret_dict class NovelViewObjverseDataset(MultiViewObjverseDataset): """novel view prediction version. """ def __init__(self, file_path, reso, reso_encoder, preprocess=None, classes=False, load_depth=False, test=False, scene_scale=1, overfitting=False, imgnet_normalize=True, dataset_size=-1, overfitting_bs=-1, **kwargs): super().__init__(file_path, reso, reso_encoder, preprocess, classes, load_depth, test, scene_scale, overfitting, imgnet_normalize, dataset_size, overfitting_bs, **kwargs) def __getitem__(self, idx): input_view = super().__getitem__( idx) # get previous input view results # get novel view of the same instance novel_view = super().__getitem__( (idx // self.instance_data_length) * self.instance_data_length + random.randint(0, self.instance_data_length - 1)) # assert input_view['ins_name'] == novel_view['ins_name'], 'should sample novel view from the same instance' input_view.update({f'nv_{k}': v for k, v in novel_view.items()}) return input_view class MultiViewObjverseDatasetforLMDB(MultiViewObjverseDataset): def __init__( self, file_path, reso, reso_encoder, preprocess=None, classes=False, load_depth=False, test=False, scene_scale=1, overfitting=False, imgnet_normalize=True, dataset_size=-1, overfitting_bs=-1, shuffle_across_cls=False, wds_split=1, four_view_for_latent=False, ): super().__init__(file_path, reso, reso_encoder, preprocess, classes, load_depth, test, scene_scale, overfitting, imgnet_normalize, dataset_size, overfitting_bs, shuffle_across_cls=shuffle_across_cls, wds_split=wds_split, four_view_for_latent=four_view_for_latent) # assert self.reso == 256 self.load_caption = True with open( # '/cpfs01/shared/V2V/V2V_hdd/yslan/aigc3d/text_captions_cap3d.json' '/nas/shared/public/yslan/data/text_captions_cap3d.json') as f: # '/nas/shared/V2V/yslan/aigc3d/text_captions_cap3d.json') as f: self.caption_data = json.load(f) # lmdb_path = '/cpfs01/user/yangpeiqing.p/yslan/data/Furnitures_uncompressed/' # with open(os.path.join(lmdb_path, 'idx_to_ins_mapping.json')) as f: # self.idx_to_ins_mapping = json.load(f) def __len__(self): return super().__len__() # return 100 # for speed debug def quantize_depth(self, depth): # https://developers.google.com/depthmap-metadata/encoding # RangeInverse encoding bg = depth == 0 depth[bg] = 3 # no need to allocate capacity to it disparity = 1 / depth far = disparity.max().item() # np array here near = disparity.min().item() # d_normalized = (far * (depth-near) / (depth * far - near)) # [0,1] range d_normalized = (disparity - near) / (far - near) # [0,1] range # imageio.imwrite('depth_negative.jpeg', (((depth - near) / (far - near) * 255)<0).numpy().astype(np.uint8)) # imageio.imwrite('depth_negative.jpeg', ((depth <0)*255).numpy().astype(np.uint8)) d_normalized = np.nan_to_num(d_normalized.cpu().numpy()) d_normalized = (np.clip(d_normalized, 0, 1) * 255).astype(np.uint8) # imageio.imwrite('depth.png', d_normalized) # d = 1 / ( (d_normalized / 255) * (far-near) + near) # diff = (d[~bg.numpy()] - depth[~bg].numpy()).sum() return d_normalized, near, far # return disp def __getitem__(self, idx): # ret_dict = super().__getitem__(idx) rgb_fname = self.rgb_list[idx] pose_fname = self.pose_list[idx] raw_img = imageio.imread(rgb_fname) # [..., :3] assert raw_img.shape[-1] == 4 # st() # cv2.imwrite('img_CV2_90.jpg', a, [int(cv2.IMWRITE_JPEG_QUALITY), 90]) # if raw_img.shape[-1] == 4: # ! set bg to white alpha_mask = raw_img[..., -1:] / 255 # [0,1] raw_img = alpha_mask * raw_img[..., :3] + ( 1 - alpha_mask) * np.ones_like(raw_img[..., :3]) * 255 raw_img = np.concatenate([raw_img, alpha_mask * 255], -1) raw_img = raw_img.astype(np.uint8) raw_img = cv2.resize(raw_img, (self.reso, self.reso), interpolation=cv2.INTER_LANCZOS4) alpha_mask = raw_img[..., -1] / 255 raw_img = raw_img[..., :3] # alpha_mask = cv2.resize(alpha_mask, (self.reso, self.reso), # interpolation=cv2.INTER_LANCZOS4) c2w = read_camera_matrix_single(pose_fname) #[1, 4, 4] -> [1, 16] c = np.concatenate([c2w.reshape(16), self.intrinsics], axis=0).reshape(25).astype( np.float32) # 25, no '1' dim needed. c = torch.from_numpy(c) # c = np.concatenate([c2w, self.intrinsics], axis=0).reshape(25) # 25, no '1' dim needed. # if self.load_depth: # depth, depth_mask, depth_mask_sr = read_dnormal(self.depth_list[idx], # try: depth, normal = read_dnormal(self.depth_list[idx], c2w[:3, 3:], self.reso, self.reso) # ! quantize depth for fast decoding # d_normalized, d_near, d_far = self.quantize_depth(depth) # ! add frame_0 alignment # try: ins = str( (Path(self.data_ins_list[idx]).relative_to(self.file_path)).parent) # if self.shuffle_across_cls: if self.load_caption: caption = self.caption_data['/'.join(ins.split('/')[1:])] bbox = self.load_bbox(torch.from_numpy(alpha_mask) > 0) else: caption = '' # since in g-alignment-xl, some instances will fail. bbox = self.load_bbox(torch.from_numpy(np.ones_like(alpha_mask)) > 0) # else: # caption = self.caption_data[ins] ret_dict = { 'normal': normal, 'raw_img': raw_img, 'c': c, # 'depth_mask': depth_mask, # 64x64 here? 'bbox': bbox, 'ins': ins, 'caption': caption, 'alpha_mask': alpha_mask, 'depth': depth, # return for pcd creation # 'd_normalized': d_normalized, # 'd_near': d_near, # 'd_far': d_far, # 'fname': rgb_fname, } return ret_dict class MultiViewObjverseDatasetforLMDB_nocaption(MultiViewObjverseDatasetforLMDB): def __init__( self, file_path, reso, reso_encoder, preprocess=None, classes=False, load_depth=False, test=False, scene_scale=1, overfitting=False, imgnet_normalize=True, dataset_size=-1, overfitting_bs=-1, shuffle_across_cls=False, wds_split=1, four_view_for_latent=False, ): super().__init__(file_path, reso, reso_encoder, preprocess, classes, load_depth, test, scene_scale, overfitting, imgnet_normalize, dataset_size, overfitting_bs, shuffle_across_cls=shuffle_across_cls, wds_split=wds_split, four_view_for_latent=four_view_for_latent) self.load_caption = False class Objv_LMDBDataset_MV_Compressed(LMDBDataset_MV_Compressed): def __init__(self, lmdb_path, reso, reso_encoder, imgnet_normalize=True, dataset_size=-1, test=False, **kwargs): super().__init__(lmdb_path, reso, reso_encoder, imgnet_normalize, dataset_size=dataset_size, **kwargs) self.instance_data_length = 40 # ! could save some key attributes in LMDB if test: self.length = self.instance_data_length elif dataset_size > 0: self.length = dataset_size * self.instance_data_length # load caption data, and idx-to-ins mapping with open( '/cpfs01/shared/V2V/V2V_hdd/yslan/aigc3d/text_captions_cap3d.json' ) as f: self.caption_data = json.load(f) with open(os.path.join(lmdb_path, 'idx_to_ins_mapping.json')) as f: self.idx_to_ins_mapping = json.load(f) def _load_data(self, idx): # ''' raw_img, depth, c, bbox = self._load_lmdb_data(idx) # raw_img, depth, c, bbox = self._load_lmdb_data_no_decompress(idx) # resize depth and bbox caption = self.caption_data[self.idx_to_ins_mapping[str(idx)]] return { **self._post_process_sample(raw_img, depth), 'c': c, 'bbox': (bbox * (self.reso / 512.0)).astype(np.uint8), # 'bbox': (bbox*(self.reso/256.0)).astype(np.uint8), # TODO, double check 512 in wds? 'caption': caption } # ''' # raw_img, depth, c, bbox = self._load_lmdb_data_no_decompress(idx) # st() # return {} def __getitem__(self, idx): return self._load_data(idx) class Objv_LMDBDataset_MV_NoCompressed(Objv_LMDBDataset_MV_Compressed): def __init__(self, lmdb_path, reso, reso_encoder, imgnet_normalize=True, dataset_size=-1, test=False, **kwargs): super().__init__(lmdb_path, reso, reso_encoder, imgnet_normalize, dataset_size, test, **kwargs) def _load_data(self, idx): # ''' raw_img, depth, c, bbox = self._load_lmdb_data_no_decompress(idx) # resize depth and bbox caption = self.caption_data[self.idx_to_ins_mapping[str(idx)]] return { **self._post_process_sample(raw_img, depth), 'c': c, 'bbox': (bbox * (self.reso / 512.0)).astype(np.uint8), 'caption': caption } return {} class Objv_LMDBDataset_NV_NoCompressed(Objv_LMDBDataset_MV_NoCompressed): def __init__(self, lmdb_path, reso, reso_encoder, imgnet_normalize=True, dataset_size=-1, test=False, **kwargs): super().__init__(lmdb_path, reso, reso_encoder, imgnet_normalize, dataset_size, test, **kwargs) def __getitem__(self, idx): input_view = self._load_data(idx) # get previous input view results # get novel view of the same instance try: novel_view = self._load_data( (idx // self.instance_data_length) * self.instance_data_length + random.randint(0, self.instance_data_length - 1)) except Exception as e: raise NotImplementedError(idx) # assert input_view['ins_name'] == novel_view['ins_name'], 'should sample novel view from the same instance' input_view.update({f'nv_{k}': v for k, v in novel_view.items()}) return input_view class Objv_LMDBDataset_MV_Compressed_for_lmdb(LMDBDataset_MV_Compressed): def __init__(self, lmdb_path, reso, reso_encoder, imgnet_normalize=True, dataset_size=-1, test=False, **kwargs): super().__init__(lmdb_path, reso, reso_encoder, imgnet_normalize, dataset_size=dataset_size, **kwargs) self.instance_data_length = 40 # ! could save some key attributes in LMDB if test: self.length = self.instance_data_length elif dataset_size > 0: self.length = dataset_size * self.instance_data_length # load caption data, and idx-to-ins mapping with open( '/cpfs01/shared/V2V/V2V_hdd/yslan/aigc3d/text_captions_cap3d.json' ) as f: self.caption_data = json.load(f) with open(os.path.join(lmdb_path, 'idx_to_ins_mapping.json')) as f: self.idx_to_ins_mapping = json.load(f) # def _load_data(self, idx): # # ''' # raw_img, depth, c, bbox = self._load_lmdb_data(idx) # # resize depth and bbox # caption = self.caption_data[self.idx_to_ins_mapping[str(idx)]] # # st() # return { # **self._post_process_sample(raw_img, depth), 'c': c, # 'bbox': (bbox*(self.reso/512.0)).astype(np.uint8), # 'caption': caption # } # # ''' # # raw_img, depth, c, bbox = self._load_lmdb_data_no_decompress(idx) # # st() # # return {} def load_bbox(self, mask): nonzero_value = torch.nonzero(mask) height, width = nonzero_value.max(dim=0)[0] top, left = nonzero_value.min(dim=0)[0] bbox = torch.tensor([top, left, height, width], dtype=torch.float32) return bbox def __getitem__(self, idx): raw_img, depth, c, bbox = self._load_lmdb_data(idx) return {'raw_img': raw_img, 'depth': depth, 'c': c, 'bbox': bbox} class Objv_LMDBDataset_NV_Compressed(Objv_LMDBDataset_MV_Compressed): def __init__(self, lmdb_path, reso, reso_encoder, imgnet_normalize=True, dataset_size=-1, **kwargs): super().__init__(lmdb_path, reso, reso_encoder, imgnet_normalize, dataset_size, **kwargs) def __getitem__(self, idx): input_view = self._load_data(idx) # get previous input view results # get novel view of the same instance try: novel_view = self._load_data( (idx // self.instance_data_length) * self.instance_data_length + random.randint(0, self.instance_data_length - 1)) except Exception as e: raise NotImplementedError(idx) # assert input_view['ins_name'] == novel_view['ins_name'], 'should sample novel view from the same instance' input_view.update({f'nv_{k}': v for k, v in novel_view.items()}) return input_view # # test tar loading def load_wds_ResampledShard(file_path, batch_size, num_workers, reso, reso_encoder, test=False, preprocess=None, imgnet_normalize=True, plucker_embedding=False, decode_encode_img_only=False, load_instance=False, mv_input=False, split_chunk_input=False, duplicate_sample=True, append_depth=False, append_normal=False, gs_cam_format=False, orthog_duplicate=False, **kwargs): # return raw_img, depth, c, bbox, sample_pyd['ins.pyd'], sample_pyd['fname.pyd'] post_process_cls = PostProcess( reso, reso_encoder, imgnet_normalize=imgnet_normalize, plucker_embedding=plucker_embedding, decode_encode_img_only=decode_encode_img_only, mv_input=mv_input, split_chunk_input=split_chunk_input, duplicate_sample=duplicate_sample, append_depth=append_depth, gs_cam_format=gs_cam_format, orthog_duplicate=orthog_duplicate, append_normal=append_normal, ) # ! add shuffling if isinstance(file_path, list): # lst of shard urls all_shards = [] for url_path in file_path: all_shards.extend(wds.shardlists.expand_source(url_path)) logger.log('all_shards', all_shards) else: all_shards = file_path # to be expanded if not load_instance: # during reconstruction training, load pair if not split_chunk_input: dataset = wds.DataPipeline( wds.ResampledShards(all_shards), # url_shard # at this point we have an iterator over all the shards wds.shuffle(50), wds.split_by_worker, # if multi-node wds.tarfile_to_samples(), # add wds.split_by_node here if you are using multiple nodes wds.shuffle( 1000 ), # shuffles in the memory, leverage large RAM for more efficient loading wds.decode(wds.autodecode.basichandlers), # TODO wds.to_tuple( "sample.pyd"), # extract the pyd from top level dict wds.map(post_process_cls.decode_zip), wds.map(post_process_cls.paired_post_process ), # create input-novelview paired samples # wds.map(post_process_cls._post_process_sample), # wds.detshuffle(1000), # shuffles in the memory, leverage large RAM for more efficient loading wds.batched( 16, partial=True, # collation_fn=collate ) # streaming more data at once, and rebatch later ) elif load_gzip: # deprecated, no performance improve dataset = wds.DataPipeline( wds.ResampledShards(all_shards), # url_shard # at this point we have an iterator over all the shards wds.shuffle(10), wds.split_by_worker, # if multi-node wds.tarfile_to_samples(), # add wds.split_by_node here if you are using multiple nodes # wds.shuffle( # 100 # ), # shuffles in the memory, leverage large RAM for more efficient loading wds.decode('rgb8'), # TODO # wds.decode(wds.autodecode.basichandlers), # TODO # wds.to_tuple('raw_img.jpeg', 'depth.jpeg', 'alpha_mask.jpeg', # 'd_near.npy', 'd_far.npy', "c.npy", 'bbox.npy', # 'ins.txt', 'caption.txt'), wds.to_tuple('raw_img.png', 'depth_alpha.png'), # wds.to_tuple('raw_img.jpg', "c.npy", 'bbox.npy', 'depth.pyd', 'ins.txt', 'caption.txt'), # wds.to_tuple('raw_img.jpg', "c.npy", 'bbox.npy', 'ins.txt', 'caption.txt'), wds.map(post_process_cls.decode_gzip), # wds.map(post_process_cls.paired_post_process_chunk # ), # create input-novelview paired samples wds.batched( 20, partial=True, # collation_fn=collate ) # streaming more data at once, and rebatch later ) else: dataset = wds.DataPipeline( wds.ResampledShards(all_shards), # url_shard # at this point we have an iterator over all the shards wds.shuffle(100), wds.split_by_worker, # if multi-node wds.tarfile_to_samples(), # add wds.split_by_node here if you are using multiple nodes wds.shuffle( 4000 // split_chunk_size ), # shuffles in the memory, leverage large RAM for more efficient loading wds.decode(wds.autodecode.basichandlers), # TODO wds.to_tuple( "sample.pyd"), # extract the pyd from top level dict wds.map(post_process_cls.decode_zip), wds.map(post_process_cls.paired_post_process_chunk ), # create input-novelview paired samples # wds.map(post_process_cls._post_process_sample), # wds.detshuffle(1000), # shuffles in the memory, leverage large RAM for more efficient loading wds.batched( 120 // split_chunk_size, partial=True, # collation_fn=collate ) # streaming more data at once, and rebatch later ) loader_shard = wds.WebLoader( dataset, num_workers=num_workers, drop_last=False, batch_size=None, shuffle=False, persistent_workers=num_workers > 0).unbatched().shuffle( 1000 // split_chunk_size).batched(batch_size).map( post_process_cls.create_dict) if mv_input: loader_shard = loader_shard.map(post_process_cls.prepare_mv_input) else: # load single instance during test/eval assert batch_size == 1 dataset = wds.DataPipeline( wds.ResampledShards(all_shards), # url_shard # at this point we have an iterator over all the shards wds.shuffle(50), wds.split_by_worker, # if multi-node wds.tarfile_to_samples(), # add wds.split_by_node here if you are using multiple nodes wds.detshuffle( 100 ), # shuffles in the memory, leverage large RAM for more efficient loading wds.decode(wds.autodecode.basichandlers), # TODO wds.to_tuple("sample.pyd"), # extract the pyd from top level dict wds.map(post_process_cls.decode_zip), # wds.map(post_process_cls.paired_post_process), # create input-novelview paired samples wds.map(post_process_cls._post_process_batch_sample), # wds.detshuffle(1000), # shuffles in the memory, leverage large RAM for more efficient loading wds.batched( 2, partial=True, # collation_fn=collate ) # streaming more data at once, and rebatch later ) loader_shard = wds.WebLoader( dataset, num_workers=num_workers, drop_last=False, batch_size=None, shuffle=False, persistent_workers=num_workers > 0).unbatched().shuffle(200).batched(batch_size).map( post_process_cls.single_instance_sample_create_dict) # persistent_workers=num_workers > 0).unbatched().batched(batch_size).map(post_process_cls.create_dict) # 1000).batched(batch_size).map(post_process_cls.create_dict) # .map(collate) # .map(collate) # .batched(batch_size) # # .unbatched().shuffle(1000).batched(batch_size).map(post_process) # # https://github.com/webdataset/webdataset/issues/187 # return next(iter(loader_shard)) #return dataset return loader_shard class PostProcessForDiff: def __init__( self, reso, reso_encoder, imgnet_normalize, plucker_embedding, decode_encode_img_only, mv_latent_dir, ) -> None: self.plucker_embedding = plucker_embedding self.mv_latent_dir = mv_latent_dir self.decode_encode_img_only = decode_encode_img_only transformations = [ transforms.ToTensor(), # [0,1] range ] if imgnet_normalize: transformations.append( transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) # type: ignore ) else: transformations.append( transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))) # type: ignore self.normalize = transforms.Compose(transformations) self.reso_encoder = reso_encoder self.reso = reso self.instance_data_length = 40 # self.pair_per_instance = 1 # compat self.pair_per_instance = 2 # check whether improves IO # self.pair_per_instance = 3 # check whether improves IO # self.pair_per_instance = 4 # check whether improves IO self.camera = torch.load('eval_pose.pt', map_location='cpu').numpy() self.canonical_frame = self.camera[25:26] # 1, 25 # inverse this self.canonical_frame_pos = self.canonical_frame[:, :16].reshape(4, 4) def get_rays_kiui(self, c, opengl=True): h, w = self.reso_encoder, self.reso_encoder intrinsics, pose = c[16:], c[:16].reshape(4, 4) # cx, cy, fx, fy = intrinsics[2], intrinsics[5] fx = fy = 525 # pixel space cx = cy = 256 # rendering default K factor = self.reso / (cx * 2) # 128 / 512 fx = fx * factor fy = fy * factor x, y = torch.meshgrid( torch.arange(w, device=pose.device), torch.arange(h, device=pose.device), indexing="xy", ) x = x.flatten() y = y.flatten() cx = w * 0.5 cy = h * 0.5 # focal = h * 0.5 / np.tan(0.5 * np.deg2rad(fovy)) camera_dirs = F.pad( torch.stack( [ (x - cx + 0.5) / fx, (y - cy + 0.5) / fy * (-1.0 if opengl else 1.0), ], dim=-1, ), (0, 1), value=(-1.0 if opengl else 1.0), ) # [hw, 3] rays_d = camera_dirs @ pose[:3, :3].transpose(0, 1) # [hw, 3] rays_o = pose[:3, 3].unsqueeze(0).expand_as(rays_d) # [hw, 3] rays_o = rays_o.view(h, w, 3) rays_d = safe_normalize(rays_d).view(h, w, 3) return rays_o, rays_d def gen_rays(self, c): # Generate rays intrinsics, c2w = c[16:], c[:16].reshape(4, 4) self.h = self.reso_encoder self.w = self.reso_encoder yy, xx = torch.meshgrid( torch.arange(self.h, dtype=torch.float32) + 0.5, torch.arange(self.w, dtype=torch.float32) + 0.5, indexing='ij') # normalize to 0-1 pixel range yy = yy / self.h xx = xx / self.w # K = np.array([f_x, 0, w / 2, 0, f_y, h / 2, 0, 0, 1]).reshape(3, 3) cx, cy, fx, fy = intrinsics[2], intrinsics[5], intrinsics[ 0], intrinsics[4] # cx *= self.w # cy *= self.h # f_x = f_y = fx * h / res_raw c2w = torch.from_numpy(c2w).float() xx = (xx - cx) / fx yy = (yy - cy) / fy zz = torch.ones_like(xx) dirs = torch.stack((xx, yy, zz), dim=-1) # OpenCV convention dirs /= torch.norm(dirs, dim=-1, keepdim=True) dirs = dirs.reshape(-1, 3, 1) del xx, yy, zz # st() dirs = (c2w[None, :3, :3] @ dirs)[..., 0] origins = c2w[None, :3, 3].expand(self.h * self.w, -1).contiguous() origins = origins.view(self.h, self.w, 3) dirs = dirs.view(self.h, self.w, 3) return origins, dirs def normalize_camera(self, c): # assert c.shape[0] == self.chunk_size # 8 o r10 c = c[None] # api compat B = c.shape[0] camera_poses = c[:, :16].reshape(B, 4, 4) # 3x4 cam_radius = np.linalg.norm( self.canonical_frame_pos.reshape(4, 4)[:3, 3], axis=-1, keepdims=False) # since g-buffer adopts dynamic radius here. frame1_fixed_pos = np.eye(4) frame1_fixed_pos[2, -1] = -cam_radius transform = frame1_fixed_pos @ np.linalg.inv( self.canonical_frame_pos) # 4,4 # from LGM, https://github.com/3DTopia/LGM/blob/fe8d12cff8c827df7bb77a3c8e8b37408cb6fe4c/core/provider_objaverse.py#L127 # transform = torch.tensor([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, self.opt.cam_radius], [0, 0, 0, 1]], dtype=torch.float32) @ torch.inverse(c[[0,4]]) new_camera_poses = transform[None] @ camera_poses # [V, 4, 4] c = np.concatenate([new_camera_poses.reshape(B, 16), c[:, 16:]], axis=-1) return c[0] def _post_process_sample(self, data_sample): # raw_img, depth, c, bbox, caption, ins = data_sample raw_img, c, caption, ins = data_sample # c = self.normalize_camera(c) @ if relative pose. img = raw_img # 256x256 img = torch.from_numpy(img).permute(2, 0, 1) / 127.5 - 1 # load latent. # latent_path = Path(self.mv_latent_dir, ins, 'latent.npy') # ! a converged version, before adding augmentation # if random.random() > 0.5: # latent_path = Path(self.mv_latent_dir, ins, 'latent.npy') # else: # augmentation, double the dataset latent_path = Path( self.mv_latent_dir.replace('v=4-final', 'v=4-rotate'), ins, 'latent.npy') latent = np.load(latent_path) # return (img_to_encoder, img, c, caption, ins) return (latent, img, c, caption, ins) def rand_sample_idx(self): return random.randint(0, self.instance_data_length - 1) def rand_pair(self): return (self.rand_sample_idx() for _ in range(2)) def paired_post_process(self, sample): # repeat n times? all_inp_list = [] all_nv_list = [] caption, ins = sample[-2:] # expanded_return = [] for _ in range(self.pair_per_instance): cano_idx, nv_idx = self.rand_pair() cano_sample = self._post_process_sample(item[cano_idx] for item in sample[:-2]) nv_sample = self._post_process_sample(item[nv_idx] for item in sample[:-2]) all_inp_list.extend(cano_sample) all_nv_list.extend(nv_sample) return (*all_inp_list, *all_nv_list, caption, ins) # return [cano_sample, nv_sample, caption, ins] # return (*cano_sample, *nv_sample, caption, ins) # def single_sample_create_dict(self, sample, prefix=''): # # if len(sample) == 1: # # sample = sample[0] # # assert len(sample) == 6 # img_to_encoder, img, fg_mask_reso, depth_reso, c, bbox = sample # return { # # **sample, # f'{prefix}img_to_encoder': img_to_encoder, # f'{prefix}img': img, # f'{prefix}depth_mask': fg_mask_reso, # f'{prefix}depth': depth_reso, # f'{prefix}c': c, # f'{prefix}bbox': bbox, # } def single_sample_create_dict(self, sample, prefix=''): # if len(sample) == 1: # sample = sample[0] # assert len(sample) == 6 # img_to_encoder, img, fg_mask_reso, depth_reso, c, bbox = sample # img_to_encoder, img, c, caption, ins = sample # img, c, caption, ins = sample latent, img, c, caption, ins = sample # load latent return { # **sample, # 'img_to_encoder': img_to_encoder, 'latent': latent, 'img': img, 'c': c, 'caption': caption, 'ins': ins } def decode_zip(self, sample_pyd, shape=(256, 256)): if isinstance(sample_pyd, tuple): sample_pyd = sample_pyd[0] assert isinstance(sample_pyd, dict) raw_img = decompress_and_open_image_gzip( sample_pyd['raw_img'], is_img=True, decompress=True, decompress_fn=lz4.frame.decompress) caption = sample_pyd['caption'].decode('utf-8') ins = sample_pyd['ins'].decode('utf-8') c = decompress_array(sample_pyd['c'], (25, ), np.float32, decompress=True, decompress_fn=lz4.frame.decompress) # bbox = decompress_array( # sample_pyd['bbox'], # ( # 40, # 4, # ), # np.float32, # # decompress=False) # decompress=True, # decompress_fn=lz4.frame.decompress) # if self.decode_encode_img_only: # depth = np.zeros(shape=(40, *shape)) # save loading time # else: # depth = decompress_array(sample_pyd['depth'], (40, *shape), # np.float32, # decompress=True, # decompress_fn=lz4.frame.decompress) # return {'raw_img': raw_img, 'depth': depth, 'bbox': bbox, 'caption': caption, 'ins': ins, 'c': c} # return raw_img, depth, c, bbox, caption, ins # return raw_img, bbox, caption, ins # return bbox, caption, ins return raw_img, c, caption, ins # ! run single-instance pipeline first # return raw_img[0], depth[0], c[0], bbox[0], caption, ins def create_dict(self, sample): # sample = [item[0] for item in sample] # wds wrap items in [] # cano_sample_list = [[] for _ in range(6)] # nv_sample_list = [[] for _ in range(6)] # for idx in range(0, self.pair_per_instance): # cano_sample = sample[6*idx:6*(idx+1)] # nv_sample = sample[6*self.pair_per_instance+6*idx:6*self.pair_per_instance+6*(idx+1)] # for item_idx in range(6): # cano_sample_list[item_idx].append(cano_sample[item_idx]) # nv_sample_list[item_idx].append(nv_sample[item_idx]) # # ! cycle input/output view for more pairs # cano_sample_list[item_idx].append(nv_sample[item_idx]) # nv_sample_list[item_idx].append(cano_sample[item_idx]) cano_sample = self.single_sample_create_dict(sample, prefix='') # nv_sample = self.single_sample_create_dict((torch.cat(item_list) for item_list in nv_sample_list) , prefix='nv_') return cano_sample # return { # **cano_sample, # # **nv_sample, # 'caption': sample[-2], # 'ins': sample[-1] # } # test tar loading def load_wds_diff_ResampledShard(file_path, batch_size, num_workers, reso, reso_encoder, test=False, preprocess=None, imgnet_normalize=True, plucker_embedding=False, decode_encode_img_only=False, mv_latent_dir='', **kwargs): # return raw_img, depth, c, bbox, sample_pyd['ins.pyd'], sample_pyd['fname.pyd'] post_process_cls = PostProcessForDiff( reso, reso_encoder, imgnet_normalize=imgnet_normalize, plucker_embedding=plucker_embedding, decode_encode_img_only=decode_encode_img_only, mv_latent_dir=mv_latent_dir, ) if isinstance(file_path, list): # lst of shard urls all_shards = [] for url_path in file_path: all_shards.extend(wds.shardlists.expand_source(url_path)) logger.log('all_shards', all_shards) else: all_shards = file_path # to be expanded dataset = wds.DataPipeline( wds.ResampledShards(all_shards), # url_shard # at this point we have an iterator over all the shards wds.shuffle(100), wds.split_by_worker, # if multi-node wds.tarfile_to_samples(), # add wds.split_by_node here if you are using multiple nodes wds.shuffle( 20000 ), # shuffles in the memory, leverage large RAM for more efficient loading wds.decode(wds.autodecode.basichandlers), # TODO wds.to_tuple("sample.pyd"), # extract the pyd from top level dict wds.map(post_process_cls.decode_zip), # wds.map(post_process_cls.paired_post_process), # create input-novelview paired samples wds.map(post_process_cls._post_process_sample), # wds.detshuffle(1000), # shuffles in the memory, leverage large RAM for more efficient loading wds.batched( 100, partial=True, # collation_fn=collate ) # streaming more data at once, and rebatch later ) loader_shard = wds.WebLoader( dataset, num_workers=num_workers, drop_last=False, batch_size=None, shuffle=False, persistent_workers=num_workers > 0).unbatched().shuffle(2500).batched(batch_size).map( post_process_cls.create_dict) # persistent_workers=num_workers > 0).unbatched().batched(batch_size).map(post_process_cls.create_dict) # 1000).batched(batch_size).map(post_process_cls.create_dict) # .map(collate) # .map(collate) # .batched(batch_size) # # .unbatched().shuffle(1000).batched(batch_size).map(post_process) # # https://github.com/webdataset/webdataset/issues/187 # return next(iter(loader_shard)) #return dataset return loader_shard def load_wds_data( file_path="", reso=64, reso_encoder=224, batch_size=1, num_workers=6, plucker_embedding=False, decode_encode_img_only=False, load_wds_diff=False, load_wds_latent=False, load_instance=False, # for evaluation mv_input=False, split_chunk_input=False, duplicate_sample=True, mv_latent_dir='', append_depth=False, gs_cam_format=False, orthog_duplicate=False, **args): if load_wds_diff: # assert num_workers == 0 # on aliyun, worker=0 performs much much faster wds_loader = load_wds_diff_ResampledShard( file_path, batch_size=batch_size, num_workers=num_workers, reso=reso, reso_encoder=reso_encoder, plucker_embedding=plucker_embedding, decode_encode_img_only=decode_encode_img_only, mv_input=mv_input, split_chunk_input=split_chunk_input, append_depth=append_depth, mv_latent_dir=mv_latent_dir, gs_cam_format=gs_cam_format, orthog_duplicate=orthog_duplicate, ) elif load_wds_latent: # for diffusion training, cache latent wds_loader = load_wds_latent_ResampledShard( file_path, batch_size=batch_size, num_workers=num_workers, reso=reso, reso_encoder=reso_encoder, plucker_embedding=plucker_embedding, decode_encode_img_only=decode_encode_img_only, mv_input=mv_input, split_chunk_input=split_chunk_input, ) # elif load_instance: # wds_loader = load_wds_instance_ResampledShard( # file_path, # batch_size=batch_size, # num_workers=num_workers, # reso=reso, # reso_encoder=reso_encoder, # plucker_embedding=plucker_embedding, # decode_encode_img_only=decode_encode_img_only # ) else: wds_loader = load_wds_ResampledShard( file_path, batch_size=batch_size, num_workers=num_workers, reso=reso, reso_encoder=reso_encoder, plucker_embedding=plucker_embedding, decode_encode_img_only=decode_encode_img_only, load_instance=load_instance, mv_input=mv_input, split_chunk_input=split_chunk_input, duplicate_sample=duplicate_sample, append_depth=append_depth, gs_cam_format=gs_cam_format, orthog_duplicate=orthog_duplicate, ) while True: yield from wds_loader # yield from wds_loader class PostProcess_forlatent: def __init__( self, reso, reso_encoder, imgnet_normalize, plucker_embedding, decode_encode_img_only, ) -> None: self.plucker_embedding = plucker_embedding self.decode_encode_img_only = decode_encode_img_only transformations = [ transforms.ToTensor(), # [0,1] range ] if imgnet_normalize: transformations.append( transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) # type: ignore ) else: transformations.append( transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))) # type: ignore self.normalize = transforms.Compose(transformations) self.reso_encoder = reso_encoder self.reso = reso self.instance_data_length = 40 # self.pair_per_instance = 1 # compat self.pair_per_instance = 2 # check whether improves IO # self.pair_per_instance = 3 # check whether improves IO # self.pair_per_instance = 4 # check whether improves IO def _post_process_sample(self, data_sample): # raw_img, depth, c, bbox, caption, ins = data_sample raw_img, c, caption, ins = data_sample # bbox = (bbox*(self.reso/256)).astype(np.uint8) # normalize bbox to the reso range if raw_img.shape[-2] != self.reso_encoder: img_to_encoder = cv2.resize(raw_img, (self.reso_encoder, self.reso_encoder), interpolation=cv2.INTER_LANCZOS4) else: img_to_encoder = raw_img img_to_encoder = self.normalize(img_to_encoder) if self.plucker_embedding: rays_o, rays_d = self.gen_rays(c) rays_plucker = torch.cat( [torch.cross(rays_o, rays_d, dim=-1), rays_d], dim=-1).permute(2, 0, 1) # [h, w, 6] -> 6,h,w img_to_encoder = torch.cat([img_to_encoder, rays_plucker], 0) img = cv2.resize(raw_img, (self.reso, self.reso), interpolation=cv2.INTER_LANCZOS4) img = torch.from_numpy(img).permute(2, 0, 1) / 127.5 - 1 return (img_to_encoder, img, c, caption, ins) def rand_sample_idx(self): return random.randint(0, self.instance_data_length - 1) def rand_pair(self): return (self.rand_sample_idx() for _ in range(2)) def paired_post_process(self, sample): # repeat n times? all_inp_list = [] all_nv_list = [] caption, ins = sample[-2:] # expanded_return = [] for _ in range(self.pair_per_instance): cano_idx, nv_idx = self.rand_pair() cano_sample = self._post_process_sample(item[cano_idx] for item in sample[:-2]) nv_sample = self._post_process_sample(item[nv_idx] for item in sample[:-2]) all_inp_list.extend(cano_sample) all_nv_list.extend(nv_sample) return (*all_inp_list, *all_nv_list, caption, ins) # return [cano_sample, nv_sample, caption, ins] # return (*cano_sample, *nv_sample, caption, ins) def paired_post_process(self, sample): # repeat n times? all_inp_list = [] all_nv_list = [] caption, ins = sample[-2:] # expanded_return = [] for _ in range(self.pair_per_instance): cano_idx, nv_idx = self.rand_pair() cano_sample = self._post_process_sample(item[cano_idx] for item in sample[:-2]) nv_sample = self._post_process_sample(item[nv_idx] for item in sample[:-2]) all_inp_list.extend(cano_sample) all_nv_list.extend(nv_sample) return (*all_inp_list, *all_nv_list, caption, ins) # return [cano_sample, nv_sample, caption, ins] # return (*cano_sample, *nv_sample, caption, ins) # def single_sample_create_dict(self, sample, prefix=''): # # if len(sample) == 1: # # sample = sample[0] # # assert len(sample) == 6 # img_to_encoder, img, fg_mask_reso, depth_reso, c, bbox = sample # return { # # **sample, # f'{prefix}img_to_encoder': img_to_encoder, # f'{prefix}img': img, # f'{prefix}depth_mask': fg_mask_reso, # f'{prefix}depth': depth_reso, # f'{prefix}c': c, # f'{prefix}bbox': bbox, # } def single_sample_create_dict(self, sample, prefix=''): # if len(sample) == 1: # sample = sample[0] # assert len(sample) == 6 # img_to_encoder, img, fg_mask_reso, depth_reso, c, bbox = sample img_to_encoder, img, c, caption, ins = sample return { # **sample, 'img_to_encoder': img_to_encoder, 'img': img, 'c': c, 'caption': caption, 'ins': ins } def decode_zip(self, sample_pyd, shape=(256, 256)): if isinstance(sample_pyd, tuple): sample_pyd = sample_pyd[0] assert isinstance(sample_pyd, dict) latent = sample_pyd['latent'] caption = sample_pyd['caption'].decode('utf-8') c = sample_pyd['c'] # img = sample_pyd['img'] # st() return latent, caption, c def create_dict(self, sample): return { # **sample, 'latent': sample[0], 'caption': sample[1], 'c': sample[2], } # test tar loading def load_wds_latent_ResampledShard(file_path, batch_size, num_workers, reso, reso_encoder, test=False, preprocess=None, imgnet_normalize=True, plucker_embedding=False, decode_encode_img_only=False, **kwargs): # return raw_img, depth, c, bbox, sample_pyd['ins.pyd'], sample_pyd['fname.pyd'] post_process_cls = PostProcess_forlatent( reso, reso_encoder, imgnet_normalize=imgnet_normalize, plucker_embedding=plucker_embedding, decode_encode_img_only=decode_encode_img_only, ) if isinstance(file_path, list): # lst of shard urls all_shards = [] for url_path in file_path: all_shards.extend(wds.shardlists.expand_source(url_path)) logger.log('all_shards', all_shards) else: all_shards = file_path # to be expanded dataset = wds.DataPipeline( wds.ResampledShards(all_shards), # url_shard # at this point we have an iterator over all the shards wds.shuffle(50), wds.split_by_worker, # if multi-node wds.tarfile_to_samples(), # add wds.split_by_node here if you are using multiple nodes wds.detshuffle( 2500 ), # shuffles in the memory, leverage large RAM for more efficient loading wds.decode(wds.autodecode.basichandlers), # TODO wds.to_tuple("sample.pyd"), # extract the pyd from top level dict wds.map(post_process_cls.decode_zip), # wds.map(post_process_cls._post_process_sample), # wds.detshuffle(1000), # shuffles in the memory, leverage large RAM for more efficient loading wds.batched( 150, partial=True, # collation_fn=collate ) # streaming more data at once, and rebatch later ) loader_shard = wds.WebLoader( dataset, num_workers=num_workers, drop_last=False, batch_size=None, shuffle=False, persistent_workers=num_workers > 0).unbatched().shuffle(1000).batched(batch_size).map( post_process_cls.create_dict) # persistent_workers=num_workers > 0).unbatched().batched(batch_size).map(post_process_cls.create_dict) # 1000).batched(batch_size).map(post_process_cls.create_dict) # .map(collate) # .map(collate) # .batched(batch_size) # # .unbatched().shuffle(1000).batched(batch_size).map(post_process) # # https://github.com/webdataset/webdataset/issues/187 # return next(iter(loader_shard)) #return dataset return loader_shard