GaussianAnything-AIGC3D / datasets /g_buffer_objaverse.py
yslan's picture
init
7f51798
raw
history blame
235 kB
import os
from tqdm import tqdm
import kiui
from kiui.op import recenter
import kornia
import collections
import math
import time
import itertools
import pickle
from typing import Any
import lmdb
import cv2
import trimesh
cv2.setNumThreads(0) # disable multiprocess
# import imageio
import imageio.v3 as imageio
import numpy as np
from PIL import Image
import Imath
import OpenEXR
from pdb import set_trace as st
from pathlib import Path
import torchvision
from torchvision.transforms import v2
from einops import rearrange, repeat
from functools import partial
import io
from scipy.stats import special_ortho_group
import gzip
import random
import torch
import torch as th
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torch.utils.data.distributed import DistributedSampler
from pathlib import Path
import lz4.frame
from nsr.volumetric_rendering.ray_sampler import RaySampler
import point_cloud_utils as pcu
import torch.multiprocessing
# torch.multiprocessing.set_sharing_strategy('file_system')
from utils.general_utils import PILtoTorch, matrix_to_quaternion
from guided_diffusion import logger
import json
import webdataset as wds
from webdataset.shardlists import expand_source
# st()
from .shapenet import LMDBDataset, LMDBDataset_MV_Compressed, decompress_and_open_image_gzip, decompress_array
from kiui.op import safe_normalize
from utils.gs_utils.graphics_utils import getWorld2View2, getProjectionMatrix, getView2World
from nsr.camera_utils import generate_input_camera
def random_rotation_matrix():
# Generate a random rotation matrix in 3D
random_rotation_3d = special_ortho_group.rvs(3)
# Embed the 3x3 rotation matrix into a 4x4 matrix
rotation_matrix_4x4 = np.eye(4)
rotation_matrix_4x4[:3, :3] = random_rotation_3d
return rotation_matrix_4x4
def fov2focal(fov, pixels):
return pixels / (2 * math.tan(fov / 2))
def focal2fov(focal, pixels):
return 2 * math.atan(pixels / (2 * focal))
def resize_depth_mask(depth_to_resize, resolution):
depth_resized = cv2.resize(depth_to_resize, (resolution, resolution),
interpolation=cv2.INTER_LANCZOS4)
# interpolation=cv2.INTER_AREA)
return depth_resized, depth_resized > 0 # type: ignore
def resize_depth_mask_Tensor(depth_to_resize, resolution):
if depth_to_resize.shape[-1] != resolution:
depth_resized = torch.nn.functional.interpolate(
input=depth_to_resize.unsqueeze(1),
size=(resolution, resolution),
# mode='bilinear',
mode='nearest',
# align_corners=False,
).squeeze(1)
else:
depth_resized = depth_to_resize
return depth_resized.float(), depth_resized > 0 # type: ignore
class PostProcess:
def __init__(
self,
reso,
reso_encoder,
imgnet_normalize,
plucker_embedding,
decode_encode_img_only,
mv_input,
split_chunk_input,
duplicate_sample,
append_depth,
gs_cam_format,
orthog_duplicate,
frame_0_as_canonical,
pcd_path=None,
load_pcd=False,
split_chunk_size=8,
append_xyz=False,
) -> None:
self.load_pcd = load_pcd
if pcd_path is None: # hard-coded
pcd_path = '/cpfs01/user/lanyushi.p/data/FPS_PCD/pcd-V=6_256_again/fps-pcd/'
self.pcd_path = Path(pcd_path)
self.append_xyz = append_xyz
if append_xyz:
assert append_depth is False
self.frame_0_as_canonical = frame_0_as_canonical
self.gs_cam_format = gs_cam_format
self.append_depth = append_depth
self.plucker_embedding = plucker_embedding
self.decode_encode_img_only = decode_encode_img_only
self.duplicate_sample = duplicate_sample
self.orthog_duplicate = orthog_duplicate
self.zfar = 100.0
self.znear = 0.01
transformations = []
if not split_chunk_input:
transformations.append(transforms.ToTensor())
if imgnet_normalize:
transformations.append(
transforms.Normalize((0.485, 0.456, 0.406),
(0.229, 0.224, 0.225)) # type: ignore
)
else:
transformations.append(
transforms.Normalize((0.5, 0.5, 0.5),
(0.5, 0.5, 0.5))) # type: ignore
self.normalize = transforms.Compose(transformations)
self.reso_encoder = reso_encoder
self.reso = reso
self.instance_data_length = 40
# self.pair_per_instance = 1 # compat
self.mv_input = mv_input
self.split_chunk_input = split_chunk_input # 8
self.chunk_size = split_chunk_size if split_chunk_input else 40
# assert self.chunk_size in [8, 10]
self.V = self.chunk_size // 2 # 4 views as input
# else:
# assert self.chunk_size == 20
# self.V = 12 # 6 + 6 here
# st()
assert split_chunk_input
self.pair_per_instance = 1
# else:
# self.pair_per_instance = 4 if mv_input else 2 # check whether improves IO
self.ray_sampler = RaySampler() # load xyz
def gen_rays(self, c):
# Generate rays
intrinsics, c2w = c[16:], c[:16].reshape(4, 4)
self.h = self.reso_encoder
self.w = self.reso_encoder
yy, xx = torch.meshgrid(
torch.arange(self.h, dtype=torch.float32) + 0.5,
torch.arange(self.w, dtype=torch.float32) + 0.5,
indexing='ij')
# normalize to 0-1 pixel range
yy = yy / self.h
xx = xx / self.w
# K = np.array([f_x, 0, w / 2, 0, f_y, h / 2, 0, 0, 1]).reshape(3, 3)
cx, cy, fx, fy = intrinsics[2], intrinsics[5], intrinsics[
0], intrinsics[4]
# cx *= self.w
# cy *= self.h
# f_x = f_y = fx * h / res_raw
c2w = torch.from_numpy(c2w).float()
xx = (xx - cx) / fx
yy = (yy - cy) / fy
zz = torch.ones_like(xx)
dirs = torch.stack((xx, yy, zz), dim=-1) # OpenCV convention
dirs /= torch.norm(dirs, dim=-1, keepdim=True)
dirs = dirs.reshape(-1, 3, 1)
del xx, yy, zz
# st()
dirs = (c2w[None, :3, :3] @ dirs)[..., 0]
origins = c2w[None, :3, 3].expand(self.h * self.w, -1).contiguous()
origins = origins.view(self.h, self.w, 3)
dirs = dirs.view(self.h, self.w, 3)
return origins, dirs
def _post_process_batch_sample(self,
sample): # sample is an instance batch here
caption, ins = sample[-2:]
instance_samples = []
for instance_idx in range(sample[0].shape[0]):
instance_samples.append(
self._post_process_sample(item[instance_idx]
for item in sample[:-2]))
return (*instance_samples, caption, ins)
def _post_process_sample(self, data_sample):
# raw_img, depth, c, bbox, caption, ins = data_sample
# st()
raw_img, depth, c, bbox = data_sample
bbox = (bbox * (self.reso / 256)).astype(
np.uint8) # normalize bbox to the reso range
if raw_img.shape[-2] != self.reso_encoder:
img_to_encoder = cv2.resize(raw_img,
(self.reso_encoder, self.reso_encoder),
interpolation=cv2.INTER_LANCZOS4)
else:
img_to_encoder = raw_img
img_to_encoder = self.normalize(img_to_encoder)
if self.plucker_embedding:
rays_o, rays_d = self.gen_rays(c)
rays_plucker = torch.cat(
[torch.cross(rays_o, rays_d, dim=-1), rays_d],
dim=-1).permute(2, 0, 1) # [h, w, 6] -> 6,h,w
img_to_encoder = torch.cat([img_to_encoder, rays_plucker], 0)
img = cv2.resize(raw_img, (self.reso, self.reso),
interpolation=cv2.INTER_LANCZOS4)
img = torch.from_numpy(img).permute(2, 0, 1) / 127.5 - 1
if self.decode_encode_img_only:
depth_reso, fg_mask_reso = depth, depth
else:
depth_reso, fg_mask_reso = resize_depth_mask(depth, self.reso)
# return {
# # **sample,
# 'img_to_encoder': img_to_encoder,
# 'img': img,
# 'depth_mask': fg_mask_reso,
# # 'img_sr': img_sr,
# 'depth': depth_reso,
# 'c': c,
# 'bbox': bbox,
# 'caption': caption,
# 'ins': ins
# # ! no need to load img_sr for now
# }
# if len(data_sample) == 4:
return (img_to_encoder, img, fg_mask_reso, depth_reso, c, bbox)
# else:
# return (img_to_encoder, img, fg_mask_reso, depth_reso, c, bbox, data_sample[-2], data_sample[-1])
def canonicalize_pts(self, c, pcd, for_encoder=True, canonical_idx=0):
# pcd: sampled in world space
assert c.shape[0] == self.chunk_size
assert for_encoder
# st()
B = c.shape[0]
camera_poses = c[:, :16].reshape(B, 4, 4) # 3x4
cam_radius = np.linalg.norm(
c[[0, self.V]][:, :16].reshape(2, 4, 4)[:, :3, 3],
axis=-1,
keepdims=False) # since g-buffer adopts dynamic radius here.
frame1_fixed_pos = np.repeat(np.eye(4)[None], 2, axis=0)
frame1_fixed_pos[:, 2, -1] = -cam_radius
transform = frame1_fixed_pos @ np.linalg.inv(camera_poses[[0, self.V
]]) # B 4 4
transform = np.expand_dims(transform, axis=1) # B 1 4 4
# from LGM, https://github.com/3DTopia/LGM/blob/fe8d12cff8c827df7bb77a3c8e8b37408cb6fe4c/core/provider_objaverse.py#L127
# transform = torch.tensor([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, self.opt.cam_radius], [0, 0, 0, 1]], dtype=torch.float32) @ torch.inverse(c[[0,4]])
repeated_homo_pcd = np.repeat(np.concatenate(
[pcd, np.ones_like(pcd[..., 0:1])], -1)[None],
2,
axis=0)[..., None] # B N 4 1
new_pcd = (transform @ repeated_homo_pcd)[..., :3, 0] # 2 N 3
return new_pcd
def canonicalize_pts_v6(self, c, pcd, for_encoder=True, canonical_idx=0):
exit() # deprecated function
# pcd: sampled in world space
assert c.shape[0] == self.chunk_size
assert for_encoder
encoder_canonical_idx = [0, 6, 12, 18]
B = c.shape[0]
camera_poses = c[:, :16].reshape(B, 4, 4) # 3x4
cam_radius = np.linalg.norm(
c[encoder_canonical_idx][:, :16].reshape(4, 4, 4)[:, :3, 3],
axis=-1,
keepdims=False) # since g-buffer adopts dynamic radius here.
frame1_fixed_pos = np.repeat(np.eye(4)[None], 4, axis=0)
frame1_fixed_pos[:, 2, -1] = -cam_radius
transform = frame1_fixed_pos @ np.linalg.inv(
camera_poses[encoder_canonical_idx]) # B 4 4
transform = np.expand_dims(transform, axis=1) # B 1 4 4
# from LGM, https://github.com/3DTopia/LGM/blob/fe8d12cff8c827df7bb77a3c8e8b37408cb6fe4c/core/provider_objaverse.py#L127
# transform = torch.tensor([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, self.opt.cam_radius], [0, 0, 0, 1]], dtype=torch.float32) @ torch.inverse(c[[0,4]])
repeated_homo_pcd = np.repeat(np.concatenate(
[pcd, np.ones_like(pcd[..., 0:1])], -1)[None],
4,
axis=0)[..., None] # B N 4 1
new_pcd = (transform @ repeated_homo_pcd)[..., :3, 0] # 2 N 3
return new_pcd
def normalize_camera(self, c, for_encoder=True, canonical_idx=0):
assert c.shape[0] == self.chunk_size # 8 o r10
B = c.shape[0]
camera_poses = c[:, :16].reshape(B, 4, 4) # 3x4
if for_encoder:
encoder_canonical_idx = [0, self.V]
# st()
cam_radius = np.linalg.norm(
c[encoder_canonical_idx][:, :16].reshape(2, 4, 4)[:, :3, 3],
axis=-1,
keepdims=False) # since g-buffer adopts dynamic radius here.
frame1_fixed_pos = np.repeat(np.eye(4)[None], 2, axis=0)
frame1_fixed_pos[:, 2, -1] = -cam_radius
transform = frame1_fixed_pos @ np.linalg.inv(
camera_poses[encoder_canonical_idx])
# from LGM, https://github.com/3DTopia/LGM/blob/fe8d12cff8c827df7bb77a3c8e8b37408cb6fe4c/core/provider_objaverse.py#L127
# transform = torch.tensor([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, self.opt.cam_radius], [0, 0, 0, 1]], dtype=torch.float32) @ torch.inverse(c[[0,4]])
new_camera_poses = np.repeat(
transform, self.V, axis=0
) @ camera_poses # [V, 4, 4]. np.repeat() is th.repeat_interleave()
else:
cam_radius = np.linalg.norm(
c[canonical_idx][:16].reshape(4, 4)[:3, 3],
axis=-1,
keepdims=False) # since g-buffer adopts dynamic radius here.
frame1_fixed_pos = np.eye(4)
frame1_fixed_pos[2, -1] = -cam_radius
transform = frame1_fixed_pos @ np.linalg.inv(
camera_poses[canonical_idx]) # 4,4
# from LGM, https://github.com/3DTopia/LGM/blob/fe8d12cff8c827df7bb77a3c8e8b37408cb6fe4c/core/provider_objaverse.py#L127
# transform = torch.tensor([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, self.opt.cam_radius], [0, 0, 0, 1]], dtype=torch.float32) @ torch.inverse(c[[0,4]])
new_camera_poses = np.repeat(transform[None],
self.chunk_size,
axis=0) @ camera_poses # [V, 4, 4]
c = np.concatenate([new_camera_poses.reshape(B, 16), c[:, 16:]],
axis=-1)
return c
def normalize_camera_v6(self, c, for_encoder=True, canonical_idx=0):
B = c.shape[0]
camera_poses = c[:, :16].reshape(B, 4, 4) # 3x4
if for_encoder:
assert c.shape[0] == 24
encoder_canonical_idx = [0, 6, 12, 18]
cam_radius = np.linalg.norm(
c[encoder_canonical_idx][:, :16].reshape(4, 4, 4)[:, :3, 3],
axis=-1,
keepdims=False) # since g-buffer adopts dynamic radius here.
frame1_fixed_pos = np.repeat(np.eye(4)[None], 4, axis=0)
frame1_fixed_pos[:, 2, -1] = -cam_radius
transform = frame1_fixed_pos @ np.linalg.inv(
camera_poses[encoder_canonical_idx])
# from LGM, https://github.com/3DTopia/LGM/blob/fe8d12cff8c827df7bb77a3c8e8b37408cb6fe4c/core/provider_objaverse.py#L127
# transform = torch.tensor([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, self.opt.cam_radius], [0, 0, 0, 1]], dtype=torch.float32) @ torch.inverse(c[[0,4]])
new_camera_poses = np.repeat(transform, 6,
axis=0) @ camera_poses # [V, 4, 4]
else:
assert c.shape[0] == 12
cam_radius = np.linalg.norm(
c[canonical_idx][:16].reshape(4, 4)[:3, 3],
axis=-1,
keepdims=False) # since g-buffer adopts dynamic radius here.
frame1_fixed_pos = np.eye(4)
frame1_fixed_pos[2, -1] = -cam_radius
transform = frame1_fixed_pos @ np.linalg.inv(
camera_poses[canonical_idx]) # 4,4
# from LGM, https://github.com/3DTopia/LGM/blob/fe8d12cff8c827df7bb77a3c8e8b37408cb6fe4c/core/provider_objaverse.py#L127
# transform = torch.tensor([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, self.opt.cam_radius], [0, 0, 0, 1]], dtype=torch.float32) @ torch.inverse(c[[0,4]])
new_camera_poses = np.repeat(transform[None], 12,
axis=0) @ camera_poses # [V, 4, 4]
c = np.concatenate([new_camera_poses.reshape(B, 16), c[:, 16:]],
axis=-1)
return c
def get_plucker_ray(self, c):
rays_plucker = []
for idx in range(c.shape[0]):
rays_o, rays_d = self.gen_rays(c[idx])
rays_plucker.append(
torch.cat([torch.cross(rays_o, rays_d, dim=-1), rays_d],
dim=-1).permute(2, 0, 1)) # [h, w, 6] -> 6,h,w
rays_plucker = torch.stack(rays_plucker, 0)
return rays_plucker
def _unproj_depth_given_c(self, c, depth):
# get xyz hxw for each pixel, like MCC
# img_size = self.reso
img_size = depth.shape[-1]
B = c.shape[0]
cam2world_matrix = c[:, :16].reshape(B, 4, 4)
intrinsics = c[:, 16:25].reshape(B, 3, 3)
ray_origins, ray_directions = self.ray_sampler( # shape:
cam2world_matrix, intrinsics, img_size)[:2]
depth = depth.reshape(B, -1).unsqueeze(-1)
xyz = ray_origins + depth * ray_directions # BV HW 3, already in the world space
xyz = xyz.reshape(B, img_size, img_size, 3).permute(0, 3, 1,
2) # B 3 H W
xyz = xyz.clip(
-0.45, 0.45) # g-buffer saves depth with anti-alias = True .....
xyz = torch.where(xyz.abs() == 0.45, 0, xyz) # no boundary here? Yes.
return xyz
def _post_process_sample_batch(self, data_sample):
# raw_img, depth, c, bbox, caption, ins = data_sample
alpha = None
if len(data_sample) == 4:
raw_img, depth, c, bbox = data_sample
else:
raw_img, depth, c, alpha, bbox = data_sample # put c to position 2
if isinstance(depth, tuple):
self.append_normal = True
depth, normal = depth
else:
self.append_normal = False
normal = None
# if raw_img.shape[-1] == 4:
# depth_reso, _ = resize_depth_mask_Tensor(
# torch.from_numpy(depth), self.reso)
# raw_img, fg_mask_reso = raw_img[..., :3], raw_img[..., -1]
# # st() # ! check has 1 dim in alpha?
# else:
if not isinstance(depth, torch.Tensor):
depth = torch.from_numpy(depth).float()
else:
depth = depth.float()
depth_reso, fg_mask_reso = resize_depth_mask_Tensor(depth, self.reso)
if alpha is None:
alpha = fg_mask_reso
else:
# ! resize first
# st()
alpha = torch.from_numpy(alpha / 255.0).float()
if alpha.shape[-1] != self.reso: # bilinear inteprolate reshape
alpha = torch.nn.functional.interpolate(
input=alpha.unsqueeze(1),
size=(self.reso, self.reso),
mode='bilinear',
align_corners=False,
).squeeze(1)
if self.reso < 256:
bbox = (bbox * (self.reso / 256)).astype(
np.uint8) # normalize bbox to the reso range
else: # 3dgs
bbox = bbox.astype(np.uint8)
# st() # ! shall compat with 320 input
# assert raw_img.shape[-2] == self.reso_encoder
# img_to_encoder = cv2.resize(
# raw_img, (self.reso_encoder, self.reso_encoder),
# interpolation=cv2.INTER_LANCZOS4)
# else:
# img_to_encoder = raw_img
raw_img = torch.from_numpy(raw_img).permute(0, 3, 1,
2) / 255.0 # [0,1]
if normal is not None:
normal = torch.from_numpy(normal).permute(0,3,1,2)
# if raw_img.shape[-1] != self.reso:
if raw_img.shape[1] != self.reso_encoder:
img_to_encoder = torch.nn.functional.interpolate(
input=raw_img,
size=(self.reso_encoder, self.reso_encoder),
mode='bilinear',
align_corners=False,)
img_to_encoder = self.normalize(img_to_encoder)
if normal is not None:
normal_for_encoder = torch.nn.functional.interpolate(
input=normal,
size=(self.reso_encoder, self.reso_encoder),
# mode='bilinear',
mode='nearest',
# align_corners=False,
)
else:
img_to_encoder = self.normalize(raw_img)
normal_for_encoder = normal
if raw_img.shape[-1] != self.reso:
img = torch.nn.functional.interpolate(
input=raw_img,
size=(self.reso, self.reso),
mode='bilinear',
align_corners=False,
) # [-1,1] range
img = img * 2 - 1 # as gt
if normal is not None:
normal = torch.nn.functional.interpolate(
input=normal,
size=(self.reso, self.reso),
# mode='bilinear',
mode='nearest',
# align_corners=False,
)
else:
img = raw_img * 2 - 1
# fg_mask_reso = depth[..., -1:] # ! use
pad_v6_fn = lambda x: torch.concat([x, x[:4]], 0) if isinstance(
x, torch.Tensor) else np.concatenate([x, x[:4]], 0)
# ! processing encoder input image.
# ! normalize camera feats
if self.frame_0_as_canonical: # 4 views as input per batch
# if self.chunk_size in [8, 10]:
if True:
# encoder_canonical_idx = [0, 4]
# encoder_canonical_idx = [0, self.chunk_size//2]
encoder_canonical_idx = [0, self.V]
c_for_encoder = self.normalize_camera(c, for_encoder=True)
c_for_render = self.normalize_camera(
c,
for_encoder=False,
canonical_idx=encoder_canonical_idx[0]
) # allocated to nv_c, frame0 (in 8 views) as the canonical
c_for_render_nv = self.normalize_camera(
c,
for_encoder=False,
canonical_idx=encoder_canonical_idx[1]
) # allocated to nv_c, frame0 (in 8 views) as the canonical
c_for_render = np.concatenate([c_for_render, c_for_render_nv],
axis=-1) # for compat
# st()
else:
assert self.chunk_size == 20
c_for_encoder = self.normalize_camera_v6(c,
for_encoder=True) #
paired_c_0 = np.concatenate([c[0:6], c[12:18]])
paired_c_1 = np.concatenate([c[6:12], c[18:24]])
def process_paired_camera(paired_c):
c_for_render = self.normalize_camera_v6(
paired_c, for_encoder=False, canonical_idx=0
) # allocated to nv_c, frame0 (in 8 views) as the canonical
c_for_render_nv = self.normalize_camera_v6(
paired_c, for_encoder=False, canonical_idx=6
) # allocated to nv_c, frame0 (in 8 views) as the canonical
c_for_render = np.concatenate(
[c_for_render, c_for_render_nv], axis=-1) # for compat
return c_for_render
paired_c_for_render_0 = process_paired_camera(paired_c_0)
paired_c_for_render_1 = process_paired_camera(paired_c_1)
c_for_render = np.empty(shape=(24, 50))
c_for_render[list(range(6)) +
list(range(12, 18))] = paired_c_for_render_0
c_for_render[list(range(6, 12)) +
list(range(18, 24))] = paired_c_for_render_1
else: # use g-buffer canonical c
c_for_encoder, c_for_render = c, c
if self.append_normal and normal is not None:
img_to_encoder = torch.cat([img_to_encoder, normal_for_encoder],
# img_to_encoder = torch.cat([img_to_encoder, normal],
1) # concat in C dim
if self.plucker_embedding:
# rays_plucker = self.get_plucker_ray(c)
rays_plucker = self.get_plucker_ray(c_for_encoder)
img_to_encoder = torch.cat([img_to_encoder, rays_plucker],
1) # concat in C dim
# torchvision.utils.save_image(raw_img, 'tmp/inp.png', normalize=True, value_range=(0,1), nrow=1, padding=0)
# torchvision.utils.save_image(rays_plucker[:,:3], 'tmp/plucker.png', normalize=True, value_range=(-1,1), nrow=1, padding=0)
# torchvision.utils.save_image(depth_reso.unsqueeze(1), 'tmp/depth.png', normalize=True, nrow=1, padding=0)
c = torch.from_numpy(c_for_render).to(torch.float32)
if self.append_depth:
normalized_depth = torch.from_numpy(depth_reso).clone().unsqueeze(
1) # min=0
# normalized_depth -= torch.min(normalized_depth) # always 0 here
# normalized_depth /= torch.max(normalized_depth)
# normalized_depth = normalized_depth.unsqueeze(1) * 2 - 1 # normalize to [-1,1]
# st()
img_to_encoder = torch.cat([img_to_encoder, normalized_depth],
1) # concat in C dim
elif self.append_xyz:
depth_for_unproj = depth.clone()
depth_for_unproj[depth_for_unproj ==
0] = 1e10 # so that rays_o will not appear in the final pcd.
xyz = self._unproj_depth_given_c(c.float(), depth)
# pcu.save_mesh_v(f'unproj_xyz_before_Nearest.ply', xyz[0:9].float().detach().permute(0,2,3,1).reshape(-1,3).cpu().numpy(),)
if xyz.shape[-1] != self.reso_encoder:
xyz = torch.nn.functional.interpolate(
input=xyz, # [-1,1]
# size=(self.reso, self.reso),
size=(self.reso_encoder, self.reso_encoder),
mode='nearest',
)
# pcu.save_mesh_v(f'unproj_xyz_afterNearest.ply', xyz[0:9].float().detach().permute(0,2,3,1).reshape(-1,3).cpu().numpy(),)
# st()
img_to_encoder = torch.cat([img_to_encoder, xyz], 1)
return (img_to_encoder, img, alpha, depth_reso, c,
torch.from_numpy(bbox))
def rand_sample_idx(self):
return random.randint(0, self.instance_data_length - 1)
def rand_pair(self):
return (self.rand_sample_idx() for _ in range(2))
def paired_post_process(self, sample):
# repeat n times?
all_inp_list = []
all_nv_list = []
caption, ins = sample[-2:]
# expanded_return = []
for _ in range(self.pair_per_instance):
cano_idx, nv_idx = self.rand_pair()
cano_sample = self._post_process_sample(item[cano_idx]
for item in sample[:-2])
nv_sample = self._post_process_sample(item[nv_idx]
for item in sample[:-2])
all_inp_list.extend(cano_sample)
all_nv_list.extend(nv_sample)
return (*all_inp_list, *all_nv_list, caption, ins)
# return [cano_sample, nv_sample, caption, ins]
# return (*cano_sample, *nv_sample, caption, ins)
def get_source_cw2wT(self, source_cameras_view_to_world):
return matrix_to_quaternion(
source_cameras_view_to_world[:3, :3].transpose(0, 1))
def c_to_3dgs_format(self, pose):
# TODO, switch to torch version (batched later)
c2w = pose[:16].reshape(4, 4) # 3x4
# ! load cam
w2c = np.linalg.inv(c2w)
R = np.transpose(
w2c[:3, :3]) # R is stored transposed due to 'glm' in CUDA code
T = w2c[:3, 3]
fx = pose[16]
FovX = focal2fov(fx, 1)
FovY = focal2fov(fx, 1)
tanfovx = math.tan(FovX * 0.5)
tanfovy = math.tan(FovY * 0.5)
assert tanfovx == tanfovy
trans = np.array([0.0, 0.0, 0.0])
scale = 1.0
view_world_transform = torch.tensor(getView2World(R, T, trans,
scale)).transpose(
0, 1)
world_view_transform = torch.tensor(getWorld2View2(R, T, trans,
scale)).transpose(
0, 1)
projection_matrix = getProjectionMatrix(znear=self.znear,
zfar=self.zfar,
fovX=FovX,
fovY=FovY).transpose(0, 1)
full_proj_transform = (world_view_transform.unsqueeze(0).bmm(
projection_matrix.unsqueeze(0))).squeeze(0)
camera_center = world_view_transform.inverse()[3, :3]
# ! check pytorch3d camera system alignment.
# item.update(viewpoint_cam=[viewpoint_cam])
c = {}
#
c["source_cv2wT_quat"] = self.get_source_cw2wT(view_world_transform)
c.update(
# projection_matrix=projection_matrix, # K
R=torch.from_numpy(R),
T=torch.from_numpy(T),
cam_view=world_view_transform, # world_view_transform
cam_view_proj=full_proj_transform, # full_proj_transform
cam_pos=camera_center,
tanfov=tanfovx, # TODO, fix in the renderer
orig_pose=torch.from_numpy(pose),
orig_c2w=torch.from_numpy(c2w),
orig_w2c=torch.from_numpy(w2c),
orig_intrin=torch.from_numpy(pose[16:]).reshape(3,3),
# tanfovy=tanfovy,
)
return c # dict for gs rendering
def paired_post_process_chunk(self, sample):
# st()
# sample_npz, ins, caption = sample_pyd # three items
# sample = *(sample[0][k] for k in ['raw_img', 'depth', 'c', 'bbox']), sample[-1], sample[-2]
# repeat n times?
all_inp_list = []
all_nv_list = []
auxiliary_sample = list(sample[-2:])
# caption, ins = sample[-2:]
ins = sample[-1]
assert sample[0].shape[0] == self.chunk_size # random chunks
# expanded_return = []
if self.load_pcd:
# fps_pcd = pcu.load_mesh_v(
# # str(self.pcd_path / ins / 'fps-24576.ply')) # N, 3
# str(self.pcd_path / ins / 'fps-4096.ply')) # N, 3
# # 'fps-4096.ply')) # N, 3
fps_pcd = trimesh.load(str(self.pcd_path / ins / 'fps-4096.ply')).vertices
auxiliary_sample += [fps_pcd]
assert self.duplicate_sample
# st()
if self.duplicate_sample:
# ! shuffle before process, since frame_0_as_canonical fixed c.
if self.chunk_size in [20, 18, 16, 12]:
shuffle_sample = sample[:-2] # no order shuffle required
else:
shuffle_sample = []
# indices = torch.randperm(self.chunk_size)
indices = np.random.permutation(self.chunk_size)
for _, item in enumerate(sample[:-2]):
shuffle_sample.append(item[indices]) # random shuffle
processed_sample = self._post_process_sample_batch(shuffle_sample)
# ! process pcd if frmae_0 alignment
if self.load_pcd:
if self.frame_0_as_canonical:
# ! normalize camera feats
# normalized camera feats as in paper (transform the first pose to a fixed position)
# if self.chunk_size == 20:
# auxiliary_sample[-1] = self.canonicalize_pts_v6(
# c=shuffle_sample[2],
# pcd=auxiliary_sample[-1],
# for_encoder=True) # B N 3
# else:
auxiliary_sample[-1] = self.canonicalize_pts(
c=shuffle_sample[2],
pcd=auxiliary_sample[-1],
for_encoder=True) # B N 3
else:
auxiliary_sample[-1] = np.repeat(
auxiliary_sample[-1][None], 2,
axis=0) # share the same camera syste, just repeat
assert not self.orthog_duplicate
# if self.chunk_size == 8:
all_inp_list.extend(item[:self.V] for item in processed_sample)
all_nv_list.extend(item[self.V:] for item in processed_sample)
# elif self.chunk_size == 20: # V=6
# # indices_v6 = [np.random.permutation(self.chunk_size)[:12] for _ in range(2)] # random sample 6 views from chunks
# all_inp_list.extend(item[:12] for item in processed_sample)
# # indices_v6 = np.concatenate([np.arange(12, 20), np.arange(0,4)])
# all_nv_list.extend(
# item[12:] for item in
# processed_sample) # already repeated inside batch fn
# else:
# raise NotImplementedError(self.chunk_size)
# else:
# all_inp_list.extend(item[:8] for item in processed_sample)
# all_nv_list.extend(item[8:] for item in processed_sample)
# st()
return (*all_inp_list, *all_nv_list, *auxiliary_sample)
else:
processed_sample = self._post_process_sample_batch( # avoid shuffle shorten processing time
item[:4] for item in sample[:-2])
all_inp_list.extend(item for item in processed_sample)
all_nv_list.extend(item
for item in processed_sample) # ! placeholder
# return (*all_inp_list, *all_nv_list, caption, ins)
return (*all_inp_list, *all_nv_list, *auxiliary_sample)
# randomly shuffle 8 views, avoid overfitting
def single_sample_create_dict_noBatch(self, sample, prefix=''):
# if len(sample) == 1:
# sample = sample[0]
# assert len(sample) == 6
img_to_encoder, img, fg_mask_reso, depth_reso, c, bbox = sample
if self.gs_cam_format:
# TODO, can optimize later after model converges
B, V, _ = c.shape # B 4 25
c = rearrange(c, 'B V C -> (B V) C').cpu().numpy()
# c = c.cpu().numpy()
all_gs_c = [self.c_to_3dgs_format(pose) for pose in c]
# st()
# all_gs_c = self.c_to_3dgs_format(c.cpu().numpy())
c = {
k:
rearrange(torch.stack([gs_c[k] for gs_c in all_gs_c]),
'(B V) ... -> B V ...',
B=B,
V=V)
# torch.stack([gs_c[k] for gs_c in all_gs_c])
if isinstance(all_gs_c[0][k], torch.Tensor) else all_gs_c[0][k]
for k in all_gs_c[0].keys()
}
# c = collate_gs_c
return {
# **sample,
f'{prefix}img_to_encoder': img_to_encoder,
f'{prefix}img': img,
f'{prefix}depth_mask': fg_mask_reso,
f'{prefix}depth': depth_reso,
f'{prefix}c': c,
f'{prefix}bbox': bbox,
}
def single_sample_create_dict(self, sample, prefix=''):
# if len(sample) == 1:
# sample = sample[0]
# assert len(sample) == 6
img_to_encoder, img, fg_mask_reso, depth_reso, c, bbox = sample
if self.gs_cam_format:
# TODO, can optimize later after model converges
B, V, _ = c.shape # B 4 25
c = rearrange(c, 'B V C -> (B V) C').cpu().numpy()
all_gs_c = [self.c_to_3dgs_format(pose) for pose in c]
c = {
k:
rearrange(torch.stack([gs_c[k] for gs_c in all_gs_c]),
'(B V) ... -> B V ...',
B=B,
V=V)
if isinstance(all_gs_c[0][k], torch.Tensor) else all_gs_c[0][k]
for k in all_gs_c[0].keys()
}
# c = collate_gs_c
return {
# **sample,
f'{prefix}img_to_encoder': img_to_encoder,
f'{prefix}img': img,
f'{prefix}depth_mask': fg_mask_reso,
f'{prefix}depth': depth_reso,
f'{prefix}c': c,
f'{prefix}bbox': bbox,
}
def single_instance_sample_create_dict(self, sample, prfix=''):
assert len(sample) == 42
inp_sample_list = [[] for _ in range(6)]
for item in sample[:40]:
for item_idx in range(6):
inp_sample_list[item_idx].append(item[0][item_idx])
inp_sample = self.single_sample_create_dict(
(torch.stack(item_list) for item_list in inp_sample_list),
prefix='')
return {
**inp_sample, #
'caption': sample[-2],
'ins': sample[-1]
}
def decode_gzip(self, sample_pyd, shape=(256, 256)):
# sample_npz, ins, caption = sample_pyd # three items
# c, bbox, depth, ins, caption, raw_img = sample_pyd[:5], sample_pyd[5:]
# wds.to_tuple('raw_img.jpeg', 'depth.jpeg',
# 'd_near.npy',
# 'd_far.npy',
# "c.npy", 'bbox.npy', 'ins.txt', 'caption.txt'),
# raw_img, depth, alpha_mask, d_near, d_far, c, bbox, ins, caption = sample_pyd
raw_img, depth_alpha, = sample_pyd
# return raw_img, depth_alpha
# raw_img, caption = sample_pyd
# return raw_img, caption
# st()
raw_img = rearrange(raw_img, 'h (b w) c -> b h w c', b=self.chunk_size)
depth = rearrange(depth, 'h (b w) c -> b h w c', b=self.chunk_size)
alpha_mask = rearrange(
alpha_mask, 'h (b w) c -> b h w c', b=self.chunk_size) / 255.0
d_far = d_far.reshape(self.chunk_size, 1, 1, 1)
d_near = d_near.reshape(self.chunk_size, 1, 1, 1)
# d = 1 / ( (d_normalized / 255) * (far-near) + near)
depth = 1 / ((depth / 255) * (d_far - d_near) + d_near)
depth = depth[..., 0] # decoded from jpeg
# depth = decompress_array(depth['depth'], (self.chunk_size, *shape),
# np.float32,
# decompress=True,
# decompress_fn=lz4.frame.decompress)
# return raw_img, depth, d_near, d_far, c, bbox, caption, ins
raw_img = np.concatenate([raw_img, alpha_mask[..., 0:1]], -1)
return raw_img, depth, c, bbox, caption, ins
def decode_zip(
self,
sample_pyd,
):
shape = (self.reso_encoder, self.reso_encoder)
if isinstance(sample_pyd, tuple):
sample_pyd = sample_pyd[0]
assert isinstance(sample_pyd, dict)
raw_img = decompress_and_open_image_gzip(
sample_pyd['raw_img'],
is_img=True,
decompress=True,
decompress_fn=lz4.frame.decompress)
caption = sample_pyd['caption'].decode('utf-8')
ins = sample_pyd['ins'].decode('utf-8')
c = decompress_array(sample_pyd['c'], (
self.chunk_size,
25,
),
np.float32,
decompress=True,
decompress_fn=lz4.frame.decompress)
bbox = decompress_array(
sample_pyd['bbox'],
(
self.chunk_size,
4,
),
np.float32,
# decompress=False)
decompress=True,
decompress_fn=lz4.frame.decompress)
if self.decode_encode_img_only:
depth = np.zeros(shape=(self.chunk_size,
*shape)) # save loading time
else:
depth = decompress_array(sample_pyd['depth'],
(self.chunk_size, *shape),
np.float32,
decompress=True,
decompress_fn=lz4.frame.decompress)
# return {'raw_img': raw_img, 'depth': depth, 'bbox': bbox, 'caption': caption, 'ins': ins, 'c': c}
# return raw_img, depth, c, bbox, caption, ins
# return raw_img, bbox, caption, ins
# return bbox, caption, ins
return raw_img, depth, c, bbox, caption, ins
# ! run single-instance pipeline first
# return raw_img[0], depth[0], c[0], bbox[0], caption, ins
def create_dict_nobatch(self, sample):
# sample = [item[0] for item in sample] # wds wrap items in []
sample_length = 6
# if self.load_pcd:
# sample_length += 1
cano_sample_list = [[] for _ in range(sample_length)]
nv_sample_list = [[] for _ in range(sample_length)]
# st()
# bs = (len(sample)-2) // 6
for idx in range(0, self.pair_per_instance):
cano_sample = sample[sample_length * idx:sample_length * (idx + 1)]
nv_sample = sample[sample_length * self.pair_per_instance +
sample_length * idx:sample_length *
self.pair_per_instance + sample_length *
(idx + 1)]
for item_idx in range(sample_length):
if self.frame_0_as_canonical:
# ! cycle input/output view for more pairs
if item_idx == 4:
cano_sample_list[item_idx].append(
cano_sample[item_idx][..., :25])
nv_sample_list[item_idx].append(
nv_sample[item_idx][..., :25])
cano_sample_list[item_idx].append(
nv_sample[item_idx][..., 25:])
nv_sample_list[item_idx].append(
cano_sample[item_idx][..., 25:])
else:
cano_sample_list[item_idx].append(
cano_sample[item_idx])
nv_sample_list[item_idx].append(nv_sample[item_idx])
cano_sample_list[item_idx].append(nv_sample[item_idx])
nv_sample_list[item_idx].append(cano_sample[item_idx])
else:
cano_sample_list[item_idx].append(cano_sample[item_idx])
nv_sample_list[item_idx].append(nv_sample[item_idx])
cano_sample_list[item_idx].append(nv_sample[item_idx])
nv_sample_list[item_idx].append(cano_sample[item_idx])
cano_sample = self.single_sample_create_dict_noBatch(
(torch.stack(item_list, 0) for item_list in cano_sample_list),
prefix=''
) # torch.Size([5, 10, 256, 256]). Since no batch dim here for now.
nv_sample = self.single_sample_create_dict_noBatch(
(torch.stack(item_list, 0) for item_list in nv_sample_list),
prefix='nv_')
ret_dict = {
**cano_sample,
**nv_sample,
}
if not self.load_pcd:
ret_dict.update({'caption': sample[-2], 'ins': sample[-1]})
else:
# if self.frame_0_as_canonical:
# # fps_pcd = rearrange( sample[-1], 'B V ... -> (B V) ...') # ! wrong order.
# # if self.chunk_size == 8:
# fps_pcd = rearrange(
# sample[-1], 'B V ... -> (V B) ...') # mimic torch.repeat
# # else:
# # fps_pcd = rearrange( sample[-1], 'B V ... -> (B V) ...') # ugly code to match the input format...
# else:
# fps_pcd = sample[-1].repeat(
# 2, 1,
# 1) # mimic torch.cat(), from torch.Size([3, 4096, 3])
# ! TODO, check fps_pcd order
ret_dict.update({
'caption': sample[-3],
'ins': sample[-2],
'fps_pcd': sample[-1]
})
return ret_dict
def create_dict(self, sample):
# sample = [item[0] for item in sample] # wds wrap items in []
# st()
sample_length = 6
# if self.load_pcd:
# sample_length += 1
cano_sample_list = [[] for _ in range(sample_length)]
nv_sample_list = [[] for _ in range(sample_length)]
# st()
# bs = (len(sample)-2) // 6
for idx in range(0, self.pair_per_instance):
cano_sample = sample[sample_length * idx:sample_length * (idx + 1)]
nv_sample = sample[sample_length * self.pair_per_instance +
sample_length * idx:sample_length *
self.pair_per_instance + sample_length *
(idx + 1)]
for item_idx in range(sample_length):
if self.frame_0_as_canonical:
# ! cycle input/output view for more pairs
if item_idx == 4:
cano_sample_list[item_idx].append(
cano_sample[item_idx][..., :25])
nv_sample_list[item_idx].append(
nv_sample[item_idx][..., :25])
cano_sample_list[item_idx].append(
nv_sample[item_idx][..., 25:])
nv_sample_list[item_idx].append(
cano_sample[item_idx][..., 25:])
else:
cano_sample_list[item_idx].append(
cano_sample[item_idx])
nv_sample_list[item_idx].append(nv_sample[item_idx])
cano_sample_list[item_idx].append(nv_sample[item_idx])
nv_sample_list[item_idx].append(cano_sample[item_idx])
else:
cano_sample_list[item_idx].append(cano_sample[item_idx])
nv_sample_list[item_idx].append(nv_sample[item_idx])
cano_sample_list[item_idx].append(nv_sample[item_idx])
nv_sample_list[item_idx].append(cano_sample[item_idx])
# if self.split_chunk_input:
# cano_sample = self.single_sample_create_dict(
# (torch.cat(item_list, 0) for item_list in cano_sample_list),
# prefix='')
# nv_sample = self.single_sample_create_dict(
# (torch.cat(item_list, 0) for item_list in nv_sample_list),
# prefix='nv_')
# else:
# st()
cano_sample = self.single_sample_create_dict(
(torch.cat(item_list, 0) for item_list in cano_sample_list),
prefix='') # torch.Size([4, 4, 10, 256, 256])
nv_sample = self.single_sample_create_dict(
(torch.cat(item_list, 0) for item_list in nv_sample_list),
prefix='nv_')
ret_dict = {
**cano_sample,
**nv_sample,
}
if not self.load_pcd:
ret_dict.update({'caption': sample[-2], 'ins': sample[-1]})
else:
if self.frame_0_as_canonical:
# fps_pcd = rearrange( sample[-1], 'B V ... -> (B V) ...') # ! wrong order.
# if self.chunk_size == 8:
fps_pcd = rearrange(
sample[-1], 'B V ... -> (V B) ...') # mimic torch.repeat
# else:
# fps_pcd = rearrange( sample[-1], 'B V ... -> (B V) ...') # ugly code to match the input format...
else:
fps_pcd = sample[-1].repeat(
2, 1,
1) # mimic torch.cat(), from torch.Size([3, 4096, 3])
ret_dict.update({
'caption': sample[-3],
'ins': sample[-2],
'fps_pcd': fps_pcd
})
return ret_dict
def prepare_mv_input(self, sample):
# sample = [item[0] for item in sample] # wds wrap items in []
bs = len(sample['caption']) # number of instances
chunk_size = sample['img'].shape[0] // bs
assert self.split_chunk_input
for k, v in sample.items():
if isinstance(v, torch.Tensor) and k != 'fps_pcd':
sample[k] = rearrange(v, "b f c ... -> (b f) c ...",
f=self.V).contiguous()
# # ! shift nv
# else:
# for k, v in sample.items():
# if k not in ['ins', 'caption']:
# rolled_idx = torch.LongTensor(
# list(
# itertools.chain.from_iterable(
# list(range(i, sample['img'].shape[0], bs))
# for i in range(bs))))
# v = torch.index_select(v, dim=0, index=rolled_idx)
# sample[k] = v
# # img = sample['img']
# # gt = sample['nv_img']
# # torchvision.utils.save_image(img[0], 'inp.jpg', normalize=True)
# # torchvision.utils.save_image(gt[0], 'nv.jpg', normalize=True)
# for k, v in sample.items():
# if 'nv' in k:
# rolled_idx = torch.LongTensor(
# list(
# itertools.chain.from_iterable(
# list(
# np.roll(
# np.arange(i * chunk_size, (i + 1) *
# chunk_size), 4)
# for i in range(bs)))))
# v = torch.index_select(v, dim=0, index=rolled_idx)
# sample[k] = v
# torchvision.utils.save_image(sample['nv_img'], 'nv.png', normalize=True)
# torchvision.utils.save_image(sample['img'], 'inp.png', normalize=True)
return sample
def load_dataset(
file_path="",
reso=64,
reso_encoder=224,
batch_size=1,
# shuffle=True,
num_workers=6,
load_depth=False,
preprocess=None,
imgnet_normalize=True,
dataset_size=-1,
trainer_name='input_rec',
use_lmdb=False,
use_wds=False,
use_chunk=False,
use_lmdb_compressed=False,
infi_sampler=True):
# st()
# dataset_cls = {
# 'input_rec': MultiViewDataset,
# 'nv': NovelViewDataset,
# }[trainer_name]
# st()
if use_wds:
return load_wds_data(file_path, reso, reso_encoder, batch_size,
num_workers)
if use_lmdb:
logger.log('using LMDB dataset')
# dataset_cls = LMDBDataset_MV # 2.5-3iter/s, but unstable, drops to 1 later.
if use_lmdb_compressed:
if 'nv' in trainer_name:
dataset_cls = Objv_LMDBDataset_NV_Compressed # 2.5-3iter/s, but unstable, drops to 1 later.
else:
dataset_cls = Objv_LMDBDataset_MV_Compressed # 2.5-3iter/s, but unstable, drops to 1 later.
else:
if 'nv' in trainer_name:
dataset_cls = Objv_LMDBDataset_NV_NoCompressed # 2.5-3iter/s, but unstable, drops to 1 later.
else:
dataset_cls = Objv_LMDBDataset_MV_NoCompressed # 2.5-3iter/s, but unstable, drops to 1 later.
# dataset = dataset_cls(file_path)
elif use_chunk:
dataset_cls = ChunkObjaverseDataset
else:
if 'nv' in trainer_name:
dataset_cls = NovelViewObjverseDataset
else:
dataset_cls = MultiViewObjverseDataset # 1.5-2iter/s
dataset = dataset_cls(file_path,
reso,
reso_encoder,
test=False,
preprocess=preprocess,
load_depth=load_depth,
imgnet_normalize=imgnet_normalize,
dataset_size=dataset_size)
logger.log('dataset_cls: {}, dataset size: {}'.format(
trainer_name, len(dataset)))
if use_chunk:
def chunk_collate_fn(sample):
# st()
default_collate_sample = torch.utils.data.default_collate(
sample[0])
st()
return default_collate_sample
collate_fn = chunk_collate_fn
else:
collate_fn = None
loader = DataLoader(dataset,
batch_size=batch_size,
num_workers=num_workers,
drop_last=False,
pin_memory=True,
persistent_workers=num_workers > 0,
shuffle=use_chunk,
collate_fn=collate_fn)
return loader
def chunk_collate_fn(sample):
sample = torch.utils.data.default_collate(sample)
# ! change from stack to cat
# sample = self.post_process.prepare_mv_input(sample)
bs = len(sample['caption']) # number of instances
# chunk_size = sample['img'].shape[0] // bs
def merge_internal_batch(sample, merge_b_only=False):
for k, v in sample.items():
if isinstance(v, torch.Tensor):
if v.ndim > 1:
if k == 'fps_pcd' or merge_b_only:
sample[k] = rearrange(
v,
"b1 b2 ... -> (b1 b2) ...").float().contiguous()
else:
sample[k] = rearrange(
v, "b1 b2 f c ... -> (b1 b2 f) c ...").float(
).contiguous()
elif k == 'tanfov':
sample[k] = v[0].float().item() # tanfov.
if isinstance(sample['c'], dict): # 3dgs
merge_internal_batch(sample['c'], merge_b_only=True)
merge_internal_batch(sample['nv_c'], merge_b_only=True)
merge_internal_batch(sample)
return sample
def chunk_ddpm_collate_fn(sample):
sample = torch.utils.data.default_collate(sample)
# ! change from stack to cat
# sample = self.post_process.prepare_mv_input(sample)
# bs = len(sample['caption']) # number of instances
# chunk_size = sample['img'].shape[0] // bs
def merge_internal_batch(sample, merge_b_only=False):
for k, v in sample.items():
if isinstance(v, torch.Tensor):
if v.ndim > 1:
# if k in ['c', 'latent']:
sample[k] = rearrange(
v,
"b1 b2 ... -> (b1 b2) ...").float().contiguous()
# else: # img
# sample[k] = rearrange(
# v, "b1 b2 f ... -> (b1 b2 f) ...").float(
# ).contiguous()
else: # caption & ins
v = [v[i][0] for i in range(len(v))]
merge_internal_batch(sample)
# if 'caption' in sample:
# sample['caption'] = sample['caption'][0] + sample['caption'][1]
return sample
def load_data_cls(
file_path="",
reso=64,
reso_encoder=224,
batch_size=1,
# shuffle=True,
num_workers=6,
load_depth=False,
preprocess=None,
imgnet_normalize=True,
dataset_size=-1,
trainer_name='input_rec',
use_lmdb=False,
use_wds=False,
use_chunk=False,
use_lmdb_compressed=False,
# plucker_embedding=False,
# frame_0_as_canonical=False,
infi_sampler=True,
load_latent=False,
return_dataset=False,
load_caption_dataset=False,
load_mv_dataset=False,
**kwargs):
# st()
# dataset_cls = {
# 'input_rec': MultiViewDataset,
# 'nv': NovelViewDataset,
# }[trainer_name]
# st()
# if use_lmdb:
# logger.log('using LMDB dataset')
# # dataset_cls = LMDBDataset_MV # 2.5-3iter/s, but unstable, drops to 1 later.
# if 'nv' in trainer_name:
# dataset_cls = Objv_LMDBDataset_NV_Compressed # 2.5-3iter/s, but unstable, drops to 1 later.
# else:
# dataset_cls = Objv_LMDBDataset_MV_Compressed # 2.5-3iter/s, but unstable, drops to 1 later.
# # dataset = dataset_cls(file_path)
collate_fn = None
if use_lmdb:
logger.log('using LMDB dataset')
# dataset_cls = LMDBDataset_MV # 2.5-3iter/s, but unstable, drops to 1 later.
if use_lmdb_compressed:
if 'nv' in trainer_name:
dataset_cls = Objv_LMDBDataset_NV_Compressed # 2.5-3iter/s, but unstable, drops to 1 later.
else:
dataset_cls = Objv_LMDBDataset_MV_Compressed # 2.5-3iter/s, but unstable, drops to 1 later.
else:
if 'nv' in trainer_name:
dataset_cls = Objv_LMDBDataset_NV_NoCompressed # 2.5-3iter/s, but unstable, drops to 1 later.
else:
dataset_cls = Objv_LMDBDataset_MV_NoCompressed # 2.5-3iter/s, but unstable, drops to 1 later.
elif use_chunk:
if load_latent:
# if 'gs_cam_format' in kwargs:
if kwargs['gs_cam_format']:
if load_caption_dataset:
dataset_cls = ChunkObjaverseDatasetDDPMgsT23D
collate_fn = chunk_ddpm_collate_fn
else:
if load_mv_dataset:
# dataset_cls = ChunkObjaverseDatasetDDPMgsMV23D # ! if multi-view
dataset_cls = ChunkObjaverseDatasetDDPMgsMV23DSynthetic # ! if multi-view
# collate_fn = chunk_ddpm_collate_fn
collate_fn = None
else:
dataset_cls = ChunkObjaverseDatasetDDPMgsI23D
collate_fn = None
else:
dataset_cls = ChunkObjaverseDatasetDDPM
collate_fn = chunk_ddpm_collate_fn
else:
dataset_cls = ChunkObjaverseDataset
collate_fn = chunk_collate_fn
else:
if 'nv' in trainer_name:
dataset_cls = NovelViewObjverseDataset # 1.5-2iter/s
else:
dataset_cls = MultiViewObjverseDataset
dataset = dataset_cls(file_path,
reso,
reso_encoder,
test=False,
preprocess=preprocess,
load_depth=load_depth,
imgnet_normalize=imgnet_normalize,
dataset_size=dataset_size,
**kwargs
# plucker_embedding=plucker_embedding
)
logger.log('dataset_cls: {}, dataset size: {}'.format(
trainer_name, len(dataset)))
# st()
return dataset
def load_data(
file_path="",
reso=64,
reso_encoder=224,
batch_size=1,
# shuffle=True,
num_workers=6,
load_depth=False,
preprocess=None,
imgnet_normalize=True,
dataset_size=-1,
trainer_name='input_rec',
use_lmdb=False,
use_wds=False,
use_chunk=False,
use_lmdb_compressed=False,
# plucker_embedding=False,
# frame_0_as_canonical=False,
infi_sampler=True,
load_latent=False,
return_dataset=False,
load_caption_dataset=False,
load_mv_dataset=False,
**kwargs):
# st()
# dataset_cls = {
# 'input_rec': MultiViewDataset,
# 'nv': NovelViewDataset,
# }[trainer_name]
# st()
# if use_lmdb:
# logger.log('using LMDB dataset')
# # dataset_cls = LMDBDataset_MV # 2.5-3iter/s, but unstable, drops to 1 later.
# if 'nv' in trainer_name:
# dataset_cls = Objv_LMDBDataset_NV_Compressed # 2.5-3iter/s, but unstable, drops to 1 later.
# else:
# dataset_cls = Objv_LMDBDataset_MV_Compressed # 2.5-3iter/s, but unstable, drops to 1 later.
# # dataset = dataset_cls(file_path)
collate_fn = None
if use_lmdb:
logger.log('using LMDB dataset')
# dataset_cls = LMDBDataset_MV # 2.5-3iter/s, but unstable, drops to 1 later.
if use_lmdb_compressed:
if 'nv' in trainer_name:
dataset_cls = Objv_LMDBDataset_NV_Compressed # 2.5-3iter/s, but unstable, drops to 1 later.
else:
dataset_cls = Objv_LMDBDataset_MV_Compressed # 2.5-3iter/s, but unstable, drops to 1 later.
else:
if 'nv' in trainer_name:
dataset_cls = Objv_LMDBDataset_NV_NoCompressed # 2.5-3iter/s, but unstable, drops to 1 later.
else:
dataset_cls = Objv_LMDBDataset_MV_NoCompressed # 2.5-3iter/s, but unstable, drops to 1 later.
elif use_chunk:
# st()
if load_latent:
if kwargs['gs_cam_format']:
if load_caption_dataset:
dataset_cls = ChunkObjaverseDatasetDDPMgsT23D
# collate_fn = chunk_ddpm_collate_fn
collate_fn = None
else:
if load_mv_dataset:
# dataset_cls = ChunkObjaverseDatasetDDPMgsMV23D
dataset_cls = ChunkObjaverseDatasetDDPMgsMV23DSynthetic # ! if multi-view
# collate_fn = chunk_ddpm_collate_fn
collate_fn = None
else:
# dataset_cls = ChunkObjaverseDatasetDDPMgsI23D # load i23d
# collate_fn = None
# load mv dataset for i23d
dataset_cls = ChunkObjaverseDatasetDDPMgsI23D_loadMV
collate_fn = chunk_ddpm_collate_fn
else:
dataset_cls = ChunkObjaverseDatasetDDPM
collate_fn = chunk_ddpm_collate_fn
else:
dataset_cls = ChunkObjaverseDataset
collate_fn = chunk_collate_fn
else:
if 'nv' in trainer_name:
dataset_cls = NovelViewObjverseDataset # 1.5-2iter/s
else:
dataset_cls = MultiViewObjverseDataset
dataset = dataset_cls(file_path,
reso,
reso_encoder,
test=False,
preprocess=preprocess,
load_depth=load_depth,
imgnet_normalize=imgnet_normalize,
dataset_size=dataset_size,
**kwargs
# plucker_embedding=plucker_embedding
)
logger.log('dataset_cls: {}, dataset size: {}'.format(
trainer_name, len(dataset)))
# st()
if return_dataset:
return dataset
assert infi_sampler
if infi_sampler:
train_sampler = DistributedSampler(dataset=dataset,
shuffle=True,
drop_last=True)
loader = DataLoader(
dataset,
batch_size=batch_size,
num_workers=num_workers,
drop_last=True,
pin_memory=True,
persistent_workers=num_workers > 0,
sampler=train_sampler,
collate_fn=collate_fn,
# prefetch_factor=3 if num_workers>0 else None,
)
while True:
yield from loader
# else:
# # loader = DataLoader(dataset,
# # batch_size=batch_size,
# # num_workers=num_workers,
# # drop_last=False,
# # pin_memory=True,
# # persistent_workers=num_workers > 0,
# # shuffle=False)
# st()
# return dataset
def load_eval_data(
file_path="",
reso=64,
reso_encoder=224,
batch_size=1,
num_workers=1,
load_depth=False,
preprocess=None,
imgnet_normalize=True,
interval=1,
use_lmdb=False,
plucker_embedding=False,
load_real=False,
load_mv_real=False,
load_gso=False,
four_view_for_latent=False,
shuffle_across_cls=False,
load_extra_36_view=False,
gs_cam_format=False,
single_view_for_i23d=False,
use_chunk=False,
**kwargs,
):
collate_fn = None
if use_lmdb:
logger.log('using LMDB dataset')
dataset_cls = Objv_LMDBDataset_MV_Compressed # 2.5-3iter/s, but unstable, drops to 1 later.
dataset = dataset_cls(file_path,
reso,
reso_encoder,
test=True,
preprocess=preprocess,
load_depth=load_depth,
imgnet_normalize=imgnet_normalize,
interval=interval)
elif use_chunk:
dataset = ChunkObjaverseDataset(
file_path,
reso,
reso_encoder,
test=False,
preprocess=preprocess,
load_depth=load_depth,
imgnet_normalize=imgnet_normalize,
# dataset_size=dataset_size,
gs_cam_format=gs_cam_format,
plucker_embedding=plucker_embedding,
wds_split_all=2,
# frame_0_as_canonical=frame_0_as_canonical,
**kwargs)
collate_fn = chunk_collate_fn
elif load_real:
if load_mv_real:
dataset_cls = RealMVDataset
elif load_gso:
# st()
dataset_cls = RealDataset_GSO
else: # single-view i23d
dataset_cls = RealDataset
dataset = dataset_cls(file_path,
reso,
reso_encoder,
preprocess=preprocess,
load_depth=load_depth,
test=True,
imgnet_normalize=imgnet_normalize,
interval=interval,
plucker_embedding=plucker_embedding)
else:
dataset = MultiViewObjverseDataset(
file_path,
reso,
reso_encoder,
preprocess=preprocess,
load_depth=load_depth,
test=True,
imgnet_normalize=imgnet_normalize,
interval=interval,
plucker_embedding=plucker_embedding,
four_view_for_latent=four_view_for_latent,
load_extra_36_view=load_extra_36_view,
shuffle_across_cls=shuffle_across_cls,
gs_cam_format=gs_cam_format,
single_view_for_i23d=single_view_for_i23d,
**kwargs)
print('eval dataset size: {}'.format(len(dataset)))
# train_sampler = DistributedSampler(dataset=dataset)
loader = DataLoader(
dataset,
batch_size=batch_size,
num_workers=num_workers,
drop_last=False,
shuffle=False,
collate_fn=collate_fn,
)
# sampler=train_sampler)
# return loader
return iter(loader)
def load_data_for_lmdb(
file_path="",
reso=64,
reso_encoder=224,
batch_size=1,
# shuffle=True,
num_workers=6,
load_depth=False,
preprocess=None,
imgnet_normalize=True,
dataset_size=-1,
trainer_name='input_rec',
shuffle_across_cls=False,
four_view_for_latent=False,
wds_split=1):
# st()
# dataset_cls = {
# 'input_rec': MultiViewDataset,
# 'nv': NovelViewDataset,
# }[trainer_name]
# if 'nv' in trainer_name:
# dataset_cls = NovelViewDataset
# else:
# dataset_cls = MultiViewDataset
# st()
# dataset_cls = MultiViewObjverseDatasetforLMDB
dataset_cls = MultiViewObjverseDatasetforLMDB_nocaption
dataset = dataset_cls(file_path,
reso,
reso_encoder,
test=False,
preprocess=preprocess,
load_depth=load_depth,
imgnet_normalize=imgnet_normalize,
dataset_size=dataset_size,
shuffle_across_cls=shuffle_across_cls,
wds_split=wds_split,
four_view_for_latent=four_view_for_latent)
logger.log('dataset_cls: {}, dataset size: {}'.format(
trainer_name, len(dataset)))
# train_sampler = DistributedSampler(dataset=dataset, shuffle=True, drop_last=True)
loader = DataLoader(
dataset,
shuffle=False,
batch_size=batch_size,
num_workers=num_workers,
drop_last=False,
# prefetch_factor=2,
# prefetch_factor=3,
pin_memory=True,
persistent_workers=num_workers > 0,
)
# sampler=train_sampler)
# while True:
# yield from loader
return loader, dataset.dataset_name, len(dataset)
def load_lmdb_for_lmdb(
file_path="",
reso=64,
reso_encoder=224,
batch_size=1,
# shuffle=True,
num_workers=6,
load_depth=False,
preprocess=None,
imgnet_normalize=True,
dataset_size=-1,
trainer_name='input_rec'):
# st()
# dataset_cls = {
# 'input_rec': MultiViewDataset,
# 'nv': NovelViewDataset,
# }[trainer_name]
# if 'nv' in trainer_name:
# dataset_cls = NovelViewDataset
# else:
# dataset_cls = MultiViewDataset
# st()
dataset_cls = Objv_LMDBDataset_MV_Compressed_for_lmdb
dataset = dataset_cls(file_path,
reso,
reso_encoder,
test=False,
preprocess=preprocess,
load_depth=load_depth,
imgnet_normalize=imgnet_normalize,
dataset_size=dataset_size)
logger.log('dataset_cls: {}, dataset size: {}'.format(
trainer_name, len(dataset)))
# train_sampler = DistributedSampler(dataset=dataset, shuffle=True, drop_last=True)
loader = DataLoader(
dataset,
shuffle=False,
batch_size=batch_size,
num_workers=num_workers,
drop_last=False,
# prefetch_factor=2,
# prefetch_factor=3,
pin_memory=True,
persistent_workers=True,
)
# sampler=train_sampler)
# while True:
# yield from loader
return loader, len(dataset)
def load_memory_data(
file_path="",
reso=64,
reso_encoder=224,
batch_size=1,
num_workers=1,
# load_depth=True,
preprocess=None,
imgnet_normalize=True,
use_chunk=True,
**kwargs):
# load a single-instance into the memory to speed up training IO
# dataset = MultiViewObjverseDataset(file_path,
collate_fn = None
if use_chunk:
dataset_cls = ChunkObjaverseDataset
collate_fn = chunk_collate_fn
else:
dataset_cls = NovelViewObjverseDataset
dataset = dataset_cls(file_path,
reso,
reso_encoder,
preprocess=preprocess,
load_depth=True,
test=False,
overfitting=True,
imgnet_normalize=imgnet_normalize,
overfitting_bs=batch_size,
**kwargs)
logger.log('!!!!!!! memory dataset size: {} !!!!!!'.format(len(dataset)))
# train_sampler = DistributedSampler(dataset=dataset)
loader = DataLoader(
dataset,
batch_size=len(dataset),
num_workers=num_workers,
drop_last=False,
shuffle=False,
collate_fn = collate_fn
)
all_data: dict = next(
iter(loader)
) # torchvision.utils.save_image(all_data['img'], 'gt.jpg', normalize=True, value_range=(-1,1))
# st()
if kwargs.get('gs_cam_format', False): # gs rendering pipeline
# ! load V=4 images for training in a batch.
while True:
# st()
# indices = torch.randperm(len(dataset))[:4]
indices = torch.randperm(
len(dataset) * 2)[:batch_size] # all instances
# indices2 = torch.randperm(len(dataset))[:] # all instances
batch_c = collections.defaultdict(dict)
V = all_data['c']['source_cv2wT_quat'].shape[1]
for k in ['c', 'nv_c']:
for k_c, v_c in all_data[k].items():
if k_c == 'tanfov':
continue
try:
batch_c[k][
k_c] = torch.index_select( # ! chunk data reading pipeline
v_c,
dim=0,
index=indices
).reshape(batch_size, V, *v_c.shape[2:]).float(
) if isinstance(
v_c,
torch.Tensor) else v_c # float
except Exception as e:
st()
print(e)
# ! read chunk not required, already float
batch_c['c']['tanfov'] = all_data['c']['tanfov']
batch_c['nv_c']['tanfov'] = all_data['nv_c']['tanfov']
indices_range = torch.arange(indices[0]*V, (indices[0]+1)*V)
batch_data = {}
for k, v in all_data.items():
if k not in ['c', 'nv_c']:
try:
if k == 'fps_pcd':
batch_data[k] = torch.index_select(
v, dim=0, index=indices).float() if isinstance(
v, torch.Tensor) else v # float
else:
batch_data[k] = torch.index_select(
v, dim=0, index=indices_range).float() if isinstance(
v, torch.Tensor) else v # float
except:
st()
print(e)
memory_batch_data = {
**batch_data,
**batch_c,
}
yield memory_batch_data
else:
while True:
start_idx = np.random.randint(0, len(dataset) - batch_size + 1)
yield {
k: v[start_idx:start_idx + batch_size]
for k, v in all_data.items()
}
def read_dnormal(normald_path, cond_pos, h=None, w=None):
cond_cam_dis = np.linalg.norm(cond_pos, 2)
near = 0.867 #sqrt(3) * 0.5
near_distance = cond_cam_dis - near
normald = cv2.imread(normald_path, cv2.IMREAD_UNCHANGED).astype(np.float32)
normal, depth = normald[..., :3], normald[..., 3:]
depth[depth < near_distance] = 0
if h is not None:
assert w is not None
if depth.shape[1] != h:
depth = cv2.resize(depth, (h, w), interpolation=cv2.INTER_NEAREST
) # 512,512, 1 -> self.reso, self.reso
# depth = cv2.resize(depth, (h, w), interpolation=cv2.INTER_LANCZOS4
# ) # ! may fail if nearest. dirty data.
# st()
else:
depth = depth[..., 0]
if normal.shape[1] != h:
normal = cv2.resize(normal, (h, w),
interpolation=cv2.INTER_NEAREST
) # 512,512, 1 -> self.reso, self.reso
else:
depth = depth[..., 0]
return torch.from_numpy(depth).float(), torch.from_numpy(normal).float()
def get_intri(target_im=None, h=None, w=None, normalize=False):
if target_im is None:
assert (h is not None and w is not None)
else:
h, w = target_im.shape[:2]
fx = fy = 1422.222
res_raw = 1024
f_x = f_y = fx * h / res_raw
K = np.array([f_x, 0, w / 2, 0, f_y, h / 2, 0, 0, 1]).reshape(3, 3)
if normalize: # center is [0.5, 0.5], eg3d renderer tradition
K[:6] /= h
# print("intr: ", K)
return K
def convert_pose(C2W):
# https://github.com/modelscope/richdreamer/blob/c3d9a77fa15fc42dbae12c2d41d64aaec14efd37/dataset/gobjaverse/depth_warp_example.py#L402
flip_yz = np.eye(4)
flip_yz[1, 1] = -1
flip_yz[2, 2] = -1
C2W = np.matmul(C2W, flip_yz)
return torch.from_numpy(C2W)
def read_camera_matrix_single(json_file):
with open(json_file, 'r', encoding='utf8') as reader:
json_content = json.load(reader)
'''
# NOTE that different from unity2blender experiments.
camera_matrix = np.eye(4)
camera_matrix[:3, 0] = np.array(json_content['x'])
camera_matrix[:3, 1] = -np.array(json_content['y'])
camera_matrix[:3, 2] = -np.array(json_content['z'])
camera_matrix[:3, 3] = np.array(json_content['origin'])
'''
camera_matrix = np.eye(4) # blender-based
camera_matrix[:3, 0] = np.array(json_content['x'])
camera_matrix[:3, 1] = np.array(json_content['y'])
camera_matrix[:3, 2] = np.array(json_content['z'])
camera_matrix[:3, 3] = np.array(json_content['origin'])
# print(camera_matrix)
# '''
# return convert_pose(camera_matrix)
return camera_matrix
def unity2blender(normal):
normal_clone = normal.copy()
normal_clone[..., 0] = -normal[..., -1]
normal_clone[..., 1] = -normal[..., 0]
normal_clone[..., 2] = normal[..., 1]
return normal_clone
def unity2blender_fix(normal): # up blue, left green, front (towards inside) red
normal_clone = normal.copy()
# normal_clone[..., 0] = -normal[..., 2]
# normal_clone[..., 1] = -normal[..., 0]
normal_clone[..., 0] = -normal[..., 0] # swap r and g
normal_clone[..., 1] = -normal[..., 2]
normal_clone[..., 2] = normal[..., 1]
return normal_clone
def unity2blender_th(normal):
assert normal.shape[1] == 3 # B 3 H W...
normal_clone = normal.clone()
normal_clone[:, 0, ...] = -normal[:, -1, ...]
normal_clone[:, 1, ...] = -normal[:, 0, ...]
normal_clone[:, 2, ...] = normal[:, 1, ...]
return normal_clone
def blender2midas(img):
'''Blender: rub
midas: lub
'''
img[..., 0] = -img[..., 0]
img[..., 1] = -img[..., 1]
img[..., -1] = -img[..., -1]
return img
def current_milli_time():
return round(time.time() * 1000)
# modified from ShapeNet class
class MultiViewObjverseDataset(Dataset):
def __init__(
self,
file_path,
reso,
reso_encoder,
preprocess=None,
classes=False,
load_depth=False,
test=False,
scene_scale=1,
overfitting=False,
imgnet_normalize=True,
dataset_size=-1,
overfitting_bs=-1,
interval=1,
plucker_embedding=False,
shuffle_across_cls=False,
wds_split=1, # 4 splits to accelerate preprocessing
four_view_for_latent=False,
single_view_for_i23d=False,
load_extra_36_view=False,
gs_cam_format=False,
frame_0_as_canonical=False,
**kwargs):
self.load_extra_36_view = load_extra_36_view
# st()
self.gs_cam_format = gs_cam_format
self.frame_0_as_canonical = frame_0_as_canonical
self.four_view_for_latent = four_view_for_latent # export 0 12 30 36, 4 views for reconstruction
self.single_view_for_i23d = single_view_for_i23d
self.file_path = file_path
self.overfitting = overfitting
self.scene_scale = scene_scale
self.reso = reso
self.reso_encoder = reso_encoder
self.classes = False
self.load_depth = load_depth
self.preprocess = preprocess
self.plucker_embedding = plucker_embedding
self.intrinsics = get_intri(h=self.reso, w=self.reso,
normalize=True).reshape(9)
assert not self.classes, "Not support class condition now."
dataset_name = Path(self.file_path).stem.split('_')[0]
self.dataset_name = dataset_name
self.zfar = 100.0
self.znear = 0.01
# if test:
# self.ins_list = sorted(os.listdir(self.file_path))[0:1] # the first 1 instance for evaluation reference.
# else:
# ! TODO, read from list?
def load_single_cls_instances(file_path):
ins_list = [] # the first 1 instance for evaluation reference.
# '''
# for dict_dir in os.listdir(file_path)[:]:
# for dict_dir in os.listdir(file_path)[:]:
for dict_dir in os.listdir(file_path):
# for dict_dir in os.listdir(file_path)[:2]:
for ins_dir in os.listdir(os.path.join(file_path, dict_dir)):
# self.ins_list.append(os.path.join(self.file_path, dict_dir, ins_dir,))
# /nas/shared/V2V/yslan/logs/nips23/Reconstruction/final/objav/vae/MV/170K/infer-latents/189w/v=6-rotate/latent_dir
# st() # check latent whether saved
# root = '/nas/shared/V2V/yslan/logs/nips23/Reconstruction/final/objav/vae/MV/170K/infer-latents/189w/v=6-rotate/latent_dir'
# if os.path.exists(os.path.join(root,file_path.split('/')[-1], dict_dir, ins_dir, 'latent.npy') ):
# continue
# pcd_root = '/nas/shared/V2V/yslan/logs/nips23/Reconstruction/pcd-V=8_24576_polish'
# pcd_root = '/nas/shared/V2V/yslan/logs/nips23/Reconstruction/pcd-V=10_4096_polish'
# if os.path.exists(
# os.path.join(pcd_root, 'fps-pcd',
# file_path.split('/')[-1], dict_dir,
# ins_dir, 'fps-4096.ply')):
# continue
# ! split=8 has some missing instances
# root = '/cpfs01/user/lanyushi.p/data/chunk-jpeg-normal/bs_16_fixsave3/170K/384/'
# if os.path.exists(os.path.join(root,file_path.split('/')[-1], dict_dir, ins_dir,) ):
# continue
# else:
# ins_list.append(
# os.path.join(file_path, dict_dir, ins_dir,
# 'campos_512_v4'))
# filter out some data
if not os.path.exists(os.path.join(file_path, dict_dir, ins_dir, 'campos_512_v2')):
continue
if not os.path.exists(os.path.join(file_path, dict_dir, ins_dir, 'campos_512_v2', '00025', '00025.png')):
continue
if len(os.listdir(os.path.join(file_path, dict_dir, ins_dir, 'campos_512_v2'))) != 40:
continue
ins_list.append(
os.path.join(file_path, dict_dir, ins_dir,
'campos_512_v2'))
# '''
# check pcd performnace
# ins_list.append(
# os.path.join(file_path, '0', '10634',
# 'campos_512_v4'))
return ins_list
# st()
self.ins_list = []
# for subset in ['Animals', 'Transportations_tar', 'Furnitures']:
# for subset in ['Furnitures']:
# selected subset for training
# if False:
if True:
for subset in [ # ! around 17W instances in total.
# 'Animals',
# 'BuildingsOutdoor',
# 'daily-used',
# 'Furnitures',
# 'Food',
# 'Plants',
# 'Electronics',
# 'Transportations_tar',
# 'Human-Shape',
'gobjaverse_alignment_unzip',
]: # selected subset for training
# if os.path.exists(f'{self.file_path}/{subset}.txt'):
# dataset_list = f'{self.file_path}/{subset}_filtered.txt'
dataset_list = f'{self.file_path}/{subset}_filtered_more.txt'
assert os.path.exists(dataset_list)
if os.path.exists(dataset_list):
with open(dataset_list, 'r') as f:
self.ins_list += [os.path.join(self.file_path, item.strip()) for item in f.readlines()]
else:
self.ins_list += load_single_cls_instances(
os.path.join(self.file_path, subset))
# st()
# current_time = int(current_milli_time()
# ) # randomly shuffle given current time
# random.seed(current_time)
# random.shuffle(self.ins_list)
else: # preprocess single class
self.ins_list = load_single_cls_instances(self.file_path)
self.ins_list = sorted(self.ins_list)
if overfitting:
self.ins_list = self.ins_list[:1]
self.rgb_list = []
self.frame0_pose_list = []
self.pose_list = []
self.depth_list = []
self.data_ins_list = []
self.instance_data_length = -1
# self.pcd_path = Path('/cpfs01/shared/V2V/V2V_hdd/yslan/logs/nips23/Reconstruction/pcd-V=6/fps-pcd')
self.pcd_path = Path(
'/nas/shared/V2V/yslan/logs/nips23/Reconstruction/pcd-V=6/fps-pcd')
with open(
'/nas/shared/public/yslan/data/text_captions_cap3d.json') as f:
# '/nas/shared/V2V/yslan/aigc3d/text_captions_cap3d.json') as f:
self.caption_data = json.load(f)
self.shuffle_across_cls = shuffle_across_cls
# for ins in self.ins_list[47000:]:
if four_view_for_latent: # also saving dense pcd
# self.wds_split_all = 1 # ! when dumping latent
# self.wds_split_all = 2 # ! when dumping latent
# self.wds_split_all = 4
# self.wds_split_all = 6
# self.wds_split_all = 4
# self.wds_split_all = 5
# self.wds_split_all = 6
# self.wds_split_all = 7
# self.wds_split_all = 1
self.wds_split_all = 8
# self.wds_split_all = 2
# ins_list_to_process = self.ins_list
all_ins_size = len(self.ins_list)
ratio_size = all_ins_size // self.wds_split_all + 1
# ratio_size = int(all_ins_size / self.wds_split_all) + 1
ins_list_to_process = self.ins_list[ratio_size *
(wds_split):ratio_size *
(wds_split + 1)]
else: # ! create shards dataset
# self.wds_split_all = 4
self.wds_split_all = 8
# self.wds_split_all = 1
all_ins_size = len(self.ins_list)
random.seed(0)
random.shuffle(self.ins_list) # avoid same category appears in the same shard
ratio_size = all_ins_size // self.wds_split_all + 1
ins_list_to_process = self.ins_list[ratio_size * # 1 - 8
(wds_split - 1):ratio_size *
wds_split]
# uniform_sample = False
uniform_sample = True
# st()
for ins in tqdm(ins_list_to_process):
# ins = os.path.join(
# # self.file_path, ins , 'campos_512_v4'
# self.file_path, ins ,
# # 'compos_512_v4'
# )
# cur_rgb_path = os.path.join(self.file_path, ins, 'compos_512_v4')
# cur_pose_path = os.path.join(self.file_path, ins, 'pose')
# st()
# ][:27])
if self.four_view_for_latent:
# cur_all_fname = [t.split('.')[0] for t in os.listdir(ins)
# ] # use full set for training
# cur_all_fname = [f'{idx:05d}' for idx in [0, 12, 30, 36]
# cur_all_fname = [f'{idx:05d}' for idx in [6,12,18,24]
# cur_all_fname = [f'{idx:05d}' for idx in [7,16,24,25]
# cur_all_fname = [f'{idx:05d}' for idx in [25,26,0,9,18,27,33,39]]
cur_all_fname = [
f'{idx:05d}'
for idx in [25, 26, 6, 12, 18, 24, 27, 31, 35, 39] # ! for extracting PCD
]
# cur_all_fname = [f'{idx:05d}' for idx in [25,26,0,9,18,27,30,33,36,39]] # more down side for better bottom coverage.
# cur_all_fname = [f'{idx:05d}' for idx in [25,0, 7,15]]
# cur_all_fname = [f'{idx:05d}' for idx in [4,12,20,25,26]
# cur_all_fname = [f'{idx:05d}' for idx in [6,12,18,24,25,26]
# cur_all_fname = [f'{idx:05d}' for idx in [6,12,18,24,25,26, 39, 33, 27]
# cur_all_fname = [f'{idx:05d}' for idx in [6,12,18,24,25,26, 39, 33, 27]
# cur_all_fname = [
# f'{idx:05d}' for idx in [25, 26, 27, 30, 33, 36]
# ] # for pcd unprojection
# cur_all_fname = [
# f'{idx:05d}' for idx in [25, 26, 27, 30] # ! for infer latents
# ] #
# cur_all_fname = [
# f'{idx:05d}' for idx in [25, 27, 29, 31, 33, 35, 37
# ] # ! for infer latents
# ] #
# cur_all_fname = [
# f'{idx:05d}' for idx in [25, 27, 31, 35
# ] # ! for infer latents
# ] #
# cur_all_fname += [f'{idx:05d}' for idx in range(40) if idx not in [0,12,30,36]] # ! four views for inference
elif self.single_view_for_i23d:
# cur_all_fname = [f'{idx:05d}'
# for idx in [16]] # 20 is also fine
cur_all_fname = [f'{idx:05d}'
for idx in [2]] # ! furniture side view
else:
cur_all_fname = [t.split('.')[0] for t in os.listdir(ins)
] # use full set for training
if shuffle_across_cls:
if uniform_sample:
cur_all_fname = sorted(cur_all_fname)
# 0-24, 25 views
# 25,26, 2 views
# 27-39, 13 views
uniform_all_fname = []
# !!!! if bs=9 or 8
for idx in range(6):
if idx % 2 == 0:
chunk_all_fname = [25]
else:
chunk_all_fname = [26]
# chunk_all_fname = [25] # no bottom view required as input
# start_1 = np.random.randint(0,5) # for first 24 views
# chunk_all_fname += [start_1+uniform_idx for uniform_idx in range(0,25,5)]
start_1 = np.random.randint(0,4) # for first 24 views, v=8
chunk_all_fname += [start_1+uniform_idx for uniform_idx in range(0,25,7)] # [0-21]
start_2 = np.random.randint(0,5) + 27 # for first 24 views
chunk_all_fname += [start_2, start_2 + 4, start_2 + 8]
assert len(chunk_all_fname) == 8, len(chunk_all_fname)
uniform_all_fname += [cur_all_fname[fname] for fname in chunk_all_fname]
# ! if bs=6
# for idx in range(8):
# if idx % 2 == 0:
# chunk_all_fname = [
# 25
# ] # no bottom view required as input
# else:
# chunk_all_fname = [
# 26
# ] # no bottom view required as input
# start_1 = np.random.randint(
# 0, 7) # for first 24 views
# # chunk_all_fname += [start_1+uniform_idx for uniform_idx in range(0,25,5)]
# chunk_all_fname += [
# start_1 + uniform_idx
# for uniform_idx in range(0, 25, 9)
# ] # 0 9 18
# start_2 = np.random.randint(
# 0, 7) + 27 # for first 24 views
# # chunk_all_fname += [start_2, start_2 + 4, start_2 + 8]
# chunk_all_fname += [start_2,
# start_2 + 6] # 2 frames
# assert len(chunk_all_fname) == 6
# uniform_all_fname += [
# cur_all_fname[fname]
# for fname in chunk_all_fname
# ]
cur_all_fname = uniform_all_fname
else:
current_time = int(current_milli_time(
)) # randomly shuffle given current time
random.seed(current_time)
random.shuffle(cur_all_fname)
else:
cur_all_fname = sorted(cur_all_fname)
# ! skip the check
# if self.instance_data_length == -1:
# self.instance_data_length = len(cur_all_fname)
# else:
# try: # data missing?
# assert len(cur_all_fname) == self.instance_data_length
# except:
# # with open('error_log.txt', 'a') as f:
# # f.write(str(e) + '\n')
# with open('missing_ins_new2.txt', 'a') as f:
# f.write(str(Path(ins.parent)) +
# '\n') # remove the "campos_512_v4"
# continue
# if test: # use middle image as the novel view model input
# mid_index = len(cur_all_fname) // 3 * 2
# cur_all_fname.insert(0, cur_all_fname[mid_index])
self.frame0_pose_list += ([
os.path.join(ins, fname, fname + '.json')
for fname in [cur_all_fname[0]]
] * len(cur_all_fname))
self.pose_list += ([
os.path.join(ins, fname, fname + '.json')
for fname in cur_all_fname
])
self.rgb_list += ([
os.path.join(ins, fname, fname + '.png')
for fname in cur_all_fname
])
self.depth_list += ([
os.path.join(ins, fname, fname + '_nd.exr')
for fname in cur_all_fname
])
self.data_ins_list += ([ins] * len(cur_all_fname))
# check
# ! setup normalizataion
transformations = [
transforms.ToTensor(), # [0,1] range
]
if imgnet_normalize:
transformations.append(
transforms.Normalize((0.485, 0.456, 0.406),
(0.229, 0.224, 0.225)) # type: ignore
)
else:
transformations.append(
transforms.Normalize((0.5, 0.5, 0.5),
(0.5, 0.5, 0.5))) # type: ignore
# st()
self.normalize = transforms.Compose(transformations)
def get_source_cw2wT(self, source_cameras_view_to_world):
return matrix_to_quaternion(
source_cameras_view_to_world[:3, :3].transpose(0, 1))
def c_to_3dgs_format(self, pose):
# TODO, switch to torch version (batched later)
c2w = pose[:16].reshape(4, 4) # 3x4
# ! load cam
w2c = np.linalg.inv(c2w)
R = np.transpose(
w2c[:3, :3]) # R is stored transposed due to 'glm' in CUDA code
T = w2c[:3, 3]
fx = pose[16]
FovX = focal2fov(fx, 1)
FovY = focal2fov(fx, 1)
tanfovx = math.tan(FovX * 0.5)
tanfovy = math.tan(FovY * 0.5)
assert tanfovx == tanfovy
trans = np.array([0.0, 0.0, 0.0])
scale = 1.0
world_view_transform = torch.tensor(getWorld2View2(R, T, trans,
scale)).transpose(
0, 1)
projection_matrix = getProjectionMatrix(znear=self.znear,
zfar=self.zfar,
fovX=FovX,
fovY=FovY).transpose(0, 1)
full_proj_transform = (world_view_transform.unsqueeze(0).bmm(
projection_matrix.unsqueeze(0))).squeeze(0)
camera_center = world_view_transform.inverse()[3, :3]
view_world_transform = torch.tensor(getView2World(R, T, trans,
scale)).transpose(
0, 1)
# item.update(viewpoint_cam=[viewpoint_cam])
c = {}
c["source_cv2wT_quat"] = self.get_source_cw2wT(view_world_transform)
c.update(
# projection_matrix=projection_matrix, # K
cam_view=world_view_transform, # world_view_transform
cam_view_proj=full_proj_transform, # full_proj_transform
cam_pos=camera_center,
tanfov=tanfovx, # TODO, fix in the renderer
# orig_c2w=c2w,
# orig_w2c=w2c,
orig_pose=torch.from_numpy(pose),
orig_c2w=torch.from_numpy(c2w),
orig_w2c=torch.from_numpy(w2c),
# tanfovy=tanfovy,
)
return c # dict for gs rendering
def __len__(self):
return len(self.rgb_list)
def load_bbox(self, mask):
# st()
nonzero_value = torch.nonzero(mask)
height, width = nonzero_value.max(dim=0)[0]
top, left = nonzero_value.min(dim=0)[0]
bbox = torch.tensor([top, left, height, width], dtype=torch.float32)
return bbox
def __getitem__(self, idx):
# try:
data = self._read_data(idx)
return data
# except Exception as e:
# # with open('error_log_pcd.txt', 'a') as f:
# with open('error_log_pcd.txt', 'a') as f:
# f.write(str(e) + '\n')
# with open('error_idx_pcd.txt', 'a') as f:
# f.write(str(self.data_ins_list[idx]) + '\n')
# print(e, flush=True)
# return {}
def gen_rays(self, c2w):
# Generate rays
self.h = self.reso_encoder
self.w = self.reso_encoder
yy, xx = torch.meshgrid(
torch.arange(self.h, dtype=torch.float32) + 0.5,
torch.arange(self.w, dtype=torch.float32) + 0.5,
indexing='ij')
# normalize to 0-1 pixel range
yy = yy / self.h
xx = xx / self.w
# K = np.array([f_x, 0, w / 2, 0, f_y, h / 2, 0, 0, 1]).reshape(3, 3)
cx, cy, fx, fy = self.intrinsics[2], self.intrinsics[
5], self.intrinsics[0], self.intrinsics[4]
# cx *= self.w
# cy *= self.h
# f_x = f_y = fx * h / res_raw
c2w = torch.from_numpy(c2w).float()
xx = (xx - cx) / fx
yy = (yy - cy) / fy
zz = torch.ones_like(xx)
dirs = torch.stack((xx, yy, zz), dim=-1) # OpenCV convention
dirs /= torch.norm(dirs, dim=-1, keepdim=True)
dirs = dirs.reshape(-1, 3, 1)
del xx, yy, zz
# st()
dirs = (c2w[None, :3, :3] @ dirs)[..., 0]
origins = c2w[None, :3, 3].expand(self.h * self.w, -1).contiguous()
origins = origins.view(self.h, self.w, 3)
dirs = dirs.view(self.h, self.w, 3)
return origins, dirs
def normalize_camera(self, c, c_frame0):
# assert c.shape[0] == self.chunk_size # 8 o r10
B = c.shape[0]
camera_poses = c[:, :16].reshape(B, 4, 4) # 3x4
canonical_camera_poses = c_frame0[:, :16].reshape(B, 4, 4)
# if for_encoder:
# encoder_canonical_idx = [0, self.V]
# st()
cam_radius = np.linalg.norm(
c_frame0[:, :16].reshape(1, 4, 4)[:, :3, 3],
axis=-1,
keepdims=False) # since g-buffer adopts dynamic radius here.
frame1_fixed_pos = np.repeat(np.eye(4)[None], 1, axis=0)
frame1_fixed_pos[:, 2, -1] = -cam_radius
transform = frame1_fixed_pos @ np.linalg.inv(canonical_camera_poses)
# from LGM, https://github.com/3DTopia/LGM/blob/fe8d12cff8c827df7bb77a3c8e8b37408cb6fe4c/core/provider_objaverse.py#L127
# transform = torch.tensor([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, self.opt.cam_radius], [0, 0, 0, 1]], dtype=torch.float32) @ torch.inverse(c[[0,4]])
new_camera_poses = np.repeat(
transform, 1, axis=0
) @ camera_poses # [V, 4, 4]. np.repeat() is th.repeat_interleave()
# else:
# cam_radius = np.linalg.norm(
# c[canonical_idx][:16].reshape(4, 4)[:3, 3],
# axis=-1,
# keepdims=False
# ) # since g-buffer adopts dynamic radius here.
# frame1_fixed_pos = np.eye(4)
# frame1_fixed_pos[2, -1] = -cam_radius
# transform = frame1_fixed_pos @ np.linalg.inv(
# camera_poses[canonical_idx]) # 4,4
# # from LGM, https://github.com/3DTopia/LGM/blob/fe8d12cff8c827df7bb77a3c8e8b37408cb6fe4c/core/provider_objaverse.py#L127
# # transform = torch.tensor([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, self.opt.cam_radius], [0, 0, 0, 1]], dtype=torch.float32) @ torch.inverse(c[[0,4]])
# new_camera_poses = np.repeat(
# transform[None], self.chunk_size,
# axis=0) @ camera_poses # [V, 4, 4]
# st()
c = np.concatenate([new_camera_poses.reshape(B, 16), c[:, 16:]],
axis=-1)
# st()
return c
def _read_data(
self,
idx,
):
rgb_fname = self.rgb_list[idx]
pose_fname = self.pose_list[idx]
raw_img = imageio.imread(rgb_fname)
# ! RGBD
alpha_mask = raw_img[..., -1:] / 255
raw_img = alpha_mask * raw_img[..., :3] + (
1 - alpha_mask) * np.ones_like(raw_img[..., :3]) * 255
raw_img = raw_img.astype(
np.uint8) # otherwise, float64 won't call ToTensor()
# return raw_img
# st()
if self.preprocess is None:
img_to_encoder = cv2.resize(raw_img,
(self.reso_encoder, self.reso_encoder),
interpolation=cv2.INTER_LANCZOS4)
# interpolation=cv2.INTER_AREA)
img_to_encoder = img_to_encoder[
..., :3] #[3, reso_encoder, reso_encoder]
img_to_encoder = self.normalize(img_to_encoder)
else:
img_to_encoder = self.preprocess(Image.open(rgb_fname)) # clip
# return img_to_encoder
img = cv2.resize(raw_img, (self.reso, self.reso),
interpolation=cv2.INTER_LANCZOS4)
# interpolation=cv2.INTER_AREA)
# img_sr = cv2.resize(raw_img, (512, 512), interpolation=cv2.INTER_AREA)
# img_sr = cv2.resize(raw_img, (256, 256), interpolation=cv2.INTER_AREA) # just as refinement, since eg3d uses 64->128 final resolution
# img_sr = cv2.resize(raw_img, (128, 128), interpolation=cv2.INTER_AREA) # just as refinement, since eg3d uses 64->128 final resolution
# img_sr = cv2.resize(
# raw_img, (128, 128), interpolation=cv2.INTER_LANCZOS4
# ) # just as refinement, since eg3d uses 64->128 final resolution
# img = torch.from_numpy(img)[..., :3].permute(
# 2, 0, 1) / 255.0 #[3, reso, reso]
img = torch.from_numpy(img)[..., :3].permute(
2, 0, 1
) / 127.5 - 1 #[3, reso, reso], normalize to [-1,1], follow triplane range
# img_sr = torch.from_numpy(img_sr)[..., :3].permute(
# 2, 0, 1
# ) / 127.5 - 1 #[3, reso, reso], normalize to [-1,1], follow triplane range
c2w = read_camera_matrix_single(pose_fname) #[1, 4, 4] -> [1, 16]
# c = np.concatenate([c2w, self.intrinsics], axis=0).reshape(25) # 25, no '1' dim needed.
# return c2w
# if self.load_depth:
# depth, depth_mask, depth_mask_sr = read_dnormal(self.depth_list[idx],
# try:
depth, normal = read_dnormal(self.depth_list[idx], c2w[:3, 3:],
self.reso, self.reso)
# ! frame0 alignment
# if self.frame_0_as_canonical:
# return depth
# except:
# # print(self.depth_list[idx])
# raise NotImplementedError(self.depth_list[idx])
# if depth
try:
bbox = self.load_bbox(depth > 0)
except:
print(rgb_fname, flush=True)
with open('error_log.txt', 'a') as f:
f.write(str(rgb_fname + '\n'))
bbox = self.load_bbox(torch.ones_like(depth))
# plucker
# ! normalize camera
c = np.concatenate([c2w.reshape(16), self.intrinsics],
axis=0).reshape(25).astype(
np.float32) # 25, no '1' dim needed.
if self.frame_0_as_canonical: # 4 views as input per batch
frame0_pose_name = self.frame0_pose_list[idx]
c2w_frame0 = read_camera_matrix_single(
frame0_pose_name) #[1, 4, 4] -> [1, 16]
c = self.normalize_camera(c[None], c2w_frame0[None])[0]
c2w = c[:16].reshape(4, 4) # !
# st()
# pass
rays_o, rays_d = self.gen_rays(c2w)
rays_plucker = torch.cat([torch.cross(rays_o, rays_d, dim=-1), rays_d],
dim=-1) # [h, w, 6]
img_to_encoder = torch.cat(
[img_to_encoder, rays_plucker.permute(2, 0, 1)],
0).float() # concat in C dim
# ! add depth as input
depth, normal = read_dnormal(self.depth_list[idx], c2w[:3, 3:],
self.reso_encoder, self.reso_encoder)
normalized_depth = depth.unsqueeze(0) # min=0
img_to_encoder = torch.cat([img_to_encoder, normalized_depth],
0) # concat in C dim
if self.gs_cam_format:
c = self.c_to_3dgs_format(c)
else:
c = torch.from_numpy(c)
ret_dict = {
# 'rgb_fname': rgb_fname,
'img_to_encoder': img_to_encoder,
'img': img,
'c': c,
# 'img_sr': img_sr,
# 'ins_name': self.data_ins_list[idx]
}
# ins = str(
# (Path(self.data_ins_list[idx]).relative_to(self.file_path)).parent)
pcd_ins = Path(self.data_ins_list[idx]).relative_to(
Path(self.file_path).parent).parent
# load pcd
# fps_pcd = pcu.load_mesh_v(
# str(self.pcd_path / pcd_ins / 'fps-10000.ply'))
ins = str( # for compat
(Path(self.data_ins_list[idx]).relative_to(self.file_path)).parent)
# if self.shuffle_across_cls:
caption = self.caption_data['/'.join(ins.split('/')[1:])]
# else:
# caption = self.caption_data[ins]
ret_dict.update({
'depth': depth,
'normal': normal,
'alpha_mask': alpha_mask,
'depth_mask': depth > 0,
# 'depth_mask_sr': depth_mask_sr,
'bbox': bbox,
'caption': caption,
'rays_plucker': rays_plucker, # cam embedding used in lgm
'ins': ins, # placeholder
# 'fps_pcd': fps_pcd,
})
return ret_dict
# class MultiViewObjverseDatasetChunk(MultiViewObjverseDataset):
# def __init__(self,
# file_path,
# reso,
# reso_encoder,
# preprocess=None,
# classes=False,
# load_depth=False,
# test=False,
# scene_scale=1,
# overfitting=False,
# imgnet_normalize=True,
# dataset_size=-1,
# overfitting_bs=-1,
# interval=1,
# plucker_embedding=False,
# shuffle_across_cls=False,
# wds_split=1,
# four_view_for_latent=False,
# single_view_for_i23d=False,
# load_extra_36_view=False,
# gs_cam_format=False,
# **kwargs):
# super().__init__(file_path, reso, reso_encoder, preprocess, classes,
# load_depth, test, scene_scale, overfitting,
# imgnet_normalize, dataset_size, overfitting_bs,
# interval, plucker_embedding, shuffle_across_cls,
# wds_split, four_view_for_latent, single_view_for_i23d,
# load_extra_36_view, gs_cam_format, **kwargs)
# # load 40 views at a time, for inferring latents.
# TODO merge all the useful APIs together
class ChunkObjaverseDataset(Dataset):
def __init__(
self,
file_path,
reso,
reso_encoder,
preprocess=None,
classes=False,
load_depth=False,
test=False,
scene_scale=1,
overfitting=False,
imgnet_normalize=True,
dataset_size=-1,
overfitting_bs=-1,
interval=1,
plucker_embedding=False,
shuffle_across_cls=False,
wds_split=1, # 4 splits to accelerate preprocessing
four_view_for_latent=False,
single_view_for_i23d=False,
load_extra_36_view=False,
gs_cam_format=False,
frame_0_as_canonical=True,
split_chunk_size=10,
mv_input=True,
append_depth=False,
append_xyz=False,
wds_split_all=1,
pcd_path=None,
load_pcd=False,
read_normal=False,
load_raw=False,
load_instance_only=False,
mv_latent_dir='',
perturb_pcd_scale=0.0,
# shards_folder_num=4,
# eval=False,
**kwargs):
super().__init__()
# st()
self.mv_latent_dir = mv_latent_dir
self.load_raw = load_raw
self.load_instance_only = load_instance_only
self.read_normal = read_normal
self.file_path = file_path
self.chunk_size = split_chunk_size
self.gs_cam_format = gs_cam_format
self.frame_0_as_canonical = frame_0_as_canonical
self.four_view_for_latent = four_view_for_latent # export 0 12 30 36, 4 views for reconstruction
self.overfitting = overfitting
self.scene_scale = scene_scale
self.reso = reso
self.reso_encoder = reso_encoder
self.classes = False
self.load_depth = load_depth
self.preprocess = preprocess
self.plucker_embedding = plucker_embedding
self.intrinsics = get_intri(h=self.reso, w=self.reso,
normalize=True).reshape(9)
self.perturb_pcd_scale = perturb_pcd_scale
assert not self.classes, "Not support class condition now."
dataset_name = Path(self.file_path).stem.split('_')[0]
self.dataset_name = dataset_name
self.ray_sampler = RaySampler()
self.zfar = 100.0
self.znear = 0.01
# ! load all chunk paths
self.chunk_list = []
# if dataset_size != -1: # predefined instance
# self.chunk_list = self.fetch_chunk_list(os.path.join(self.file_path, 'debug'))
# else:
# # for shard_idx in range(1, 5): # shard_dir 1-4 by default
# for shard_idx in os.listdir(self.file_path):
# self.chunk_list += self.fetch_chunk_list(os.path.join(self.file_path, shard_idx))
def load_single_cls_instances(file_path):
ins_list = [] # the first 1 instance for evaluation reference.
for dict_dir in os.listdir(file_path)[:]: # ! for debugging
for ins_dir in os.listdir(os.path.join(file_path, dict_dir)):
ins_list.append(
os.path.join(file_path, dict_dir, ins_dir,
'campos_512_v4'))
return ins_list
# st()
if self.load_raw:
with open(
# '/nas/shared/V2V/yslan/aigc3d/text_captions_cap3d.json') as f:
# '/nas/shared/public/yslan//data/text_captions_cap3d.json') as f:
'./dataset/text_captions_3dtopia.json') as f:
self.caption_data = json.load(f)
# with open
# # '/nas/shared/V2V/yslan/aigc3d/text_captions_cap3d.json') as f:
# '/nas/shared/public/yslan//data/text_captions_cap3d.json') as f:
# # '/cpfs01/shared/public/yhluo/Projects/threed/3D-Enhancer/develop/text_captions_3dtopia.json') as f:
# self.old_caption_data = json.load(f)
for subset in [ # ! around 17.6 W instances in total.
'Animals',
# 'daily-used',
# 'BuildingsOutdoor',
# 'Furnitures',
# 'Food',
# 'Plants',
# 'Electronics',
# 'Transportations_tar',
# 'Human-Shape',
]: # selected subset for training
# self.chunk_list += load_single_cls_instances(
# os.path.join(self.file_path, subset))
with open(f'shell_scripts/raw_img_list/{subset}.txt', 'r') as f:
self.chunk_list += [os.path.join(subset, item.strip()) for item in f.readlines()]
# st() # save to local
# with open('/cpfs01/user/lanyushi.p/Repo/diffusion-3d/shell_scripts/shards_list/chunk_list.txt', 'w') as f:
# f.writelines(self.chunk_list)
# load raw g-objv dataset
# self.img_ext = 'png' # ln3diff
# for k, v in dataset_json.items(): # directly load from folders instead
# self.chunk_list.extend(v)
else:
# ! direclty load from json
with open(f'{self.file_path}/dataset.json', 'r') as f:
dataset_json = json.load(f)
# dataset_json = {'Animals': ['Animals/0/10017/1']}
if self.chunk_size == 12:
self.img_ext = 'png' # ln3diff
for k, v in dataset_json.items():
self.chunk_list.extend(v)
else:
# extract latent
assert self.chunk_size in [16,18, 20]
self.img_ext = 'jpg' # more views
for k, v in dataset_json.items():
# if k != 'BuildingsOutdoor': # cannot be handled by gs
self.chunk_list.extend(v)
# filter
# st()
# root = '/nas/shared/V2V/yslan/logs/nips23/Reconstruction/final/objav/vae/gs/infer-latents/768/8x8/animals-gs-latent/latent_dir'
# root = '/nas/shared/V2V/yslan/logs/nips23/Reconstruction/final/objav/vae/gs/infer-latents/768/8x8/animals-gs-latent-dim=10-fullset/latent_dir'
# filtered_chunk_list = []
# for v in self.chunk_list:
# if os.path.exists(os.path.join(root, v[:-2], 'gaussians.npy') ):
# continue
# filtered_chunk_list.append(v)
# self.chunk_list = filtered_chunk_list
dataset_size = len(self.chunk_list)
self.chunk_list = sorted(self.chunk_list)
# self.chunk_list, self.eval_list = self.chunk_list[:int(dataset_size*0.95)], self.chunk_list[int(dataset_size*0.95):]
# self.chunk_list = self.eval_list
# self.wds_split_all = wds_split_all # for
# self.wds_split_all = 1
# self.wds_split_all = 7
# self.wds_split_all = 4
self.wds_split_all = 1
# ! filter
# st()
if wds_split_all != 1:
# ! retrieve the right wds split
all_ins_size = len(self.chunk_list)
ratio_size = all_ins_size // self.wds_split_all + 1
# ratio_size = int(all_ins_size / self.wds_split_all) + 1
print('ratio_size: ', ratio_size, 'all_ins_size: ', all_ins_size)
self.chunk_list = self.chunk_list[ratio_size *
(wds_split):ratio_size *
(wds_split + 1)]
# st()
# load images from raw
self.rgb_list = []
if self.load_instance_only:
for ins in tqdm(self.chunk_list):
ins_name = str(Path(ins).parent)
# cur_all_fname = [f'{t:05d}' for t in range(40)] # load all instances for now
self.rgb_list += ([
os.path.join(self.file_path, ins, fname + '.png')
for fname in [f'{t}' for t in range(2)]
# for fname in [f'{t:05d}' for t in range(2)]
]) # synthetic mv data
# index mapping of mvi data to objv single-view data
self.mvi_objv_mapping = {
'0': '00000',
'1': '00012',
}
# load gt mv data
self.gt_chunk_list = []
self.gt_mv_file_path = '/cpfs01/user/lanyushi.p/data/chunk-jpeg-normal/bs_16_fixsave3/170K/512/'
assert self.chunk_size in [16,18, 20]
with open(f'{self.gt_mv_file_path}/dataset.json', 'r') as f:
dataset_json = json.load(f)
# dataset_json = {'Animals': dataset_json['Animals'] } #
self.img_ext = 'jpg' # more views
for k, v in dataset_json.items():
# if k != 'BuildingsOutdoor': # cannot be handled by gs
self.gt_chunk_list.extend(v)
elif self.load_raw:
for ins in tqdm(self.chunk_list):
#
# st()
# ins = ins[len('/nas/shared/V2V/yslan/aigc3d/unzip4/'):]
# ins_name = str(Path(ins).relative_to(self.file_path).parent)
ins_name = str(Path(ins).parent)
# latent_path = os.path.join(self.mv_latent_dir, ins_name, 'latent.npz')
# if not os.path.exists(latent_path):
# continue
cur_all_fname = [f'{t:05d}' for t in range(40)] # load all instances for now
self.rgb_list += ([
os.path.join(self.file_path, ins, fname, fname + '.png')
for fname in cur_all_fname
])
self.post_process = PostProcess(
reso,
reso_encoder,
imgnet_normalize=imgnet_normalize,
plucker_embedding=plucker_embedding,
decode_encode_img_only=False,
mv_input=mv_input,
split_chunk_input=split_chunk_size,
duplicate_sample=True,
append_depth=append_depth,
append_xyz=append_xyz,
gs_cam_format=gs_cam_format,
orthog_duplicate=False,
frame_0_as_canonical=frame_0_as_canonical,
pcd_path=pcd_path,
load_pcd=load_pcd,
split_chunk_size=split_chunk_size,
)
self.kernel = torch.tensor([[0, 1, 0], [1, 1, 1], [0, 1, 0]])
# self.no_bottom = True # avoid loading bottom vew
def fetch_chunk_list(self, file_path):
if os.path.isdir(file_path):
chunks = [
os.path.join(file_path, fname)
for fname in os.listdir(file_path) if fname.isdigit()
]
return chunks
else:
return []
def _pre_process_chunk(self):
# e.g., remove bottom view
pass
def read_chunk(self, chunk_path):
# equivalent to decode_zip() in wds
# reshape chunk
raw_img = imageio.imread(
os.path.join(chunk_path, f'raw_img.{self.img_ext}'))
h, bw, c = raw_img.shape
raw_img = raw_img.reshape(h, self.chunk_size, -1, c).transpose(
(1, 0, 2, 3))
c = np.load(os.path.join(chunk_path, 'c.npy'))
with open(os.path.join(chunk_path, 'caption.txt'),
'r',
encoding="utf-8") as f:
caption = f.read()
with open(os.path.join(chunk_path, 'ins.txt'), 'r',
encoding="utf-8") as f:
ins = f.read()
bbox = np.load(os.path.join(chunk_path, 'bbox.npy'))
if self.chunk_size > 16:
depth_alpha = imageio.imread(
os.path.join(chunk_path, 'depth_alpha.jpg')) # 2h 10w
depth_alpha = depth_alpha.reshape(h * 2, self.chunk_size,
-1).transpose((1, 0, 2))
depth, alpha = np.split(depth_alpha, 2, axis=1)
d_near_far = np.load(os.path.join(chunk_path, 'd_near_far.npy'))
d_near = d_near_far[0].reshape(self.chunk_size, 1, 1)
d_far = d_near_far[1].reshape(self.chunk_size, 1, 1)
# d = 1 / ( (d_normalized / 255) * (far-near) + near)
depth = 1 / ((depth / 255) * (d_far - d_near) + d_near)
depth[depth > 2.9] = 0.0 # background as 0, follow old tradition
# ! filter anti-alias artifacts
erode_mask = kornia.morphology.erosion(
torch.from_numpy(alpha == 255).float().unsqueeze(1),
self.kernel) # B 1 H W
depth = (torch.from_numpy(depth).unsqueeze(1) * erode_mask).squeeze(
1) # shrink anti-alias bug
else:
# load separate alpha and depth map
alpha = imageio.imread(
os.path.join(chunk_path, f'alpha.{self.img_ext}'))
alpha = alpha.reshape(h, self.chunk_size, h).transpose(
(1, 0, 2))
depth = np.load(os.path.join(chunk_path, 'depth.npz'))['depth']
# depth = depth * (alpha==255) # mask out background
# depth = np.stack([depth, alpha], -1) # rgba
# if self.no_bottom:
# raw_img
# pass
if self.read_normal:
normal = imageio.imread(os.path.join(
chunk_path, 'normal.png')).astype(np.float32) / 255.0
normal = (normal * 2 - 1).reshape(h, self.chunk_size, -1,
3).transpose((1, 0, 2, 3))
# fix g-buffer normal rendering coordinate
# normal = unity2blender(normal) # ! still wrong
normal = unity2blender_fix(normal) # !
depth = (depth, normal) # ?
return raw_img, depth, c, alpha, bbox, caption, ins
def __len__(self):
return len(self.chunk_list)
def __getitem__(self, index) -> Any:
sample = self.read_chunk(
os.path.join(self.file_path, self.chunk_list[index]))
sample = self.post_process.paired_post_process_chunk(sample)
sample = self.post_process.create_dict_nobatch(sample)
# aug pcd
# st()
if self.perturb_pcd_scale > 0:
if random.random() > 0.5:
t = np.random.rand(sample['fps_pcd'].shape[0], 1, 1) * self.perturb_pcd_scale
sample['fps_pcd'] = sample['fps_pcd'] + t * np.random.randn(*sample['fps_pcd'].shape) # type: ignore
sample['fps_pcd'] = np.clip(sample['fps_pcd'], -0.45, 0.45) # truncate noisy augmentation
return sample
class ChunkObjaverseDatasetDDPM(ChunkObjaverseDataset):
def __init__(
self,
file_path,
reso,
reso_encoder,
preprocess=None,
classes=False,
load_depth=False,
test=False,
scene_scale=1,
overfitting=False,
imgnet_normalize=True,
dataset_size=-1,
overfitting_bs=-1,
interval=1,
plucker_embedding=False,
shuffle_across_cls=False,
wds_split=1, # 4 splits to accelerate preprocessing
four_view_for_latent=False,
single_view_for_i23d=False,
load_extra_36_view=False,
gs_cam_format=False,
frame_0_as_canonical=True,
split_chunk_size=10,
mv_input=True,
append_depth=False,
append_xyz=False,
pcd_path=None,
load_pcd=False,
read_normal=False,
mv_latent_dir='',
load_raw=False,
# shards_folder_num=4,
# eval=False,
**kwargs):
super().__init__(
file_path,
reso,
reso_encoder,
preprocess=None,
classes=False,
load_depth=False,
test=False,
scene_scale=1,
overfitting=False,
imgnet_normalize=True,
dataset_size=-1,
overfitting_bs=-1,
interval=1,
plucker_embedding=False,
shuffle_across_cls=False,
wds_split=1, # 4 splits to accelerate preprocessing
four_view_for_latent=False,
single_view_for_i23d=False,
load_extra_36_view=False,
gs_cam_format=False,
frame_0_as_canonical=True,
split_chunk_size=split_chunk_size,
mv_input=True,
append_depth=False,
append_xyz=False,
pcd_path=None,
load_pcd=False,
read_normal=False,
load_raw=load_raw,
mv_latent_dir=mv_latent_dir,
# shards_folder_num=4,
# eval=False,
**kwargs)
self.n_cond_frames = 6
self.perspective_transformer = v2.RandomPerspective(distortion_scale=0.4, p=0.15, fill=1,
interpolation=torchvision.transforms.InterpolationMode.NEAREST)
self.mv_resize_cls = torchvision.transforms.Resize(320, interpolation=torchvision.transforms.InterpolationMode.BILINEAR,
max_size=None, antialias=True)
# ! read img c, caption.
def get_plucker_ray(self, c):
rays_plucker = []
for idx in range(c.shape[0]):
rays_o, rays_d = self.gen_rays(c[idx])
rays_plucker.append(
torch.cat([torch.cross(rays_o, rays_d, dim=-1), rays_d],
dim=-1).permute(2, 0, 1)) # [h, w, 6] -> 6,h,w
rays_plucker = torch.stack(rays_plucker, 0)
return rays_plucker
def read_chunk(self, chunk_path):
# equivalent to decode_zip() in wds
# reshape chunk
raw_img = imageio.imread(
os.path.join(chunk_path, f'raw_img.{self.img_ext}')).astype(np.float32)
h, bw, c = raw_img.shape
raw_img = raw_img.reshape(h, self.chunk_size, -1, c).transpose(
(1, 0, 2, 3))
c = np.load(os.path.join(chunk_path, 'c.npy')).astype(np.float32)
with open(os.path.join(chunk_path, 'caption.txt'),
'r',
encoding="utf-8") as f:
caption = f.read()
with open(os.path.join(chunk_path, 'ins.txt'), 'r',
encoding="utf-8") as f:
ins = f.read()
return raw_img, c, caption, ins
def _load_latent(self, ins):
# if 'adv' in self.mv_latent_dir: # new latent codes saved have 3 augmentations
# idx = random.choice([0,1,2])
# latent = np.load(os.path.join(self.mv_latent_dir, ins, f'latent-{idx}.npy')) # pre-calculated VAE latent
# else:
latent = np.load(os.path.join(self.mv_latent_dir, ins, 'latent.npy')) # pre-calculated VAE latent
latent = repeat(latent, 'C H W -> B C H W', B=2)
# return {'latent': latent}
return latent
def normalize_camera(self, c, c_frame0):
# assert c.shape[0] == self.chunk_size # 8 o r10
B = c.shape[0]
camera_poses = c[:, :16].reshape(B, 4, 4) # 3x4
canonical_camera_poses = c_frame0[:, :16].reshape(1, 4, 4)
inverse_canonical_pose = np.linalg.inv(canonical_camera_poses)
inverse_canonical_pose = np.repeat(inverse_canonical_pose, B, 0)
cam_radius = np.linalg.norm(
c_frame0[:, :16].reshape(1, 4, 4)[:, :3, 3],
axis=-1,
keepdims=False) # since g-buffer adopts dynamic radius here.
frame1_fixed_pos = np.repeat(np.eye(4)[None], 1, axis=0)
frame1_fixed_pos[:, 2, -1] = -cam_radius
transform = frame1_fixed_pos @ inverse_canonical_pose
new_camera_poses = np.repeat(
transform, 1, axis=0
) @ camera_poses # [V, 4, 4]. np.repeat() is th.repeat_interleave()
c = np.concatenate([new_camera_poses.reshape(B, 16), c[:, 16:]],
axis=-1)
return c
# @autocast
# def plucker_embedding(self, c):
# rays_o, rays_d = self.gen_rays(c)
# rays_plucker = torch.cat(
# [torch.cross(rays_o, rays_d, dim=-1), rays_d],
# dim=-1).permute(2, 0, 1) # [h, w, 6] -> 6,h,w
# return rays_plucker
def __getitem__(self, index) -> Any:
raw_img, c, caption, ins = self.read_chunk(
os.path.join(self.file_path, self.chunk_list[index]))
# sample = self.post_process.paired_post_process_chunk(sample)
# ! random zoom in (scale augmentation)
# for i in range(img.shape[0]):
# for v in range(img.shape[1]):
# if random.random() > 0.8:
# rand_bg_scale = random.randint(60,99) / 100
# st()
# img[i,v] = recenter(img[i,v], np.ones_like(img[i,v]), border_ratio=rand_bg_scale)
# ! process
raw_img = torch.from_numpy(raw_img).permute(0, 3, 1, 2) / 255.0 # [0,1]
if raw_img.shape[-1] != self.reso:
raw_img = torch.nn.functional.interpolate(
input=raw_img,
size=(self.reso, self.reso),
mode='bilinear',
align_corners=False,
)
img = raw_img * 2 - 1 # as gt
# ! load latent
latent, _ = self._load_latent(ins)
# ! shuffle
indices = np.random.permutation(self.chunk_size)
img = img[indices]
c = c[indices]
img = self.perspective_transformer(img) # create 3D inconsistency
# ! split along V and repeat other stuffs accordingly
img = rearrange(img, '(B V) ... -> B V ...', B=2)[:, :self.n_cond_frames]
c = rearrange(c, '(B V) ... -> B V ...', B=2)[:, :self.n_cond_frames] # 2 6 25
# rand perspective aug
caption = [caption, caption]
ins = [ins, ins]
# load plucker coord
# st()
# plucker_c = self.get_plucker_ray(rearrange(c[:, 1:1+self.n_cond_frames], "b t ... -> (b t) ..."))
# plucker_c = rearrange(c, '(B V) ... -> B V ...', B=2) # 2 6 25
# use view-space camera tradition
c[0] = self.normalize_camera(c[0], c[0,0:1])
c[1] = self.normalize_camera(c[1], c[1,0:1])
# https://github.com/TencentARC/InstantMesh/blob/7fe95627cf819748f7830b2b278f302a9d798d17/src/model.py#L70
# c = np.concatenate([c[..., :12], c[..., 16:17], c[..., 20:21], c[..., 18:19], c[..., 21:22]], axis=-1)
# c = c + np.random.randn(*c.shape) * 0.04 - 0.02
# ! to dict
# sample = self.post_process.create_dict_nobatch(sample)
ret_dict = {
'caption': caption,
'ins': ins,
'c': c,
'img': img, # fix inp img range to [-1,1]
'latent': latent,
# **latent
}
# st()
return ret_dict
class ChunkObjaverseDatasetDDPMgs(ChunkObjaverseDatasetDDPM):
def __init__(
self,
file_path,
reso,
reso_encoder,
preprocess=None,
classes=False,
load_depth=False,
test=False,
scene_scale=1,
overfitting=False,
imgnet_normalize=True,
dataset_size=-1,
overfitting_bs=-1,
interval=1,
plucker_embedding=False,
shuffle_across_cls=False,
wds_split=1, # 4 splits to accelerate preprocessing
four_view_for_latent=False,
single_view_for_i23d=False,
load_extra_36_view=False,
gs_cam_format=False,
frame_0_as_canonical=True,
split_chunk_size=10,
mv_input=True,
append_depth=False,
append_xyz=False,
pcd_path=None,
load_pcd=False,
read_normal=False,
mv_latent_dir='',
load_raw=False,
# shards_folder_num=4,
# eval=False,
**kwargs):
super().__init__(
file_path,
reso,
reso_encoder,
preprocess=preprocess,
classes=classes,
load_depth=load_depth,
test=test,
scene_scale=scene_scale,
overfitting=overfitting,
imgnet_normalize=imgnet_normalize,
dataset_size=dataset_size,
overfitting_bs=overfitting_bs,
interval=interval,
plucker_embedding=plucker_embedding,
shuffle_across_cls=shuffle_across_cls,
wds_split=wds_split, # 4 splits to accelerate preprocessing
four_view_for_latent=four_view_for_latent,
single_view_for_i23d=single_view_for_i23d,
load_extra_36_view=load_extra_36_view,
gs_cam_format=gs_cam_format,
frame_0_as_canonical=frame_0_as_canonical,
split_chunk_size=split_chunk_size,
mv_input=mv_input,
append_depth=append_depth,
append_xyz=append_xyz,
pcd_path=pcd_path,
load_pcd=load_pcd,
read_normal=read_normal,
mv_latent_dir=mv_latent_dir,
load_raw=load_raw,
# shards_folder_num=4,
# eval=False,
**kwargs)
self.avoid_loading_first = False
# self.feat_scale_factor = torch.Tensor([0.99227685, 1.014337 , 0.20842505, 0.98727155, 0.3305389 ,
# 0.38729668, 1.0155401 , 0.9728264 , 1.0009694 , 0.97328585,
# 0.2881106 , 0.1652732 , 0.3482468 , 0.9971449 , 0.99895126,
# 0.18491288]).float().reshape(1,1,-1)
# stat for normalization
# self.xyz_mean = torch.Tensor([-0.00053714, 0.08095618, -0.01914407] ).reshape(1, 3).float()
# self.xyz_std = np.array([0.14593576, 0.15753542, 0.18873914] ).reshape(1,3).astype(np.float32)
# self.xyz_std = np.array([0.14593576, 0.15753542, 0.18873914] ).reshape(1,3).astype(np.float32)
self.xyz_std = 0.164 # a global scaler
self.kl_mean = np.array([ 0.0184, 0.0024, 0.0926, 0.0517, 0.1781, 0.7137, -0.0355, 0.0267,
0.0183, 0.0164, -0.5090, 0.2406, 0.2733, -0.0256, -0.0285, 0.0761]).reshape(1,16).astype(np.float32)
self.kl_std = np.array([1.0018, 1.0309, 1.3001, 1.0160, 0.8182, 0.8023, 1.0591, 0.9789, 0.9966,
0.9448, 0.8908, 1.4595, 0.7957, 0.9871, 1.0236, 1.2923]).reshape(1,16).astype(np.float32)
def normalize_pcd_act(self, x):
return x / self.xyz_std
def normalize_kl_feat(self, latent):
# return latent / self.feat_scale_factor
return (latent-self.kl_mean) / self.kl_std
def _load_latent(self, ins, rand_pick_one=False, pick_both=False):
if 'adv' in self.mv_latent_dir: # new latent codes saved have 3 augmentations
idx = random.choice([0,1,2])
# idx = random.choice([0])
latent = np.load(os.path.join(self.mv_latent_dir, ins, f'latent-{idx}.npz')) # pre-calculated VAE latent
else:
latent = np.load(os.path.join(self.mv_latent_dir, ins, 'latent.npz')) # pre-calculated VAE latent
latent, fps_xyz = latent['latent_normalized'], latent['query_pcd_xyz'] # 2,768,16; 2,768,3
if not pick_both:
if rand_pick_one:
rand_idx = random.randint(0,1)
else:
rand_idx = 0
latent, fps_xyz = latent[rand_idx:rand_idx+1], fps_xyz[rand_idx:rand_idx+1]
# per-channel normalize to std=1 & concat
# latent_pcd = np.concatenate([self.normalize_kl_feat(latent), self.normalize_pcd_act(fps_xyz)], -1)
# latent_pcd = np.concatenate([latent, self.normalize_pcd_act(fps_xyz)], -1)
# return latent_pcd, fps_xyz
return latent, fps_xyz
def __getitem__(self, index) -> Any:
raw_img, c, caption, ins = self.read_chunk(
os.path.join(self.file_path, self.chunk_list[index]))
# sample = self.post_process.paired_post_process_chunk(sample)
# ! random zoom in (scale augmentation)
# for i in range(img.shape[0]):
# for v in range(img.shape[1]):
# if random.random() > 0.8:
# rand_bg_scale = random.randint(60,99) / 100
# st()
# img[i,v] = recenter(img[i,v], np.ones_like(img[i,v]), border_ratio=rand_bg_scale)
# ! process
raw_img = torch.from_numpy(raw_img).permute(0, 3, 1, 2) / 255.0 # [0,1]
if raw_img.shape[-1] != self.reso:
raw_img = torch.nn.functional.interpolate(
input=raw_img,
size=(self.reso, self.reso),
mode='bilinear',
align_corners=False,
)
img = raw_img * 2 - 1 # as gt
# ! load latent
# latent, _ = self._load_latent(ins)
latent, fps_xyz = self._load_latent(ins, pick_both=True) # analyzing xyz/latent disentangled diffusion
# latent, fps_xyz = latent[0], fps_xyz[0] # remove batch dim here
# fps_xyz = fps_xyz / self.scaling_factor # for xyz training
normalized_fps_xyz = self.normalize_pcd_act(fps_xyz)
if self.avoid_loading_first: # for training mv model
index = list(range(1,6)) + list(range(7,12))
img = img[index]
c = c[index]
# ! shuffle
indices = np.random.permutation(img.shape[0])
img = img[indices]
c = c[indices]
img = self.perspective_transformer(img) # create 3D inconsistency
# ! split along V and repeat other stuffs accordingly
img = rearrange(img, '(B V) ... -> B V ...', B=2)[:, :self.n_cond_frames]
c = rearrange(c, '(B V) ... -> B V ...', B=2)[:, :self.n_cond_frames] # 2 6 25
# rand perspective aug
caption = [caption, caption]
ins = [ins, ins]
# load plucker coord
# st()
# plucker_c = self.get_plucker_ray(rearrange(c[:, 1:1+self.n_cond_frames], "b t ... -> (b t) ..."))
# plucker_c = rearrange(c, '(B V) ... -> B V ...', B=2) # 2 6 25
# use view-space camera tradition
c[0] = self.normalize_camera(c[0], c[0,0:1])
c[1] = self.normalize_camera(c[1], c[1,0:1])
# ! to dict
# sample = self.post_process.create_dict_nobatch(sample)
ret_dict = {
'caption': caption,
'ins': ins,
'c': c,
'img': img, # fix inp img range to [-1,1]
'latent': latent,
'normalized-fps-xyz': normalized_fps_xyz
# **latent
}
# st()
return ret_dict
class ChunkObjaverseDatasetDDPMgsT23D(ChunkObjaverseDatasetDDPMgs):
def __init__(
self,
file_path,
reso,
reso_encoder,
preprocess=None,
classes=False,
load_depth=False,
test=False,
scene_scale=1,
overfitting=False,
imgnet_normalize=True,
dataset_size=-1,
overfitting_bs=-1,
interval=1,
plucker_embedding=False,
shuffle_across_cls=False,
wds_split=1, # 4 splits to accelerate preprocessing
four_view_for_latent=False,
single_view_for_i23d=False,
load_extra_36_view=False,
gs_cam_format=False,
frame_0_as_canonical=True,
split_chunk_size=10,
mv_input=True,
append_depth=False,
append_xyz=False,
pcd_path=None,
load_pcd=False,
read_normal=False,
mv_latent_dir='',
# shards_folder_num=4,
# eval=False,
**kwargs):
super().__init__(
file_path,
reso,
reso_encoder,
preprocess=preprocess,
classes=classes,
load_depth=load_depth,
test=test,
scene_scale=scene_scale,
overfitting=overfitting,
imgnet_normalize=imgnet_normalize,
dataset_size=dataset_size,
overfitting_bs=overfitting_bs,
interval=interval,
plucker_embedding=plucker_embedding,
shuffle_across_cls=shuffle_across_cls,
wds_split=wds_split, # 4 splits to accelerate preprocessing
four_view_for_latent=four_view_for_latent,
single_view_for_i23d=single_view_for_i23d,
load_extra_36_view=load_extra_36_view,
gs_cam_format=gs_cam_format,
frame_0_as_canonical=frame_0_as_canonical,
split_chunk_size=split_chunk_size,
mv_input=mv_input,
append_depth=append_depth,
append_xyz=append_xyz,
pcd_path=pcd_path,
load_pcd=load_pcd,
read_normal=read_normal,
mv_latent_dir=mv_latent_dir,
load_raw=True,
# shards_folder_num=4,
# eval=False,
**kwargs)
# def __len__(self):
# return 40
def __len__(self):
return len(self.rgb_list)
def __getitem__(self, index) -> Any:
rgb_path = self.rgb_list[index]
ins = str(Path(rgb_path).relative_to(self.file_path).parent.parent.parent)
# load caption
caption = self.caption_data['/'.join(ins.split('/')[1:])]
# chunk_path = os.path.join(self.file_path, self.chunk_list[index])
# # load caption
# with open(os.path.join(chunk_path, 'caption.txt'),
# 'r',
# encoding="utf-8") as f:
# caption = f.read()
# # load latent
# with open(os.path.join(chunk_path, 'ins.txt'), 'r',
# encoding="utf-8") as f:
# ins = f.read()
latent, fps_xyz = self._load_latent(ins, True) # analyzing xyz/latent disentangled diffusion
latent, fps_xyz = latent[0], fps_xyz[0] # remove batch dim here
# fps_xyz = fps_xyz / self.scaling_factor # for xyz training
normalized_fps_xyz = self.normalize_pcd_act(fps_xyz)
# ! to dict
ret_dict = {
# 'caption': caption,
'latent': latent,
# 'img': img,
'fps-xyz': fps_xyz,
'normalized-fps-xyz': normalized_fps_xyz,
'caption': caption
}
return ret_dict
class ChunkObjaverseDatasetDDPMgsI23D(ChunkObjaverseDatasetDDPMgs):
def __init__(
self,
file_path,
reso,
reso_encoder,
preprocess=None,
classes=False,
load_depth=False,
test=False,
scene_scale=1,
overfitting=False,
imgnet_normalize=True,
dataset_size=-1,
overfitting_bs=-1,
interval=1,
plucker_embedding=False,
shuffle_across_cls=False,
wds_split=1, # 4 splits to accelerate preprocessing
four_view_for_latent=False,
single_view_for_i23d=False,
load_extra_36_view=False,
gs_cam_format=False,
frame_0_as_canonical=True,
split_chunk_size=10,
mv_input=True,
append_depth=False,
append_xyz=False,
pcd_path=None,
load_pcd=False,
read_normal=False,
mv_latent_dir='',
# shards_folder_num=4,
# eval=False,
**kwargs):
super().__init__(
file_path,
reso,
reso_encoder,
preprocess=preprocess,
classes=classes,
load_depth=load_depth,
test=test,
scene_scale=scene_scale,
overfitting=overfitting,
imgnet_normalize=imgnet_normalize,
dataset_size=dataset_size,
overfitting_bs=overfitting_bs,
interval=interval,
plucker_embedding=plucker_embedding,
shuffle_across_cls=shuffle_across_cls,
wds_split=wds_split, # 4 splits to accelerate preprocessing
four_view_for_latent=four_view_for_latent,
single_view_for_i23d=single_view_for_i23d,
load_extra_36_view=load_extra_36_view,
gs_cam_format=gs_cam_format,
frame_0_as_canonical=frame_0_as_canonical,
split_chunk_size=split_chunk_size,
mv_input=mv_input,
append_depth=append_depth,
append_xyz=append_xyz,
pcd_path=pcd_path,
load_pcd=load_pcd,
read_normal=read_normal,
mv_latent_dir=mv_latent_dir,
load_raw=True,
# shards_folder_num=4,
# eval=False,
**kwargs)
assert self.load_raw
self.scaling_factor = np.array([0.14593576, 0.15753542, 0.18873914])
def __len__(self):
return len(self.rgb_list)
# def __len__(self):
# return 40
def __getitem__(self, index) -> Any:
rgb_path = self.rgb_list[index]
ins = str(Path(rgb_path).relative_to(self.file_path).parent.parent.parent)
raw_img = imageio.imread(rgb_path).astype(np.float32)
alpha_mask = raw_img[..., -1:] / 255
raw_img = alpha_mask * raw_img[..., :3] + (
1 - alpha_mask) * np.ones_like(raw_img[..., :3]) * 255
raw_img = cv2.resize(raw_img, (self.reso, self.reso), interpolation=cv2.INTER_CUBIC)
raw_img = torch.from_numpy(raw_img).permute(2,0,1).clip(0,255) # [0,1]
img = raw_img / 127.5 - 1
# with open(os.path.join(chunk_path, 'caption.txt'),
# 'r',
# encoding="utf-8") as f:
# caption = f.read()
# latent = self._load_latent(ins, True)[0]
latent, fps_xyz = self._load_latent(ins, True) # analyzing xyz/latent disentangled diffusion
latent, fps_xyz = latent[0], fps_xyz[0]
# fps_xyz = fps_xyz / self.scaling_factor # for xyz training
normalized_fps_xyz = self.normalize_pcd_act(fps_xyz)
# load caption
caption = self.caption_data['/'.join(ins.split('/')[1:])]
# ! to dict
ret_dict = {
# 'caption': caption,
'latent': latent,
'img': img.numpy(), # no idea whether loading Tensor leads to 'too many files opened'
'fps-xyz': fps_xyz,
'normalized-fps-xyz': normalized_fps_xyz,
'caption': caption
}
return ret_dict
class ChunkObjaverseDatasetDDPMgsMV23D(ChunkObjaverseDatasetDDPMgs):
def __init__(
self,
file_path,
reso,
reso_encoder,
preprocess=None,
classes=False,
load_depth=False,
test=False,
scene_scale=1,
overfitting=False,
imgnet_normalize=True,
dataset_size=-1,
overfitting_bs=-1,
interval=1,
plucker_embedding=False,
shuffle_across_cls=False,
wds_split=1, # 4 splits to accelerate preprocessing
four_view_for_latent=False,
single_view_for_i23d=False,
load_extra_36_view=False,
gs_cam_format=False,
frame_0_as_canonical=True,
split_chunk_size=10,
mv_input=True,
append_depth=False,
append_xyz=False,
pcd_path=None,
load_pcd=False,
read_normal=False,
mv_latent_dir='',
# shards_folder_num=4,
# eval=False,
**kwargs):
super().__init__(
file_path,
reso,
reso_encoder,
preprocess=preprocess,
classes=classes,
load_depth=load_depth,
test=test,
scene_scale=scene_scale,
overfitting=overfitting,
imgnet_normalize=imgnet_normalize,
dataset_size=dataset_size,
overfitting_bs=overfitting_bs,
interval=interval,
plucker_embedding=plucker_embedding,
shuffle_across_cls=shuffle_across_cls,
wds_split=wds_split, # 4 splits to accelerate preprocessing
four_view_for_latent=four_view_for_latent,
single_view_for_i23d=single_view_for_i23d,
load_extra_36_view=load_extra_36_view,
gs_cam_format=gs_cam_format,
frame_0_as_canonical=frame_0_as_canonical,
split_chunk_size=split_chunk_size,
mv_input=mv_input,
append_depth=append_depth,
append_xyz=append_xyz,
pcd_path=pcd_path,
load_pcd=load_pcd,
read_normal=read_normal,
mv_latent_dir=mv_latent_dir,
load_raw=False,
# shards_folder_num=4,
# eval=False,
**kwargs)
assert not self.load_raw
# self.scaling_factor = np.array([0.14593576, 0.15753542, 0.18873914])
self.n_cond_frames = 4 # a easy version for now.
self.avoid_loading_first = True
def __getitem__(self, index) -> Any:
raw_img, c, caption, ins = self.read_chunk(
os.path.join(self.file_path, self.chunk_list[index]))
# ! process
raw_img = torch.from_numpy(raw_img).permute(0, 3, 1, 2) / 255.0 # [0,1]
if raw_img.shape[-1] != self.reso:
raw_img = torch.nn.functional.interpolate(
input=raw_img,
size=(self.reso, self.reso),
mode='bilinear',
align_corners=False,
)
img = raw_img * 2 - 1 # as gt
# ! load latent
# latent, _ = self._load_latent(ins)
latent, fps_xyz = self._load_latent(ins, pick_both=True) # analyzing xyz/latent disentangled diffusion
# latent, fps_xyz = latent[0], fps_xyz[0] # remove batch dim here
# fps_xyz = fps_xyz / self.scaling_factor # for xyz training
normalized_fps_xyz = self.normalize_pcd_act(fps_xyz)
if self.avoid_loading_first: # for training mv model
index = list(range(1,self.chunk_size//2)) + list(range(self.chunk_size//2+1, self.chunk_size))
img = img[index]
c = c[index]
# ! shuffle
indices = np.random.permutation(img.shape[0])
img = img[indices]
c = c[indices]
aug_img = self.perspective_transformer(img) # create 3D inconsistency
# ! split along V and repeat other stuffs accordingly
img = rearrange(img, '(B V) ... -> B V ...', B=2)[:, 0:1] # only return first view (randomly sampled)
aug_img = rearrange(aug_img, '(B V) ... -> B V ...', B=2)[:, 1:self.n_cond_frames+1]
c = rearrange(c, '(B V) ... -> B V ...', B=2)[:, 1:self.n_cond_frames+1] # 2 6 25
# use view-space camera tradition
c[0] = self.normalize_camera(c[0], c[0,0:1])
c[1] = self.normalize_camera(c[1], c[1,0:1])
caption = [caption, caption]
ins = [ins, ins]
# ! to dict
# sample = self.post_process.create_dict_nobatch(sample)
ret_dict = {
'caption': caption,
'ins': ins,
'c': c,
'img': img, # fix inp img range to [-1,1]
'mv_img': aug_img,
'latent': latent,
'normalized-fps-xyz': normalized_fps_xyz
# **latent
}
# st()
return ret_dict
class ChunkObjaverseDatasetDDPMgsMV23DSynthetic(ChunkObjaverseDatasetDDPMgs):
def __init__(
self,
file_path,
reso,
reso_encoder,
preprocess=None,
classes=False,
load_depth=False,
test=False,
scene_scale=1,
overfitting=False,
imgnet_normalize=True,
dataset_size=-1,
overfitting_bs=-1,
interval=1,
plucker_embedding=False,
shuffle_across_cls=False,
wds_split=1, # 4 splits to accelerate preprocessing
four_view_for_latent=False,
single_view_for_i23d=False,
load_extra_36_view=False,
gs_cam_format=False,
frame_0_as_canonical=True,
split_chunk_size=10,
mv_input=True,
append_depth=False,
append_xyz=False,
pcd_path=None,
load_pcd=False,
read_normal=False,
mv_latent_dir='',
# shards_folder_num=4,
# eval=False,
**kwargs):
super().__init__(
file_path,
reso,
reso_encoder,
preprocess=preprocess,
classes=classes,
load_depth=load_depth,
test=test,
scene_scale=scene_scale,
overfitting=overfitting,
imgnet_normalize=imgnet_normalize,
dataset_size=dataset_size,
overfitting_bs=overfitting_bs,
interval=interval,
plucker_embedding=plucker_embedding,
shuffle_across_cls=shuffle_across_cls,
wds_split=wds_split, # 4 splits to accelerate preprocessing
four_view_for_latent=four_view_for_latent,
single_view_for_i23d=single_view_for_i23d,
load_extra_36_view=load_extra_36_view,
gs_cam_format=gs_cam_format,
frame_0_as_canonical=frame_0_as_canonical,
split_chunk_size=split_chunk_size,
mv_input=mv_input,
append_depth=append_depth,
append_xyz=append_xyz,
pcd_path=pcd_path,
load_pcd=load_pcd,
read_normal=read_normal,
mv_latent_dir=mv_latent_dir,
load_raw=True,
load_instance_only=True,
# shards_folder_num=4,
# eval=False,
**kwargs)
# assert not self.load_raw
# self.scaling_factor = np.array([0.14593576, 0.15753542, 0.18873914])
self.n_cond_frames = 6 # a easy version for now.
self.avoid_loading_first = True
self.indices = np.array([0,1,2,3,4,5])
self.img_root_dir = '/cpfs01/user/lanyushi.p/data/unzip4_img'
azimuths = np.array([30, 90, 150, 210, 270, 330]).astype(float)
elevations = np.array([20, -10, 20, -10, 20, -10]).astype(float)
zero123pp_pose, _ = generate_input_camera(1.8, [[elevations[i], azimuths[i]] for i in range(6)], fov=30)
K = torch.Tensor([1.3889, 0.0000, 0.5000, 0.0000, 1.3889, 0.5000, 0.0000, 0.0000, 0.0039]).to(zero123pp_pose) # keeps the same
zero123pp_pose = torch.cat([zero123pp_pose.reshape(6,-1), K.unsqueeze(0).repeat(6,1)], dim=-1)
eval_camera = zero123pp_pose[self.indices].float().cpu().numpy() # for normalization
self.eval_camera = self.normalize_camera(eval_camera, eval_camera[0:1]) # the first img is not used.
# self.load_synthetic_only = False
self.load_synthetic_only = True
def __len__(self):
return len(self.rgb_list)
def _getitem_synthetic(self, index) -> Any:
rgb_fname = Path(self.rgb_list[index])
# ins = self.mvi_objv_mapping(rgb_fname.parent.parent.stem)
# ins = str(Path(rgb_fname).parent.parent.stem)
ins = str((Path(rgb_fname).relative_to(self.file_path)).parent.parent)
mv_img = imageio.imread(rgb_fname)
# st()
mv_img = rearrange(mv_img, '(n h) (m w) c -> (n m) h w c', n=3, m=2)[self.indices] # (6, 3, 320, 320)
mv_img = np.stack([recenter(img, np.ones_like(img), border_ratio=0.1) for img in mv_img], axis=0)
mv_img = rearrange(mv_img, 'b h w c -> b c h w') # to torch tradition
mv_img = torch.from_numpy(mv_img) / 127.5 - 1
# ! load single-view image here
img_idx = self.mvi_objv_mapping[rgb_fname.stem]
img_path = os.path.join(self.img_root_dir, rgb_fname.parent.relative_to(self.file_path), img_idx, f'{img_idx}.png')
raw_img = imageio.imread(img_path).astype(np.float32)
alpha_mask = raw_img[..., -1:] / 255
raw_img = alpha_mask * raw_img[..., :3] + (
1 - alpha_mask) * np.ones_like(raw_img[..., :3]) * 255
raw_img = cv2.resize(raw_img, (self.reso, self.reso), interpolation=cv2.INTER_CUBIC)
raw_img = torch.from_numpy(raw_img).permute(2,0,1).clip(0,255) # [0,1]
img = raw_img / 127.5 - 1
latent, fps_xyz = self._load_latent(ins, pick_both=False) # analyzing xyz/latent disentangled diffusion
latent, fps_xyz = latent[0], fps_xyz[0]
normalized_fps_xyz = self.normalize_pcd_act(fps_xyz) # for stage-1
# use view-space camera tradition
# ins = [ins, ins]
# st()
caption = self.caption_data['/'.join(ins.split('/')[1:])]
# ! to dict
# sample = self.post_process.create_dict_nobatch(sample)
ret_dict = {
'caption': caption,
# 'ins': ins,
'c': self.eval_camera,
'img': img, # fix inp img range to [-1,1]
'mv_img': mv_img,
'latent': latent,
'normalized-fps-xyz': normalized_fps_xyz,
'fps-xyz': fps_xyz,
}
return ret_dict
def _getitem_gt(self, index) -> Any:
raw_img, c, caption, ins = self.read_chunk(
os.path.join(self.gt_mv_file_path, self.gt_chunk_list[index]))
# ! process
raw_img = torch.from_numpy(raw_img).permute(0, 3, 1, 2) / 255.0 # [0,1]
if raw_img.shape[-1] != self.reso:
raw_img = torch.nn.functional.interpolate(
input=raw_img,
size=(self.reso, self.reso),
mode='bilinear',
align_corners=False,
)
img = raw_img * 2 - 1 # as gt
# ! load latent
# latent, _ = self._load_latent(ins)
latent, fps_xyz = self._load_latent(ins, pick_both=True) # analyzing xyz/latent disentangled diffusion
# latent, fps_xyz = latent[0], fps_xyz[0] # remove batch dim here
# fps_xyz = fps_xyz / self.scaling_factor # for xyz training
normalized_fps_xyz = self.normalize_pcd_act(fps_xyz)
if self.avoid_loading_first: # for training mv model
index = list(range(1,self.chunk_size//2)) + list(range(self.chunk_size//2+1, self.chunk_size))
img = img[index]
c = c[index]
# ! shuffle
indices = np.random.permutation(img.shape[0])
img = img[indices]
c = c[indices]
# st()
aug_img = self.mv_resize_cls(img)
aug_img = self.perspective_transformer(aug_img) # create 3D inconsistency
# ! split along V and repeat other stuffs accordingly
img = rearrange(img, '(B V) ... -> B V ...', B=2)[:, 0:1] # only return first view (randomly sampled)
aug_img = rearrange(aug_img, '(B V) ... -> B V ...', B=2)[:, 1:self.n_cond_frames+1]
c = rearrange(c, '(B V) ... -> B V ...', B=2)[:, 1:self.n_cond_frames+1] # 2 6 25
# use view-space camera tradition
c[0] = self.normalize_camera(c[0], c[0,0:1])
c[1] = self.normalize_camera(c[1], c[1,0:1])
caption = [caption, caption]
ins = [ins, ins]
# ! to dict
# sample = self.post_process.create_dict_nobatch(sample)
ret_dict = {
'caption': caption,
'ins': ins,
'c': c,
'img': img, # fix inp img range to [-1,1]
'mv_img': aug_img,
'latent': latent,
'normalized-fps-xyz': normalized_fps_xyz,
'fps-xyz': fps_xyz,
}
return ret_dict
def __getitem__(self, index) -> Any:
# load synthetic version
try:
synthetic_mv = self._getitem_synthetic(index)
except Exception as e:
# logger.log(Path(self.rgb_list[index]), 'missing')
synthetic_mv = self._getitem_synthetic(random.randint(0, len(self.rgb_list)//2))
if self.load_synthetic_only:
return synthetic_mv
else:
# load gt mv chunk
gt_chunk_index = random.randint(0, len(self.gt_chunk_list)-1)
gt_mv = self._getitem_gt(gt_chunk_index)
# merge them together along batch dim
merged_mv = {}
for k, v in synthetic_mv.items(): # merge, synthetic - gt order
if k not in ['caption', 'ins']:
if k == 'img':
merged_mv[k] = np.concatenate([v[None], gt_mv[k][:, 0]], axis=0).astype(np.float32)
else:
merged_mv[k] = np.concatenate([v[None], gt_mv[k]], axis=0).astype(np.float32)
else:
merged_mv[k] = [v] + gt_mv[k] # list
return merged_mv
class ChunkObjaverseDatasetDDPMgsI23D_loadMV(ChunkObjaverseDatasetDDPMgs):
def __init__(
self,
file_path,
reso,
reso_encoder,
preprocess=None,
classes=False,
load_depth=False,
test=False,
scene_scale=1,
overfitting=False,
imgnet_normalize=True,
dataset_size=-1,
overfitting_bs=-1,
interval=1,
plucker_embedding=False,
shuffle_across_cls=False,
wds_split=1, # 4 splits to accelerate preprocessing
four_view_for_latent=False,
single_view_for_i23d=False,
load_extra_36_view=False,
gs_cam_format=False,
frame_0_as_canonical=True,
split_chunk_size=10,
mv_input=True,
append_depth=False,
append_xyz=False,
pcd_path=None,
load_pcd=False,
read_normal=False,
mv_latent_dir='',
canonicalize_pcd=False,
# shards_folder_num=4,
# eval=False,
**kwargs):
super().__init__(
file_path,
reso,
reso_encoder,
preprocess=preprocess,
classes=classes,
load_depth=load_depth,
test=test,
scene_scale=scene_scale,
overfitting=overfitting,
imgnet_normalize=imgnet_normalize,
dataset_size=dataset_size,
overfitting_bs=overfitting_bs,
interval=interval,
plucker_embedding=plucker_embedding,
shuffle_across_cls=shuffle_across_cls,
wds_split=wds_split, # 4 splits to accelerate preprocessing
four_view_for_latent=four_view_for_latent,
single_view_for_i23d=single_view_for_i23d,
load_extra_36_view=load_extra_36_view,
gs_cam_format=gs_cam_format,
frame_0_as_canonical=frame_0_as_canonical,
split_chunk_size=split_chunk_size,
mv_input=mv_input,
append_depth=append_depth,
append_xyz=append_xyz,
pcd_path=pcd_path,
load_pcd=load_pcd,
read_normal=read_normal,
mv_latent_dir=mv_latent_dir,
load_raw=False,
# shards_folder_num=4,
# eval=False,
**kwargs)
assert not self.load_raw
# self.scaling_factor = np.array([0.14593576, 0.15753542, 0.18873914])
self.n_cond_frames = 5 # a easy version for now.
self.avoid_loading_first = True
# self.canonicalize_pcd = canonicalize_pcd
# self.canonicalize_pcd = True
self.canonicalize_pcd = False
def canonicalize_xyz(self, c, pcd):
B = c.shape[0]
camera_poses_rot = c[:, :16].reshape(B, 4, 4)[:, :3, :3]
R_inv = np.transpose(camera_poses_rot, (0,2,1)) # w2c rotation
new_pcd = (R_inv @ np.transpose(pcd, (0,2,1))) # B 3 3 @ B 3 N
new_pcd = np.transpose(new_pcd, (0,2,1))
return new_pcd
def __getitem__(self, index) -> Any:
raw_img, c, caption, ins = self.read_chunk(
os.path.join(self.file_path, self.chunk_list[index]))
# ! process
raw_img = torch.from_numpy(raw_img).permute(0, 3, 1, 2) / 255.0 # [0,1]
if raw_img.shape[-1] != self.reso:
raw_img = torch.nn.functional.interpolate(
input=raw_img,
size=(self.reso, self.reso),
mode='bilinear',
align_corners=False,
)
img = raw_img * 2 - 1 # as gt
# ! load latent
# latent, _ = self._load_latent(ins)
if self.avoid_loading_first: # for training mv model
index = list(range(1,self.chunk_size//2)) + list(range(self.chunk_size//2+1, self.chunk_size))
img = img[index]
c = c[index]
# ! shuffle
indices = np.random.permutation(img.shape[0])[:self.n_cond_frames*2]
img = img[indices]
c = c[indices]
latent, fps_xyz = self._load_latent(ins, pick_both=True) # analyzing xyz/latent disentangled diffusion
# latent, fps_xyz = latent[0], fps_xyz[0] # remove batch dim here
fps_xyz = np.repeat(fps_xyz, self.n_cond_frames, 0)
latent = np.repeat(latent, self.n_cond_frames, 0)
normalized_fps_xyz = self.normalize_pcd_act(fps_xyz)
if self.canonicalize_pcd:
normalized_fps_xyz = self.canonicalize_xyz(c, normalized_fps_xyz)
# repeat
caption = [caption] * self.n_cond_frames * 2
ins = [ins] * self.n_cond_frames * 2
ret_dict = {
'caption': caption,
'ins': ins,
'c': c,
'img': img, # fix inp img range to [-1,1]
'latent': latent,
'normalized-fps-xyz': normalized_fps_xyz,
'fps-xyz': fps_xyz,
# **latent
}
return ret_dict
class RealDataset(Dataset):
def __init__(
self,
file_path,
reso,
reso_encoder,
preprocess=None,
classes=False,
load_depth=False,
test=False,
scene_scale=1,
overfitting=False,
imgnet_normalize=True,
dataset_size=-1,
overfitting_bs=-1,
interval=1,
plucker_embedding=False,
shuffle_across_cls=False,
wds_split=1, # 4 splits to accelerate preprocessing
) -> None:
super().__init__()
self.file_path = file_path
self.overfitting = overfitting
self.scene_scale = scene_scale
self.reso = reso
self.reso_encoder = reso_encoder
self.classes = False
self.load_depth = load_depth
self.preprocess = preprocess
self.plucker_embedding = plucker_embedding
self.rgb_list = []
all_fname = [
t for t in os.listdir(self.file_path)
if t.split('.')[1] in ['png', 'jpg']
]
all_fname = [name for name in all_fname if '-input' in name ]
self.rgb_list += ([
os.path.join(self.file_path, fname) for fname in all_fname
])
# st()
# if len(self.rgb_list) == 1:
# # placeholder
# self.rgb_list = self.rgb_list * 40
# ! setup normalizataion
transformations = [
transforms.ToTensor(), # [0,1] range
]
assert imgnet_normalize
if imgnet_normalize:
transformations.append(
transforms.Normalize((0.485, 0.456, 0.406),
(0.229, 0.224, 0.225)) # type: ignore
)
else:
transformations.append(
transforms.Normalize((0.5, 0.5, 0.5),
(0.5, 0.5, 0.5))) # type: ignore
self.normalize = transforms.Compose(transformations)
# camera = torch.load('eval_pose.pt', map_location='cpu')
# self.eval_camera = camera
# pre-cache
# self.calc_rays_plucker()
def __len__(self):
return len(self.rgb_list)
def __getitem__(self, index) -> Any:
# return super().__getitem__(index)
rgb_fname = self.rgb_list[index]
# ! preprocess, normalize
raw_img = imageio.imread(rgb_fname)
# interpolation=cv2.INTER_AREA)
if raw_img.shape[-1] == 4:
alpha_mask = raw_img[..., 3:4] / 255.0
bg_white = np.ones_like(alpha_mask) * 255.0
raw_img = raw_img[..., :3] * alpha_mask + (
1 - alpha_mask) * bg_white #[3, reso_encoder, reso_encoder]
raw_img = raw_img.astype(np.uint8)
# raw_img = recenter(raw_img, np.ones_like(raw_img), border_ratio=0.2)
# log gt
img = cv2.resize(raw_img, (self.reso, self.reso),
interpolation=cv2.INTER_LANCZOS4)
img = torch.from_numpy(img)[..., :3].permute(
2, 0, 1
) / 127.5 - 1 #[3, reso, reso], normalize to [-1,1], follow triplane range
ret_dict = {
# 'rgb_fname': rgb_fname,
# 'img_to_encoder':
# img_to_encoder.unsqueeze(0).repeat_interleave(40, 0),
'img': img,
# 'c': self.eval_camera, # TODO, get pre-calculated samples
# 'ins': 'placeholder',
# 'bbox': 'placeholder',
# 'caption': 'placeholder',
}
# ! repeat as a intance
return ret_dict
class RealDataset_GSO(Dataset):
def __init__(
self,
file_path,
reso,
reso_encoder,
preprocess=None,
classes=False,
load_depth=False,
test=False,
scene_scale=1,
overfitting=False,
imgnet_normalize=True,
dataset_size=-1,
overfitting_bs=-1,
interval=1,
plucker_embedding=False,
shuffle_across_cls=False,
wds_split=1, # 4 splits to accelerate preprocessing
) -> None:
super().__init__()
self.file_path = file_path
self.overfitting = overfitting
self.scene_scale = scene_scale
self.reso = reso
self.reso_encoder = reso_encoder
self.classes = False
self.load_depth = load_depth
self.preprocess = preprocess
self.plucker_embedding = plucker_embedding
self.rgb_list = []
# ! for gso-rendering
all_objs = os.listdir(self.file_path)
all_objs.sort()
if True: # instant-mesh picked images
# if False:
all_instances = os.listdir(self.file_path)
# all_fname = [
# t for t in all_instances
# if t.split('.')[1] in ['png', 'jpg']
# ]
# all_fname = [name for name in all_fname if '-input' in name ]
# all_fname = ['house2-input.png', 'plant-input.png']
all_fname = ['house2-input.png']
self.rgb_list = [os.path.join(self.file_path, name) for name in all_fname]
if False:
for obj_folder in tqdm(all_objs[515:]):
# for obj_folder in tqdm(all_objs[:515]):
# for obj_folder in tqdm(all_objs[:]):
# for obj_folder in tqdm(sorted(os.listdir(self.file_path))[515:]):
# for idx in range(0,25,5):
for idx in [0]: # only query frontal view is enough
self.rgb_list.append(os.path.join(self.file_path, obj_folder, 'rgba', f'{idx:03d}.png'))
# for free-3d rendering
if False:
# if True:
# all_instances = sorted(os.listdir(self.file_path))
all_instances = ['BAGEL_WITH_CHEESE',
'BALANCING_CACTUS',
'Baby_Elements_Stacking_Cups',
'Breyer_Horse_Of_The_Year_2015',
'COAST_GUARD_BOAT',
'CONE_SORTING',
'CREATIVE_BLOCKS_35_MM',
'Cole_Hardware_Mini_Honey_Dipper',
'FAIRY_TALE_BLOCKS',
'FIRE_ENGINE',
'FOOD_BEVERAGE_SET',
'GEOMETRIC_PEG_BOARD',
'Great_Dinos_Triceratops_Toy',
'JUICER_SET',
'STACKING_BEAR',
'STACKING_RING',
'Schleich_African_Black_Rhino']
for instance in all_instances:
self.rgb_list += ([
# os.path.join(self.file_path, instance, 'rgb', f'{fname:06d}.png') for fname in range(0,250,50)
# os.path.join(self.file_path, instance, 'rgb', f'{fname:06d}.png') for fname in range(0,250,100)
# os.path.join(self.file_path, instance, f'{fname:03d}.png') for fname in range(0,25,5)
os.path.join(self.file_path, instance, 'render_mvs_25', 'model', f'{fname:03d}.png') for fname in range(0,25,4)
])
# if True: # g-objv animals images for i23d eval
if False:
# if True:
objv_dataset = '/mnt/sfs-common/yslan/Dataset/Obajverse/chunk-jpeg-normal/bs_16_fixsave3/170K/512/'
dataset_json = os.path.join(objv_dataset, 'dataset.json')
with open(dataset_json, 'r') as f:
dataset_json = json.load(f)
# all_objs = dataset_json['Animals'][::3][:6250]
all_objs = dataset_json['Animals'][::3][1100:2200][:600]
for obj_folder in tqdm(all_objs[:]):
for idx in [0]: # only query frontal view is enough
self.rgb_list.append(os.path.join(self.file_path, obj_folder, f'{idx}.jpg'))
# ! setup normalizataion
transformations = [
transforms.ToTensor(), # [0,1] range
]
assert imgnet_normalize
if imgnet_normalize:
transformations.append(
transforms.Normalize((0.485, 0.456, 0.406),
(0.229, 0.224, 0.225)) # type: ignore
)
else:
transformations.append(
transforms.Normalize((0.5, 0.5, 0.5),
(0.5, 0.5, 0.5))) # type: ignore
self.normalize = transforms.Compose(transformations)
# camera = torch.load('eval_pose.pt', map_location='cpu')
# self.eval_camera = camera
# pre-cache
# self.calc_rays_plucker()
def __len__(self):
return len(self.rgb_list)
def __getitem__(self, index) -> Any:
# return super().__getitem__(index)
rgb_fname = self.rgb_list[index]
# ! preprocess, normalize
raw_img = imageio.imread(rgb_fname)
# interpolation=cv2.INTER_AREA)
if raw_img.shape[-1] == 4:
alpha_mask = raw_img[..., 3:4] / 255.0
bg_white = np.ones_like(alpha_mask) * 255.0
raw_img = raw_img[..., :3] * alpha_mask + (
1 - alpha_mask) * bg_white #[3, reso_encoder, reso_encoder]
raw_img = raw_img.astype(np.uint8)
# raw_img = recenter(raw_img, np.ones_like(raw_img), border_ratio=0.2)
# log gt
img = cv2.resize(raw_img, (self.reso, self.reso),
interpolation=cv2.INTER_LANCZOS4)
img = torch.from_numpy(img)[..., :3].permute(
2, 0, 1
) / 127.5 - 1 #[3, reso, reso], normalize to [-1,1], follow triplane range
ret_dict = {
'img': img,
# 'ins': str(Path(rgb_fname).parent.parent.stem), # for gso-rendering
'ins': str(Path(rgb_fname).relative_to(self.file_path)), # for gso-rendering
# 'ins': rgb_fname,
}
return ret_dict
class RealMVDataset(Dataset):
def __init__(
self,
file_path,
reso,
reso_encoder,
preprocess=None,
classes=False,
load_depth=False,
test=False,
scene_scale=1,
overfitting=False,
imgnet_normalize=True,
dataset_size=-1,
overfitting_bs=-1,
interval=1,
plucker_embedding=False,
shuffle_across_cls=False,
wds_split=1, # 4 splits to accelerate preprocessing
) -> None:
super().__init__()
self.file_path = file_path
self.overfitting = overfitting
self.scene_scale = scene_scale
self.reso = reso
self.reso_encoder = reso_encoder
self.classes = False
self.load_depth = load_depth
self.preprocess = preprocess
self.plucker_embedding = plucker_embedding
self.rgb_list = []
all_fname = [
t for t in os.listdir(self.file_path)
if t.split('.')[1] in ['png', 'jpg']
]
all_fname = [name for name in all_fname if '-input' in name ]
# all_fname = [name for name in all_fname if 'sorting_board-input' in name ]
# all_fname = [name for name in all_fname if 'teasure_chest-input' in name ]
# all_fname = [name for name in all_fname if 'bubble_mart_blue-input' in name ]
# all_fname = [name for name in all_fname if 'chair_comfort-input' in name ]
self.rgb_list += ([
os.path.join(self.file_path, fname) for fname in all_fname
])
# if len(self.rgb_list) == 1:
# # placeholder
# self.rgb_list = self.rgb_list * 40
# ! setup normalizataion
transformations = [
transforms.ToTensor(), # [0,1] range
]
azimuths = np.array([30, 90, 150, 210, 270, 330]).astype(float)
elevations = np.array([20, -10, 20, -10, 20, -10]).astype(float)
# zero123pp_pose, _ = generate_input_camera(1.6, [[elevations[i], azimuths[i]] for i in range(6)], fov=30)
zero123pp_pose, _ = generate_input_camera(1.8, [[elevations[i], azimuths[i]] for i in range(6)], fov=30)
K = torch.Tensor([1.3889, 0.0000, 0.5000, 0.0000, 1.3889, 0.5000, 0.0000, 0.0000, 0.0039]).to(zero123pp_pose) # keeps the same
# st()
zero123pp_pose = torch.cat([zero123pp_pose.reshape(6,-1), K.unsqueeze(0).repeat(6,1)], dim=-1)
# ! directly adopt gt input
# self.indices = np.array([0,2,4,5])
# eval_camera = zero123pp_pose[self.indices]
# self.eval_camera = torch.cat([torch.zeros_like(eval_camera[0:1]),eval_camera], 0) # first c not used as condition here, just placeholder
# ! adopt mv-diffusion output as input.
# self.indices = np.array([1,0,2,4,5])
self.indices = np.array([0,1,2,3,4,5])
eval_camera = zero123pp_pose[self.indices].float().cpu().numpy() # for normalization
# eval_camera = zero123pp_pose[self.indices]
# self.eval_camera = eval_camera
# self.eval_camera = torch.cat([torch.zeros_like(eval_camera[0:1]),eval_camera], 0) # first c not used as condition here, just placeholder
# # * normalize here
self.eval_camera = self.normalize_camera(eval_camera, eval_camera[0:1]) # the first img is not used.
# self.mv_resize_cls = torchvision.transforms.Resize(320, interpolation=torchvision.transforms.InterpolationMode.BILINEAR,
# max_size=None, antialias=True)
def normalize_camera(self, c, c_frame0):
# assert c.shape[0] == self.chunk_size # 8 o r10
B = c.shape[0]
camera_poses = c[:, :16].reshape(B, 4, 4) # 3x4
canonical_camera_poses = c_frame0[:, :16].reshape(1, 4, 4)
inverse_canonical_pose = np.linalg.inv(canonical_camera_poses)
inverse_canonical_pose = np.repeat(inverse_canonical_pose, B, 0)
cam_radius = np.linalg.norm(
c_frame0[:, :16].reshape(1, 4, 4)[:, :3, 3],
axis=-1,
keepdims=False) # since g-buffer adopts dynamic radius here.
frame1_fixed_pos = np.repeat(np.eye(4)[None], 1, axis=0)
frame1_fixed_pos[:, 2, -1] = -cam_radius
transform = frame1_fixed_pos @ inverse_canonical_pose
new_camera_poses = np.repeat(
transform, 1, axis=0
) @ camera_poses # [V, 4, 4]. np.repeat() is th.repeat_interleave()
c = np.concatenate([new_camera_poses.reshape(B, 16), c[:, 16:]],
axis=-1)
return c
def __len__(self):
return len(self.rgb_list)
def __getitem__(self, index) -> Any:
# return super().__getitem__(index)
rgb_fname = self.rgb_list[index]
raw_img = imageio.imread(rgb_fname)[..., :3]
raw_img = cv2.resize(raw_img, (self.reso, self.reso), interpolation=cv2.INTER_CUBIC)
raw_img = torch.from_numpy(raw_img).permute(2,0,1).clip(0,255) # [0,1]
img = raw_img / 127.5 - 1
# ! if loading mv-diff output views
mv_img = imageio.imread(rgb_fname.replace('-input', ''))
mv_img = rearrange(mv_img, '(n h) (m w) c -> (n m) h w c', n=3, m=2)[self.indices] # (6, 3, 320, 320)
mv_img = np.stack([recenter(img, np.ones_like(img), border_ratio=0.1) for img in mv_img], axis=0)
mv_img = rearrange(mv_img, 'b h w c -> b c h w') # to torch tradition
mv_img = torch.from_numpy(mv_img) / 127.5 - 1
ret_dict = {
'img': img,
'mv_img': mv_img,
'c': self.eval_camera,
'caption': 'null',
}
return ret_dict
class NovelViewObjverseDataset(MultiViewObjverseDataset):
"""novel view prediction version.
"""
def __init__(self,
file_path,
reso,
reso_encoder,
preprocess=None,
classes=False,
load_depth=False,
test=False,
scene_scale=1,
overfitting=False,
imgnet_normalize=True,
dataset_size=-1,
overfitting_bs=-1,
**kwargs):
super().__init__(file_path, reso, reso_encoder, preprocess, classes,
load_depth, test, scene_scale, overfitting,
imgnet_normalize, dataset_size, overfitting_bs,
**kwargs)
def __getitem__(self, idx):
input_view = super().__getitem__(
idx) # get previous input view results
# get novel view of the same instance
novel_view = super().__getitem__(
(idx // self.instance_data_length) * self.instance_data_length +
random.randint(0, self.instance_data_length - 1))
# assert input_view['ins_name'] == novel_view['ins_name'], 'should sample novel view from the same instance'
input_view.update({f'nv_{k}': v for k, v in novel_view.items()})
return input_view
class MultiViewObjverseDatasetforLMDB(MultiViewObjverseDataset):
def __init__(
self,
file_path,
reso,
reso_encoder,
preprocess=None,
classes=False,
load_depth=False,
test=False,
scene_scale=1,
overfitting=False,
imgnet_normalize=True,
dataset_size=-1,
overfitting_bs=-1,
shuffle_across_cls=False,
wds_split=1,
four_view_for_latent=False,
):
super().__init__(file_path,
reso,
reso_encoder,
preprocess,
classes,
load_depth,
test,
scene_scale,
overfitting,
imgnet_normalize,
dataset_size,
overfitting_bs,
shuffle_across_cls=shuffle_across_cls,
wds_split=wds_split,
four_view_for_latent=four_view_for_latent)
# assert self.reso == 256
self.load_caption = True
with open(
# '/cpfs01/shared/V2V/V2V_hdd/yslan/aigc3d/text_captions_cap3d.json'
'/nas/shared/public/yslan/data/text_captions_cap3d.json') as f:
# '/nas/shared/V2V/yslan/aigc3d/text_captions_cap3d.json') as f:
self.caption_data = json.load(f)
# lmdb_path = '/cpfs01/user/yangpeiqing.p/yslan/data/Furnitures_uncompressed/'
# with open(os.path.join(lmdb_path, 'idx_to_ins_mapping.json')) as f:
# self.idx_to_ins_mapping = json.load(f)
def __len__(self):
return super().__len__()
# return 100 # for speed debug
def quantize_depth(self, depth):
# https://developers.google.com/depthmap-metadata/encoding
# RangeInverse encoding
bg = depth == 0
depth[bg] = 3 # no need to allocate capacity to it
disparity = 1 / depth
far = disparity.max().item() # np array here
near = disparity.min().item()
# d_normalized = (far * (depth-near) / (depth * far - near)) # [0,1] range
d_normalized = (disparity - near) / (far - near) # [0,1] range
# imageio.imwrite('depth_negative.jpeg', (((depth - near) / (far - near) * 255)<0).numpy().astype(np.uint8))
# imageio.imwrite('depth_negative.jpeg', ((depth <0)*255).numpy().astype(np.uint8))
d_normalized = np.nan_to_num(d_normalized.cpu().numpy())
d_normalized = (np.clip(d_normalized, 0, 1) * 255).astype(np.uint8)
# imageio.imwrite('depth.png', d_normalized)
# d = 1 / ( (d_normalized / 255) * (far-near) + near)
# diff = (d[~bg.numpy()] - depth[~bg].numpy()).sum()
return d_normalized, near, far # return disp
def __getitem__(self, idx):
# ret_dict = super().__getitem__(idx)
rgb_fname = self.rgb_list[idx]
pose_fname = self.pose_list[idx]
raw_img = imageio.imread(rgb_fname) # [..., :3]
assert raw_img.shape[-1] == 4
# st() # cv2.imwrite('img_CV2_90.jpg', a, [int(cv2.IMWRITE_JPEG_QUALITY), 90])
# if raw_img.shape[-1] == 4: # ! set bg to white
alpha_mask = raw_img[..., -1:] / 255 # [0,1]
raw_img = alpha_mask * raw_img[..., :3] + (
1 - alpha_mask) * np.ones_like(raw_img[..., :3]) * 255
raw_img = np.concatenate([raw_img, alpha_mask * 255], -1)
raw_img = raw_img.astype(np.uint8)
raw_img = cv2.resize(raw_img, (self.reso, self.reso),
interpolation=cv2.INTER_LANCZOS4)
alpha_mask = raw_img[..., -1] / 255
raw_img = raw_img[..., :3]
# alpha_mask = cv2.resize(alpha_mask, (self.reso, self.reso),
# interpolation=cv2.INTER_LANCZOS4)
c2w = read_camera_matrix_single(pose_fname) #[1, 4, 4] -> [1, 16]
c = np.concatenate([c2w.reshape(16), self.intrinsics],
axis=0).reshape(25).astype(
np.float32) # 25, no '1' dim needed.
c = torch.from_numpy(c)
# c = np.concatenate([c2w, self.intrinsics], axis=0).reshape(25) # 25, no '1' dim needed.
# if self.load_depth:
# depth, depth_mask, depth_mask_sr = read_dnormal(self.depth_list[idx],
# try:
depth, normal = read_dnormal(self.depth_list[idx], c2w[:3, 3:],
self.reso, self.reso)
# ! quantize depth for fast decoding
# d_normalized, d_near, d_far = self.quantize_depth(depth)
# ! add frame_0 alignment
# try:
ins = str(
(Path(self.data_ins_list[idx]).relative_to(self.file_path)).parent)
# if self.shuffle_across_cls:
if self.load_caption:
caption = self.caption_data['/'.join(ins.split('/')[1:])]
bbox = self.load_bbox(torch.from_numpy(alpha_mask) > 0)
else:
caption = '' # since in g-alignment-xl, some instances will fail.
bbox = self.load_bbox(torch.from_numpy(np.ones_like(alpha_mask)) > 0)
# else:
# caption = self.caption_data[ins]
ret_dict = {
'normal': normal,
'raw_img': raw_img,
'c': c,
# 'depth_mask': depth_mask, # 64x64 here?
'bbox': bbox,
'ins': ins,
'caption': caption,
'alpha_mask': alpha_mask,
'depth': depth, # return for pcd creation
# 'd_normalized': d_normalized,
# 'd_near': d_near,
# 'd_far': d_far,
# 'fname': rgb_fname,
}
return ret_dict
class MultiViewObjverseDatasetforLMDB_nocaption(MultiViewObjverseDatasetforLMDB):
def __init__(
self,
file_path,
reso,
reso_encoder,
preprocess=None,
classes=False,
load_depth=False,
test=False,
scene_scale=1,
overfitting=False,
imgnet_normalize=True,
dataset_size=-1,
overfitting_bs=-1,
shuffle_across_cls=False,
wds_split=1,
four_view_for_latent=False,
):
super().__init__(file_path,
reso,
reso_encoder,
preprocess,
classes,
load_depth,
test,
scene_scale,
overfitting,
imgnet_normalize,
dataset_size,
overfitting_bs,
shuffle_across_cls=shuffle_across_cls,
wds_split=wds_split,
four_view_for_latent=four_view_for_latent)
self.load_caption = False
class Objv_LMDBDataset_MV_Compressed(LMDBDataset_MV_Compressed):
def __init__(self,
lmdb_path,
reso,
reso_encoder,
imgnet_normalize=True,
dataset_size=-1,
test=False,
**kwargs):
super().__init__(lmdb_path,
reso,
reso_encoder,
imgnet_normalize,
dataset_size=dataset_size,
**kwargs)
self.instance_data_length = 40 # ! could save some key attributes in LMDB
if test:
self.length = self.instance_data_length
elif dataset_size > 0:
self.length = dataset_size * self.instance_data_length
# load caption data, and idx-to-ins mapping
with open(
'/cpfs01/shared/V2V/V2V_hdd/yslan/aigc3d/text_captions_cap3d.json'
) as f:
self.caption_data = json.load(f)
with open(os.path.join(lmdb_path, 'idx_to_ins_mapping.json')) as f:
self.idx_to_ins_mapping = json.load(f)
def _load_data(self, idx):
# '''
raw_img, depth, c, bbox = self._load_lmdb_data(idx)
# raw_img, depth, c, bbox = self._load_lmdb_data_no_decompress(idx)
# resize depth and bbox
caption = self.caption_data[self.idx_to_ins_mapping[str(idx)]]
return {
**self._post_process_sample(raw_img, depth),
'c': c,
'bbox': (bbox * (self.reso / 512.0)).astype(np.uint8),
# 'bbox': (bbox*(self.reso/256.0)).astype(np.uint8), # TODO, double check 512 in wds?
'caption': caption
}
# '''
# raw_img, depth, c, bbox = self._load_lmdb_data_no_decompress(idx)
# st()
# return {}
def __getitem__(self, idx):
return self._load_data(idx)
class Objv_LMDBDataset_MV_NoCompressed(Objv_LMDBDataset_MV_Compressed):
def __init__(self,
lmdb_path,
reso,
reso_encoder,
imgnet_normalize=True,
dataset_size=-1,
test=False,
**kwargs):
super().__init__(lmdb_path, reso, reso_encoder, imgnet_normalize,
dataset_size, test, **kwargs)
def _load_data(self, idx):
# '''
raw_img, depth, c, bbox = self._load_lmdb_data_no_decompress(idx)
# resize depth and bbox
caption = self.caption_data[self.idx_to_ins_mapping[str(idx)]]
return {
**self._post_process_sample(raw_img, depth), 'c': c,
'bbox': (bbox * (self.reso / 512.0)).astype(np.uint8),
'caption': caption
}
return {}
class Objv_LMDBDataset_NV_NoCompressed(Objv_LMDBDataset_MV_NoCompressed):
def __init__(self,
lmdb_path,
reso,
reso_encoder,
imgnet_normalize=True,
dataset_size=-1,
test=False,
**kwargs):
super().__init__(lmdb_path, reso, reso_encoder, imgnet_normalize,
dataset_size, test, **kwargs)
def __getitem__(self, idx):
input_view = self._load_data(idx) # get previous input view results
# get novel view of the same instance
try:
novel_view = self._load_data(
(idx // self.instance_data_length) *
self.instance_data_length +
random.randint(0, self.instance_data_length - 1))
except Exception as e:
raise NotImplementedError(idx)
# assert input_view['ins_name'] == novel_view['ins_name'], 'should sample novel view from the same instance'
input_view.update({f'nv_{k}': v for k, v in novel_view.items()})
return input_view
class Objv_LMDBDataset_MV_Compressed_for_lmdb(LMDBDataset_MV_Compressed):
def __init__(self,
lmdb_path,
reso,
reso_encoder,
imgnet_normalize=True,
dataset_size=-1,
test=False,
**kwargs):
super().__init__(lmdb_path,
reso,
reso_encoder,
imgnet_normalize,
dataset_size=dataset_size,
**kwargs)
self.instance_data_length = 40 # ! could save some key attributes in LMDB
if test:
self.length = self.instance_data_length
elif dataset_size > 0:
self.length = dataset_size * self.instance_data_length
# load caption data, and idx-to-ins mapping
with open(
'/cpfs01/shared/V2V/V2V_hdd/yslan/aigc3d/text_captions_cap3d.json'
) as f:
self.caption_data = json.load(f)
with open(os.path.join(lmdb_path, 'idx_to_ins_mapping.json')) as f:
self.idx_to_ins_mapping = json.load(f)
# def _load_data(self, idx):
# # '''
# raw_img, depth, c, bbox = self._load_lmdb_data(idx)
# # resize depth and bbox
# caption = self.caption_data[self.idx_to_ins_mapping[str(idx)]]
# # st()
# return {
# **self._post_process_sample(raw_img, depth), 'c': c,
# 'bbox': (bbox*(self.reso/512.0)).astype(np.uint8),
# 'caption': caption
# }
# # '''
# # raw_img, depth, c, bbox = self._load_lmdb_data_no_decompress(idx)
# # st()
# # return {}
def load_bbox(self, mask):
nonzero_value = torch.nonzero(mask)
height, width = nonzero_value.max(dim=0)[0]
top, left = nonzero_value.min(dim=0)[0]
bbox = torch.tensor([top, left, height, width], dtype=torch.float32)
return bbox
def __getitem__(self, idx):
raw_img, depth, c, bbox = self._load_lmdb_data(idx)
return {'raw_img': raw_img, 'depth': depth, 'c': c, 'bbox': bbox}
class Objv_LMDBDataset_NV_Compressed(Objv_LMDBDataset_MV_Compressed):
def __init__(self,
lmdb_path,
reso,
reso_encoder,
imgnet_normalize=True,
dataset_size=-1,
**kwargs):
super().__init__(lmdb_path, reso, reso_encoder, imgnet_normalize,
dataset_size, **kwargs)
def __getitem__(self, idx):
input_view = self._load_data(idx) # get previous input view results
# get novel view of the same instance
try:
novel_view = self._load_data(
(idx // self.instance_data_length) *
self.instance_data_length +
random.randint(0, self.instance_data_length - 1))
except Exception as e:
raise NotImplementedError(idx)
# assert input_view['ins_name'] == novel_view['ins_name'], 'should sample novel view from the same instance'
input_view.update({f'nv_{k}': v for k, v in novel_view.items()})
return input_view
#
# test tar loading
def load_wds_ResampledShard(file_path,
batch_size,
num_workers,
reso,
reso_encoder,
test=False,
preprocess=None,
imgnet_normalize=True,
plucker_embedding=False,
decode_encode_img_only=False,
load_instance=False,
mv_input=False,
split_chunk_input=False,
duplicate_sample=True,
append_depth=False,
append_normal=False,
gs_cam_format=False,
orthog_duplicate=False,
**kwargs):
# return raw_img, depth, c, bbox, sample_pyd['ins.pyd'], sample_pyd['fname.pyd']
post_process_cls = PostProcess(
reso,
reso_encoder,
imgnet_normalize=imgnet_normalize,
plucker_embedding=plucker_embedding,
decode_encode_img_only=decode_encode_img_only,
mv_input=mv_input,
split_chunk_input=split_chunk_input,
duplicate_sample=duplicate_sample,
append_depth=append_depth,
gs_cam_format=gs_cam_format,
orthog_duplicate=orthog_duplicate,
append_normal=append_normal,
)
# ! add shuffling
if isinstance(file_path, list): # lst of shard urls
all_shards = []
for url_path in file_path:
all_shards.extend(wds.shardlists.expand_source(url_path))
logger.log('all_shards', all_shards)
else:
all_shards = file_path # to be expanded
if not load_instance: # during reconstruction training, load pair
if not split_chunk_input:
dataset = wds.DataPipeline(
wds.ResampledShards(all_shards), # url_shard
# at this point we have an iterator over all the shards
wds.shuffle(50),
wds.split_by_worker, # if multi-node
wds.tarfile_to_samples(),
# add wds.split_by_node here if you are using multiple nodes
wds.shuffle(
1000
), # shuffles in the memory, leverage large RAM for more efficient loading
wds.decode(wds.autodecode.basichandlers), # TODO
wds.to_tuple(
"sample.pyd"), # extract the pyd from top level dict
wds.map(post_process_cls.decode_zip),
wds.map(post_process_cls.paired_post_process
), # create input-novelview paired samples
# wds.map(post_process_cls._post_process_sample),
# wds.detshuffle(1000), # shuffles in the memory, leverage large RAM for more efficient loading
wds.batched(
16,
partial=True,
# collation_fn=collate
) # streaming more data at once, and rebatch later
)
elif load_gzip: # deprecated, no performance improve
dataset = wds.DataPipeline(
wds.ResampledShards(all_shards), # url_shard
# at this point we have an iterator over all the shards
wds.shuffle(10),
wds.split_by_worker, # if multi-node
wds.tarfile_to_samples(),
# add wds.split_by_node here if you are using multiple nodes
# wds.shuffle(
# 100
# ), # shuffles in the memory, leverage large RAM for more efficient loading
wds.decode('rgb8'), # TODO
# wds.decode(wds.autodecode.basichandlers), # TODO
# wds.to_tuple('raw_img.jpeg', 'depth.jpeg', 'alpha_mask.jpeg',
# 'd_near.npy', 'd_far.npy', "c.npy", 'bbox.npy',
# 'ins.txt', 'caption.txt'),
wds.to_tuple('raw_img.png', 'depth_alpha.png'),
# wds.to_tuple('raw_img.jpg', "c.npy", 'bbox.npy', 'depth.pyd', 'ins.txt', 'caption.txt'),
# wds.to_tuple('raw_img.jpg', "c.npy", 'bbox.npy', 'ins.txt', 'caption.txt'),
wds.map(post_process_cls.decode_gzip),
# wds.map(post_process_cls.paired_post_process_chunk
# ), # create input-novelview paired samples
wds.batched(
20,
partial=True,
# collation_fn=collate
) # streaming more data at once, and rebatch later
)
else:
dataset = wds.DataPipeline(
wds.ResampledShards(all_shards), # url_shard
# at this point we have an iterator over all the shards
wds.shuffle(100),
wds.split_by_worker, # if multi-node
wds.tarfile_to_samples(),
# add wds.split_by_node here if you are using multiple nodes
wds.shuffle(
4000 // split_chunk_size
), # shuffles in the memory, leverage large RAM for more efficient loading
wds.decode(wds.autodecode.basichandlers), # TODO
wds.to_tuple(
"sample.pyd"), # extract the pyd from top level dict
wds.map(post_process_cls.decode_zip),
wds.map(post_process_cls.paired_post_process_chunk
), # create input-novelview paired samples
# wds.map(post_process_cls._post_process_sample),
# wds.detshuffle(1000), # shuffles in the memory, leverage large RAM for more efficient loading
wds.batched(
120 // split_chunk_size,
partial=True,
# collation_fn=collate
) # streaming more data at once, and rebatch later
)
loader_shard = wds.WebLoader(
dataset,
num_workers=num_workers,
drop_last=False,
batch_size=None,
shuffle=False,
persistent_workers=num_workers > 0).unbatched().shuffle(
1000 // split_chunk_size).batched(batch_size).map(
post_process_cls.create_dict)
if mv_input:
loader_shard = loader_shard.map(post_process_cls.prepare_mv_input)
else: # load single instance during test/eval
assert batch_size == 1
dataset = wds.DataPipeline(
wds.ResampledShards(all_shards), # url_shard
# at this point we have an iterator over all the shards
wds.shuffle(50),
wds.split_by_worker, # if multi-node
wds.tarfile_to_samples(),
# add wds.split_by_node here if you are using multiple nodes
wds.detshuffle(
100
), # shuffles in the memory, leverage large RAM for more efficient loading
wds.decode(wds.autodecode.basichandlers), # TODO
wds.to_tuple("sample.pyd"), # extract the pyd from top level dict
wds.map(post_process_cls.decode_zip),
# wds.map(post_process_cls.paired_post_process), # create input-novelview paired samples
wds.map(post_process_cls._post_process_batch_sample),
# wds.detshuffle(1000), # shuffles in the memory, leverage large RAM for more efficient loading
wds.batched(
2,
partial=True,
# collation_fn=collate
) # streaming more data at once, and rebatch later
)
loader_shard = wds.WebLoader(
dataset,
num_workers=num_workers,
drop_last=False,
batch_size=None,
shuffle=False,
persistent_workers=num_workers
> 0).unbatched().shuffle(200).batched(batch_size).map(
post_process_cls.single_instance_sample_create_dict)
# persistent_workers=num_workers > 0).unbatched().batched(batch_size).map(post_process_cls.create_dict)
# 1000).batched(batch_size).map(post_process_cls.create_dict)
# .map(collate)
# .map(collate)
# .batched(batch_size)
#
# .unbatched().shuffle(1000).batched(batch_size).map(post_process)
# # https://github.com/webdataset/webdataset/issues/187
# return next(iter(loader_shard))
#return dataset
return loader_shard
class PostProcessForDiff:
def __init__(
self,
reso,
reso_encoder,
imgnet_normalize,
plucker_embedding,
decode_encode_img_only,
mv_latent_dir,
) -> None:
self.plucker_embedding = plucker_embedding
self.mv_latent_dir = mv_latent_dir
self.decode_encode_img_only = decode_encode_img_only
transformations = [
transforms.ToTensor(), # [0,1] range
]
if imgnet_normalize:
transformations.append(
transforms.Normalize((0.485, 0.456, 0.406),
(0.229, 0.224, 0.225)) # type: ignore
)
else:
transformations.append(
transforms.Normalize((0.5, 0.5, 0.5),
(0.5, 0.5, 0.5))) # type: ignore
self.normalize = transforms.Compose(transformations)
self.reso_encoder = reso_encoder
self.reso = reso
self.instance_data_length = 40
# self.pair_per_instance = 1 # compat
self.pair_per_instance = 2 # check whether improves IO
# self.pair_per_instance = 3 # check whether improves IO
# self.pair_per_instance = 4 # check whether improves IO
self.camera = torch.load('eval_pose.pt', map_location='cpu').numpy()
self.canonical_frame = self.camera[25:26] # 1, 25 # inverse this
self.canonical_frame_pos = self.canonical_frame[:, :16].reshape(4, 4)
def get_rays_kiui(self, c, opengl=True):
h, w = self.reso_encoder, self.reso_encoder
intrinsics, pose = c[16:], c[:16].reshape(4, 4)
# cx, cy, fx, fy = intrinsics[2], intrinsics[5]
fx = fy = 525 # pixel space
cx = cy = 256 # rendering default K
factor = self.reso / (cx * 2) # 128 / 512
fx = fx * factor
fy = fy * factor
x, y = torch.meshgrid(
torch.arange(w, device=pose.device),
torch.arange(h, device=pose.device),
indexing="xy",
)
x = x.flatten()
y = y.flatten()
cx = w * 0.5
cy = h * 0.5
# focal = h * 0.5 / np.tan(0.5 * np.deg2rad(fovy))
camera_dirs = F.pad(
torch.stack(
[
(x - cx + 0.5) / fx,
(y - cy + 0.5) / fy * (-1.0 if opengl else 1.0),
],
dim=-1,
),
(0, 1),
value=(-1.0 if opengl else 1.0),
) # [hw, 3]
rays_d = camera_dirs @ pose[:3, :3].transpose(0, 1) # [hw, 3]
rays_o = pose[:3, 3].unsqueeze(0).expand_as(rays_d) # [hw, 3]
rays_o = rays_o.view(h, w, 3)
rays_d = safe_normalize(rays_d).view(h, w, 3)
return rays_o, rays_d
def gen_rays(self, c):
# Generate rays
intrinsics, c2w = c[16:], c[:16].reshape(4, 4)
self.h = self.reso_encoder
self.w = self.reso_encoder
yy, xx = torch.meshgrid(
torch.arange(self.h, dtype=torch.float32) + 0.5,
torch.arange(self.w, dtype=torch.float32) + 0.5,
indexing='ij')
# normalize to 0-1 pixel range
yy = yy / self.h
xx = xx / self.w
# K = np.array([f_x, 0, w / 2, 0, f_y, h / 2, 0, 0, 1]).reshape(3, 3)
cx, cy, fx, fy = intrinsics[2], intrinsics[5], intrinsics[
0], intrinsics[4]
# cx *= self.w
# cy *= self.h
# f_x = f_y = fx * h / res_raw
c2w = torch.from_numpy(c2w).float()
xx = (xx - cx) / fx
yy = (yy - cy) / fy
zz = torch.ones_like(xx)
dirs = torch.stack((xx, yy, zz), dim=-1) # OpenCV convention
dirs /= torch.norm(dirs, dim=-1, keepdim=True)
dirs = dirs.reshape(-1, 3, 1)
del xx, yy, zz
# st()
dirs = (c2w[None, :3, :3] @ dirs)[..., 0]
origins = c2w[None, :3, 3].expand(self.h * self.w, -1).contiguous()
origins = origins.view(self.h, self.w, 3)
dirs = dirs.view(self.h, self.w, 3)
return origins, dirs
def normalize_camera(self, c):
# assert c.shape[0] == self.chunk_size # 8 o r10
c = c[None] # api compat
B = c.shape[0]
camera_poses = c[:, :16].reshape(B, 4, 4) # 3x4
cam_radius = np.linalg.norm(
self.canonical_frame_pos.reshape(4, 4)[:3, 3],
axis=-1,
keepdims=False) # since g-buffer adopts dynamic radius here.
frame1_fixed_pos = np.eye(4)
frame1_fixed_pos[2, -1] = -cam_radius
transform = frame1_fixed_pos @ np.linalg.inv(
self.canonical_frame_pos) # 4,4
# from LGM, https://github.com/3DTopia/LGM/blob/fe8d12cff8c827df7bb77a3c8e8b37408cb6fe4c/core/provider_objaverse.py#L127
# transform = torch.tensor([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, self.opt.cam_radius], [0, 0, 0, 1]], dtype=torch.float32) @ torch.inverse(c[[0,4]])
new_camera_poses = transform[None] @ camera_poses # [V, 4, 4]
c = np.concatenate([new_camera_poses.reshape(B, 16), c[:, 16:]],
axis=-1)
return c[0]
def _post_process_sample(self, data_sample):
# raw_img, depth, c, bbox, caption, ins = data_sample
raw_img, c, caption, ins = data_sample
# c = self.normalize_camera(c) @ if relative pose.
img = raw_img # 256x256
img = torch.from_numpy(img).permute(2, 0, 1) / 127.5 - 1
# load latent.
# latent_path = Path(self.mv_latent_dir, ins, 'latent.npy') # ! a converged version, before adding augmentation
# if random.random() > 0.5:
# latent_path = Path(self.mv_latent_dir, ins, 'latent.npy')
# else: # augmentation, double the dataset
latent_path = Path(
self.mv_latent_dir.replace('v=4-final', 'v=4-rotate'), ins,
'latent.npy')
latent = np.load(latent_path)
# return (img_to_encoder, img, c, caption, ins)
return (latent, img, c, caption, ins)
def rand_sample_idx(self):
return random.randint(0, self.instance_data_length - 1)
def rand_pair(self):
return (self.rand_sample_idx() for _ in range(2))
def paired_post_process(self, sample):
# repeat n times?
all_inp_list = []
all_nv_list = []
caption, ins = sample[-2:]
# expanded_return = []
for _ in range(self.pair_per_instance):
cano_idx, nv_idx = self.rand_pair()
cano_sample = self._post_process_sample(item[cano_idx]
for item in sample[:-2])
nv_sample = self._post_process_sample(item[nv_idx]
for item in sample[:-2])
all_inp_list.extend(cano_sample)
all_nv_list.extend(nv_sample)
return (*all_inp_list, *all_nv_list, caption, ins)
# return [cano_sample, nv_sample, caption, ins]
# return (*cano_sample, *nv_sample, caption, ins)
# def single_sample_create_dict(self, sample, prefix=''):
# # if len(sample) == 1:
# # sample = sample[0]
# # assert len(sample) == 6
# img_to_encoder, img, fg_mask_reso, depth_reso, c, bbox = sample
# return {
# # **sample,
# f'{prefix}img_to_encoder': img_to_encoder,
# f'{prefix}img': img,
# f'{prefix}depth_mask': fg_mask_reso,
# f'{prefix}depth': depth_reso,
# f'{prefix}c': c,
# f'{prefix}bbox': bbox,
# }
def single_sample_create_dict(self, sample, prefix=''):
# if len(sample) == 1:
# sample = sample[0]
# assert len(sample) == 6
# img_to_encoder, img, fg_mask_reso, depth_reso, c, bbox = sample
# img_to_encoder, img, c, caption, ins = sample
# img, c, caption, ins = sample
latent, img, c, caption, ins = sample
# load latent
return {
# **sample,
# 'img_to_encoder': img_to_encoder,
'latent': latent,
'img': img,
'c': c,
'caption': caption,
'ins': ins
}
def decode_zip(self, sample_pyd, shape=(256, 256)):
if isinstance(sample_pyd, tuple):
sample_pyd = sample_pyd[0]
assert isinstance(sample_pyd, dict)
raw_img = decompress_and_open_image_gzip(
sample_pyd['raw_img'],
is_img=True,
decompress=True,
decompress_fn=lz4.frame.decompress)
caption = sample_pyd['caption'].decode('utf-8')
ins = sample_pyd['ins'].decode('utf-8')
c = decompress_array(sample_pyd['c'], (25, ),
np.float32,
decompress=True,
decompress_fn=lz4.frame.decompress)
# bbox = decompress_array(
# sample_pyd['bbox'],
# (
# 40,
# 4,
# ),
# np.float32,
# # decompress=False)
# decompress=True,
# decompress_fn=lz4.frame.decompress)
# if self.decode_encode_img_only:
# depth = np.zeros(shape=(40, *shape)) # save loading time
# else:
# depth = decompress_array(sample_pyd['depth'], (40, *shape),
# np.float32,
# decompress=True,
# decompress_fn=lz4.frame.decompress)
# return {'raw_img': raw_img, 'depth': depth, 'bbox': bbox, 'caption': caption, 'ins': ins, 'c': c}
# return raw_img, depth, c, bbox, caption, ins
# return raw_img, bbox, caption, ins
# return bbox, caption, ins
return raw_img, c, caption, ins
# ! run single-instance pipeline first
# return raw_img[0], depth[0], c[0], bbox[0], caption, ins
def create_dict(self, sample):
# sample = [item[0] for item in sample] # wds wrap items in []
# cano_sample_list = [[] for _ in range(6)]
# nv_sample_list = [[] for _ in range(6)]
# for idx in range(0, self.pair_per_instance):
# cano_sample = sample[6*idx:6*(idx+1)]
# nv_sample = sample[6*self.pair_per_instance+6*idx:6*self.pair_per_instance+6*(idx+1)]
# for item_idx in range(6):
# cano_sample_list[item_idx].append(cano_sample[item_idx])
# nv_sample_list[item_idx].append(nv_sample[item_idx])
# # ! cycle input/output view for more pairs
# cano_sample_list[item_idx].append(nv_sample[item_idx])
# nv_sample_list[item_idx].append(cano_sample[item_idx])
cano_sample = self.single_sample_create_dict(sample, prefix='')
# nv_sample = self.single_sample_create_dict((torch.cat(item_list) for item_list in nv_sample_list) , prefix='nv_')
return cano_sample
# return {
# **cano_sample,
# # **nv_sample,
# 'caption': sample[-2],
# 'ins': sample[-1]
# }
# test tar loading
def load_wds_diff_ResampledShard(file_path,
batch_size,
num_workers,
reso,
reso_encoder,
test=False,
preprocess=None,
imgnet_normalize=True,
plucker_embedding=False,
decode_encode_img_only=False,
mv_latent_dir='',
**kwargs):
# return raw_img, depth, c, bbox, sample_pyd['ins.pyd'], sample_pyd['fname.pyd']
post_process_cls = PostProcessForDiff(
reso,
reso_encoder,
imgnet_normalize=imgnet_normalize,
plucker_embedding=plucker_embedding,
decode_encode_img_only=decode_encode_img_only,
mv_latent_dir=mv_latent_dir,
)
if isinstance(file_path, list): # lst of shard urls
all_shards = []
for url_path in file_path:
all_shards.extend(wds.shardlists.expand_source(url_path))
logger.log('all_shards', all_shards)
else:
all_shards = file_path # to be expanded
dataset = wds.DataPipeline(
wds.ResampledShards(all_shards), # url_shard
# at this point we have an iterator over all the shards
wds.shuffle(100),
wds.split_by_worker, # if multi-node
wds.tarfile_to_samples(),
# add wds.split_by_node here if you are using multiple nodes
wds.shuffle(
20000
), # shuffles in the memory, leverage large RAM for more efficient loading
wds.decode(wds.autodecode.basichandlers), # TODO
wds.to_tuple("sample.pyd"), # extract the pyd from top level dict
wds.map(post_process_cls.decode_zip),
# wds.map(post_process_cls.paired_post_process), # create input-novelview paired samples
wds.map(post_process_cls._post_process_sample),
# wds.detshuffle(1000), # shuffles in the memory, leverage large RAM for more efficient loading
wds.batched(
100,
partial=True,
# collation_fn=collate
) # streaming more data at once, and rebatch later
)
loader_shard = wds.WebLoader(
dataset,
num_workers=num_workers,
drop_last=False,
batch_size=None,
shuffle=False,
persistent_workers=num_workers
> 0).unbatched().shuffle(2500).batched(batch_size).map(
post_process_cls.create_dict)
# persistent_workers=num_workers > 0).unbatched().batched(batch_size).map(post_process_cls.create_dict)
# 1000).batched(batch_size).map(post_process_cls.create_dict)
# .map(collate)
# .map(collate)
# .batched(batch_size)
#
# .unbatched().shuffle(1000).batched(batch_size).map(post_process)
# # https://github.com/webdataset/webdataset/issues/187
# return next(iter(loader_shard))
#return dataset
return loader_shard
def load_wds_data(
file_path="",
reso=64,
reso_encoder=224,
batch_size=1,
num_workers=6,
plucker_embedding=False,
decode_encode_img_only=False,
load_wds_diff=False,
load_wds_latent=False,
load_instance=False, # for evaluation
mv_input=False,
split_chunk_input=False,
duplicate_sample=True,
mv_latent_dir='',
append_depth=False,
gs_cam_format=False,
orthog_duplicate=False,
**args):
if load_wds_diff:
# assert num_workers == 0 # on aliyun, worker=0 performs much much faster
wds_loader = load_wds_diff_ResampledShard(
file_path,
batch_size=batch_size,
num_workers=num_workers,
reso=reso,
reso_encoder=reso_encoder,
plucker_embedding=plucker_embedding,
decode_encode_img_only=decode_encode_img_only,
mv_input=mv_input,
split_chunk_input=split_chunk_input,
append_depth=append_depth,
mv_latent_dir=mv_latent_dir,
gs_cam_format=gs_cam_format,
orthog_duplicate=orthog_duplicate,
)
elif load_wds_latent:
# for diffusion training, cache latent
wds_loader = load_wds_latent_ResampledShard(
file_path,
batch_size=batch_size,
num_workers=num_workers,
reso=reso,
reso_encoder=reso_encoder,
plucker_embedding=plucker_embedding,
decode_encode_img_only=decode_encode_img_only,
mv_input=mv_input,
split_chunk_input=split_chunk_input,
)
# elif load_instance:
# wds_loader = load_wds_instance_ResampledShard(
# file_path,
# batch_size=batch_size,
# num_workers=num_workers,
# reso=reso,
# reso_encoder=reso_encoder,
# plucker_embedding=plucker_embedding,
# decode_encode_img_only=decode_encode_img_only
# )
else:
wds_loader = load_wds_ResampledShard(
file_path,
batch_size=batch_size,
num_workers=num_workers,
reso=reso,
reso_encoder=reso_encoder,
plucker_embedding=plucker_embedding,
decode_encode_img_only=decode_encode_img_only,
load_instance=load_instance,
mv_input=mv_input,
split_chunk_input=split_chunk_input,
duplicate_sample=duplicate_sample,
append_depth=append_depth,
gs_cam_format=gs_cam_format,
orthog_duplicate=orthog_duplicate,
)
while True:
yield from wds_loader
# yield from wds_loader
class PostProcess_forlatent:
def __init__(
self,
reso,
reso_encoder,
imgnet_normalize,
plucker_embedding,
decode_encode_img_only,
) -> None:
self.plucker_embedding = plucker_embedding
self.decode_encode_img_only = decode_encode_img_only
transformations = [
transforms.ToTensor(), # [0,1] range
]
if imgnet_normalize:
transformations.append(
transforms.Normalize((0.485, 0.456, 0.406),
(0.229, 0.224, 0.225)) # type: ignore
)
else:
transformations.append(
transforms.Normalize((0.5, 0.5, 0.5),
(0.5, 0.5, 0.5))) # type: ignore
self.normalize = transforms.Compose(transformations)
self.reso_encoder = reso_encoder
self.reso = reso
self.instance_data_length = 40
# self.pair_per_instance = 1 # compat
self.pair_per_instance = 2 # check whether improves IO
# self.pair_per_instance = 3 # check whether improves IO
# self.pair_per_instance = 4 # check whether improves IO
def _post_process_sample(self, data_sample):
# raw_img, depth, c, bbox, caption, ins = data_sample
raw_img, c, caption, ins = data_sample
# bbox = (bbox*(self.reso/256)).astype(np.uint8) # normalize bbox to the reso range
if raw_img.shape[-2] != self.reso_encoder:
img_to_encoder = cv2.resize(raw_img,
(self.reso_encoder, self.reso_encoder),
interpolation=cv2.INTER_LANCZOS4)
else:
img_to_encoder = raw_img
img_to_encoder = self.normalize(img_to_encoder)
if self.plucker_embedding:
rays_o, rays_d = self.gen_rays(c)
rays_plucker = torch.cat(
[torch.cross(rays_o, rays_d, dim=-1), rays_d],
dim=-1).permute(2, 0, 1) # [h, w, 6] -> 6,h,w
img_to_encoder = torch.cat([img_to_encoder, rays_plucker], 0)
img = cv2.resize(raw_img, (self.reso, self.reso),
interpolation=cv2.INTER_LANCZOS4)
img = torch.from_numpy(img).permute(2, 0, 1) / 127.5 - 1
return (img_to_encoder, img, c, caption, ins)
def rand_sample_idx(self):
return random.randint(0, self.instance_data_length - 1)
def rand_pair(self):
return (self.rand_sample_idx() for _ in range(2))
def paired_post_process(self, sample):
# repeat n times?
all_inp_list = []
all_nv_list = []
caption, ins = sample[-2:]
# expanded_return = []
for _ in range(self.pair_per_instance):
cano_idx, nv_idx = self.rand_pair()
cano_sample = self._post_process_sample(item[cano_idx]
for item in sample[:-2])
nv_sample = self._post_process_sample(item[nv_idx]
for item in sample[:-2])
all_inp_list.extend(cano_sample)
all_nv_list.extend(nv_sample)
return (*all_inp_list, *all_nv_list, caption, ins)
# return [cano_sample, nv_sample, caption, ins]
# return (*cano_sample, *nv_sample, caption, ins)
def paired_post_process(self, sample):
# repeat n times?
all_inp_list = []
all_nv_list = []
caption, ins = sample[-2:]
# expanded_return = []
for _ in range(self.pair_per_instance):
cano_idx, nv_idx = self.rand_pair()
cano_sample = self._post_process_sample(item[cano_idx]
for item in sample[:-2])
nv_sample = self._post_process_sample(item[nv_idx]
for item in sample[:-2])
all_inp_list.extend(cano_sample)
all_nv_list.extend(nv_sample)
return (*all_inp_list, *all_nv_list, caption, ins)
# return [cano_sample, nv_sample, caption, ins]
# return (*cano_sample, *nv_sample, caption, ins)
# def single_sample_create_dict(self, sample, prefix=''):
# # if len(sample) == 1:
# # sample = sample[0]
# # assert len(sample) == 6
# img_to_encoder, img, fg_mask_reso, depth_reso, c, bbox = sample
# return {
# # **sample,
# f'{prefix}img_to_encoder': img_to_encoder,
# f'{prefix}img': img,
# f'{prefix}depth_mask': fg_mask_reso,
# f'{prefix}depth': depth_reso,
# f'{prefix}c': c,
# f'{prefix}bbox': bbox,
# }
def single_sample_create_dict(self, sample, prefix=''):
# if len(sample) == 1:
# sample = sample[0]
# assert len(sample) == 6
# img_to_encoder, img, fg_mask_reso, depth_reso, c, bbox = sample
img_to_encoder, img, c, caption, ins = sample
return {
# **sample,
'img_to_encoder': img_to_encoder,
'img': img,
'c': c,
'caption': caption,
'ins': ins
}
def decode_zip(self, sample_pyd, shape=(256, 256)):
if isinstance(sample_pyd, tuple):
sample_pyd = sample_pyd[0]
assert isinstance(sample_pyd, dict)
latent = sample_pyd['latent']
caption = sample_pyd['caption'].decode('utf-8')
c = sample_pyd['c']
# img = sample_pyd['img']
# st()
return latent, caption, c
def create_dict(self, sample):
return {
# **sample,
'latent': sample[0],
'caption': sample[1],
'c': sample[2],
}
# test tar loading
def load_wds_latent_ResampledShard(file_path,
batch_size,
num_workers,
reso,
reso_encoder,
test=False,
preprocess=None,
imgnet_normalize=True,
plucker_embedding=False,
decode_encode_img_only=False,
**kwargs):
# return raw_img, depth, c, bbox, sample_pyd['ins.pyd'], sample_pyd['fname.pyd']
post_process_cls = PostProcess_forlatent(
reso,
reso_encoder,
imgnet_normalize=imgnet_normalize,
plucker_embedding=plucker_embedding,
decode_encode_img_only=decode_encode_img_only,
)
if isinstance(file_path, list): # lst of shard urls
all_shards = []
for url_path in file_path:
all_shards.extend(wds.shardlists.expand_source(url_path))
logger.log('all_shards', all_shards)
else:
all_shards = file_path # to be expanded
dataset = wds.DataPipeline(
wds.ResampledShards(all_shards), # url_shard
# at this point we have an iterator over all the shards
wds.shuffle(50),
wds.split_by_worker, # if multi-node
wds.tarfile_to_samples(),
# add wds.split_by_node here if you are using multiple nodes
wds.detshuffle(
2500
), # shuffles in the memory, leverage large RAM for more efficient loading
wds.decode(wds.autodecode.basichandlers), # TODO
wds.to_tuple("sample.pyd"), # extract the pyd from top level dict
wds.map(post_process_cls.decode_zip),
# wds.map(post_process_cls._post_process_sample),
# wds.detshuffle(1000), # shuffles in the memory, leverage large RAM for more efficient loading
wds.batched(
150,
partial=True,
# collation_fn=collate
) # streaming more data at once, and rebatch later
)
loader_shard = wds.WebLoader(
dataset,
num_workers=num_workers,
drop_last=False,
batch_size=None,
shuffle=False,
persistent_workers=num_workers
> 0).unbatched().shuffle(1000).batched(batch_size).map(
post_process_cls.create_dict)
# persistent_workers=num_workers > 0).unbatched().batched(batch_size).map(post_process_cls.create_dict)
# 1000).batched(batch_size).map(post_process_cls.create_dict)
# .map(collate)
# .map(collate)
# .batched(batch_size)
#
# .unbatched().shuffle(1000).batched(batch_size).map(post_process)
# # https://github.com/webdataset/webdataset/issues/187
# return next(iter(loader_shard))
#return dataset
return loader_shard