|
import os, sys |
|
import math |
|
import json |
|
import glm |
|
from pathlib import Path |
|
|
|
import random |
|
import numpy as np |
|
from PIL import Image |
|
import webdataset as wds |
|
import pytorch_lightning as pl |
|
import sys |
|
from src.utils import obj, render_utils |
|
import torch |
|
import torch.nn.functional as F |
|
from torch.utils.data import Dataset |
|
from torch.utils.data.distributed import DistributedSampler |
|
import random |
|
import itertools |
|
from src.utils.train_util import instantiate_from_config |
|
from src.utils.camera_util import ( |
|
FOV_to_intrinsics, |
|
center_looking_at_camera_pose, |
|
get_circular_camera_poses, |
|
) |
|
os.environ["OPENCV_IO_ENABLE_OPENEXR"]="1" |
|
import re |
|
|
|
def spherical_camera_pose(azimuths: np.ndarray, elevations: np.ndarray, radius=2.5): |
|
azimuths = np.deg2rad(azimuths) |
|
elevations = np.deg2rad(elevations) |
|
|
|
xs = radius * np.cos(elevations) * np.cos(azimuths) |
|
ys = radius * np.cos(elevations) * np.sin(azimuths) |
|
zs = radius * np.sin(elevations) |
|
|
|
cam_locations = np.stack([xs, ys, zs], axis=-1) |
|
cam_locations = torch.from_numpy(cam_locations).float() |
|
|
|
c2ws = center_looking_at_camera_pose(cam_locations) |
|
return c2ws |
|
|
|
def find_matching_files(base_path, idx): |
|
formatted_idx = '%03d' % idx |
|
pattern = re.compile(r'^%s_\d+\.png$' % formatted_idx) |
|
matching_files = [] |
|
|
|
if os.path.exists(base_path): |
|
for filename in os.listdir(base_path): |
|
if pattern.match(filename): |
|
matching_files.append(filename) |
|
|
|
return os.path.join(base_path, matching_files[0]) |
|
|
|
def load_mipmap(env_path): |
|
diffuse_path = os.path.join(env_path, "diffuse.pth") |
|
diffuse = torch.load(diffuse_path, map_location=torch.device('cpu')) |
|
|
|
specular = [] |
|
for i in range(6): |
|
specular_path = os.path.join(env_path, f"specular_{i}.pth") |
|
specular_tensor = torch.load(specular_path, map_location=torch.device('cpu')) |
|
specular.append(specular_tensor) |
|
return [specular, diffuse] |
|
|
|
def convert_to_white_bg(image, write_bg=True): |
|
alpha = image[:, :, 3:] |
|
if write_bg: |
|
return image[:, :, :3] * alpha + 1. * (1 - alpha) |
|
else: |
|
return image[:, :, :3] * alpha |
|
|
|
def load_obj(path, return_attributes=False, scale_factor=1.0): |
|
return obj.load_obj(path, clear_ks=True, mtl_override=None, return_attributes=return_attributes, scale_factor=scale_factor) |
|
|
|
def custom_collate_fn(batch): |
|
return batch |
|
|
|
|
|
def collate_fn_wrapper(batch): |
|
return custom_collate_fn(batch) |
|
|
|
class DataModuleFromConfig(pl.LightningDataModule): |
|
def __init__( |
|
self, |
|
batch_size=8, |
|
num_workers=4, |
|
train=None, |
|
validation=None, |
|
test=None, |
|
**kwargs, |
|
): |
|
super().__init__() |
|
|
|
self.batch_size = batch_size |
|
self.num_workers = num_workers |
|
|
|
self.dataset_configs = dict() |
|
if train is not None: |
|
self.dataset_configs['train'] = train |
|
if validation is not None: |
|
self.dataset_configs['validation'] = validation |
|
if test is not None: |
|
self.dataset_configs['test'] = test |
|
|
|
def setup(self, stage): |
|
|
|
if stage in ['fit']: |
|
self.datasets = dict((k, instantiate_from_config(self.dataset_configs[k])) for k in self.dataset_configs) |
|
else: |
|
raise NotImplementedError |
|
|
|
def custom_collate_fn(self, batch): |
|
collated_batch = {} |
|
for key in batch[0].keys(): |
|
if key == 'input_env' or key == 'target_env': |
|
collated_batch[key] = [d[key] for d in batch] |
|
else: |
|
collated_batch[key] = torch.stack([d[key] for d in batch], dim=0) |
|
return collated_batch |
|
|
|
def convert_to_white_bg(self, image): |
|
alpha = image[:, :, 3:] |
|
return image[:, :, :3] * alpha + 1. * (1 - alpha) |
|
|
|
def load_obj(self, path): |
|
return obj.load_obj(path, clear_ks=True, mtl_override=None) |
|
|
|
def train_dataloader(self): |
|
|
|
sampler = DistributedSampler(self.datasets['train']) |
|
return wds.WebLoader(self.datasets['train'], batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False, sampler=sampler, collate_fn=collate_fn_wrapper) |
|
|
|
def val_dataloader(self): |
|
|
|
sampler = DistributedSampler(self.datasets['validation']) |
|
return wds.WebLoader(self.datasets['validation'], batch_size=1, num_workers=self.num_workers, shuffle=False, sampler=sampler, collate_fn=collate_fn_wrapper) |
|
|
|
def test_dataloader(self): |
|
|
|
return wds.WebLoader(self.datasets['test'], batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False) |
|
|
|
|
|
class ObjaverseData(Dataset): |
|
def __init__(self, |
|
root_dir='Objaverse_highQuality', |
|
light_dir= 'env_mipmap', |
|
input_view_num=6, |
|
target_view_num=4, |
|
total_view_n=18, |
|
distance=3.5, |
|
fov=50, |
|
camera_random=False, |
|
validation=False, |
|
): |
|
self.root_dir = Path(root_dir) |
|
self.light_dir = light_dir |
|
self.all_env_name = [] |
|
for temp_dir in os.listdir(light_dir): |
|
if os.listdir(os.path.join(self.light_dir, temp_dir)): |
|
self.all_env_name.append(temp_dir) |
|
|
|
self.input_view_num = input_view_num |
|
self.target_view_num = target_view_num |
|
self.total_view_n = total_view_n |
|
self.fov = fov |
|
self.camera_random = camera_random |
|
|
|
self.train_res = [512, 512] |
|
self.cam_near_far = [0.1, 1000.0] |
|
self.fov_rad = np.deg2rad(fov) |
|
self.fov_deg = fov |
|
self.spp = 1 |
|
self.cam_radius = distance |
|
self.layers = 1 |
|
|
|
numbers = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0] |
|
self.combinations = list(itertools.product(numbers, repeat=2)) |
|
|
|
self.paths = os.listdir(self.root_dir) |
|
|
|
|
|
|
|
|
|
print('total training object num:', len(self.paths)) |
|
|
|
self.depth_scale = 6.0 |
|
|
|
total_objects = len(self.paths) |
|
print('============= length of dataset %d =============' % total_objects) |
|
|
|
def __len__(self): |
|
return len(self.paths) |
|
|
|
def load_obj(self, path): |
|
return obj.load_obj(path, clear_ks=True, mtl_override=None) |
|
|
|
def sample_spherical(self, phi, theta, cam_radius): |
|
theta = np.deg2rad(theta) |
|
phi = np.deg2rad(phi) |
|
|
|
z = cam_radius * np.cos(phi) * np.sin(theta) |
|
x = cam_radius * np.sin(phi) * np.sin(theta) |
|
y = cam_radius * np.cos(theta) |
|
|
|
return x, y, z |
|
|
|
def _random_scene(self, cam_radius, fov_rad): |
|
iter_res = self.train_res |
|
proj_mtx = render_utils.perspective(fov_rad, iter_res[1] / iter_res[0], self.cam_near_far[0], self.cam_near_far[1]) |
|
|
|
azimuths = random.uniform(0, 360) |
|
elevations = random.uniform(30, 150) |
|
mv_embedding = spherical_camera_pose(azimuths, 90-elevations, cam_radius) |
|
x, y, z = self.sample_spherical(azimuths, elevations, cam_radius) |
|
eye = glm.vec3(x, y, z) |
|
at = glm.vec3(0.0, 0.0, 0.0) |
|
up = glm.vec3(0.0, 1.0, 0.0) |
|
view_matrix = glm.lookAt(eye, at, up) |
|
mv = torch.from_numpy(np.array(view_matrix)) |
|
mvp = proj_mtx @ (mv) |
|
campos = torch.linalg.inv(mv)[:3, 3] |
|
return mv[None, ...], mvp[None, ...], campos[None, ...], mv_embedding[None, ...], iter_res, self.spp |
|
|
|
def load_im(self, path, color): |
|
''' |
|
replace background pixel with random color in rendering |
|
''' |
|
pil_img = Image.open(path) |
|
|
|
image = np.asarray(pil_img, dtype=np.float32) / 255. |
|
alpha = image[:, :, 3:] |
|
image = image[:, :, :3] * alpha + color * (1 - alpha) |
|
|
|
image = torch.from_numpy(image).permute(2, 0, 1).contiguous().float() |
|
alpha = torch.from_numpy(alpha).permute(2, 0, 1).contiguous().float() |
|
return image, alpha |
|
|
|
def load_albedo(self, path, color, mask): |
|
''' |
|
replace background pixel with random color in rendering |
|
''' |
|
pil_img = Image.open(path) |
|
|
|
image = np.asarray(pil_img, dtype=np.float32) / 255. |
|
image = torch.from_numpy(image).permute(2, 0, 1).contiguous().float() |
|
|
|
color = torch.ones_like(image) |
|
image = image * mask + color * (1 - mask) |
|
return image |
|
|
|
def convert_to_white_bg(self, image): |
|
alpha = image[:, :, 3:] |
|
return image[:, :, :3] * alpha + 1. * (1 - alpha) |
|
|
|
def calculate_fov(self, initial_distance, initial_fov, new_distance): |
|
initial_fov_rad = math.radians(initial_fov) |
|
|
|
height = 2 * initial_distance * math.tan(initial_fov_rad / 2) |
|
|
|
new_fov_rad = 2 * math.atan(height / (2 * new_distance)) |
|
|
|
new_fov = math.degrees(new_fov_rad) |
|
|
|
return new_fov |
|
|
|
def __getitem__(self, index): |
|
obj_path = os.path.join(self.root_dir, self.paths[index]) |
|
mesh_attributes = torch.load(obj_path, map_location=torch.device('cpu')) |
|
pose_list = [] |
|
env_list = [] |
|
material_list = [] |
|
camera_pos = [] |
|
c2w_list = [] |
|
camera_embedding_list = [] |
|
random_env = False |
|
random_mr = False |
|
if random.random() > 0.5: |
|
random_env = True |
|
if random.random() > 0.5: |
|
random_mr = True |
|
selected_env = random.randint(0, len(self.all_env_name)-1) |
|
materials = random.choice(self.combinations) |
|
if self.camera_random: |
|
random_perturbation = random.uniform(-1.5, 1.5) |
|
cam_radius = self.cam_radius + random_perturbation |
|
fov_deg = self.calculate_fov(initial_distance=self.cam_radius, initial_fov=self.fov_deg, new_distance=cam_radius) |
|
fov_rad = np.deg2rad(fov_deg) |
|
else: |
|
cam_radius = self.cam_radius |
|
fov_rad = self.fov_rad |
|
fov_deg = self.fov_deg |
|
|
|
if len(self.input_view_num) >= 1: |
|
input_view_num = random.choice(self.input_view_num) |
|
else: |
|
input_view_num = self.input_view_num |
|
for _ in range(input_view_num + self.target_view_num): |
|
mv, mvp, campos, mv_mebedding, iter_res, iter_spp = self._random_scene(cam_radius, fov_rad) |
|
if random_env: |
|
selected_env = random.randint(0, len(self.all_env_name)-1) |
|
env_path = os.path.join(self.light_dir, self.all_env_name[selected_env]) |
|
env = load_mipmap(env_path) |
|
if random_mr: |
|
materials = random.choice(self.combinations) |
|
pose_list.append(mvp) |
|
camera_pos.append(campos) |
|
c2w_list.append(mv) |
|
env_list.append(env) |
|
material_list.append(materials) |
|
camera_embedding_list.append(mv_mebedding) |
|
data = { |
|
'mesh_attributes': mesh_attributes, |
|
'input_view_num': input_view_num, |
|
'target_view_num': self.target_view_num, |
|
'obj_path': obj_path, |
|
'pose_list': pose_list, |
|
'camera_pos': camera_pos, |
|
'c2w_list': c2w_list, |
|
'env_list': env_list, |
|
'material_list': material_list, |
|
'camera_embedding_list': camera_embedding_list, |
|
'fov_deg':fov_deg, |
|
'raduis': cam_radius |
|
} |
|
|
|
return data |
|
|
|
class ValidationData(Dataset): |
|
def __init__(self, |
|
root_dir='objaverse/', |
|
input_view_num=6, |
|
input_image_size=320, |
|
fov=30, |
|
): |
|
self.root_dir = Path(root_dir) |
|
self.input_view_num = input_view_num |
|
self.input_image_size = input_image_size |
|
self.fov = fov |
|
self.light_dir = 'env_mipmap' |
|
|
|
|
|
|
|
|
|
self.paths = os.listdir(self.root_dir) |
|
|
|
|
|
print('============= length of dataset %d =============' % len(self.paths)) |
|
|
|
cam_distance = 4.0 |
|
azimuths = np.array([30, 90, 150, 210, 270, 330]) |
|
elevations = np.array([20, -10, 20, -10, 20, -10]) |
|
azimuths = np.deg2rad(azimuths) |
|
elevations = np.deg2rad(elevations) |
|
|
|
x = cam_distance * np.cos(elevations) * np.cos(azimuths) |
|
y = cam_distance * np.cos(elevations) * np.sin(azimuths) |
|
z = cam_distance * np.sin(elevations) |
|
|
|
cam_locations = np.stack([x, y, z], axis=-1) |
|
cam_locations = torch.from_numpy(cam_locations).float() |
|
c2ws = center_looking_at_camera_pose(cam_locations) |
|
self.c2ws = c2ws.float() |
|
self.Ks = FOV_to_intrinsics(self.fov).unsqueeze(0).repeat(6, 1, 1).float() |
|
|
|
render_c2ws = get_circular_camera_poses(M=8, radius=cam_distance, elevation=20.0) |
|
render_Ks = FOV_to_intrinsics(self.fov).unsqueeze(0).repeat(render_c2ws.shape[0], 1, 1) |
|
self.render_c2ws = render_c2ws.float() |
|
self.render_Ks = render_Ks.float() |
|
|
|
def __len__(self): |
|
return len(self.paths) |
|
|
|
def load_im(self, path, color): |
|
''' |
|
replace background pixel with random color in rendering |
|
''' |
|
pil_img = Image.open(path) |
|
pil_img = pil_img.resize((self.input_image_size, self.input_image_size), resample=Image.BICUBIC) |
|
|
|
image = np.asarray(pil_img, dtype=np.float32) / 255. |
|
if image.shape[-1] == 4: |
|
alpha = image[:, :, 3:] |
|
image = image[:, :, :3] * alpha + color * (1 - alpha) |
|
else: |
|
alpha = np.ones_like(image[:, :, :1]) |
|
|
|
image = torch.from_numpy(image).permute(2, 0, 1).contiguous().float() |
|
alpha = torch.from_numpy(alpha).permute(2, 0, 1).contiguous().float() |
|
return image, alpha |
|
|
|
def load_mat(self, path, color): |
|
''' |
|
replace background pixel with random color in rendering |
|
''' |
|
pil_img = Image.open(path) |
|
pil_img = pil_img.resize((384,384), resample=Image.BICUBIC) |
|
|
|
image = np.asarray(pil_img, dtype=np.float32) / 255. |
|
if image.shape[-1] == 4: |
|
alpha = image[:, :, 3:] |
|
image = image[:, :, :3] * alpha + color * (1 - alpha) |
|
else: |
|
alpha = np.ones_like(image[:, :, :1]) |
|
|
|
image = torch.from_numpy(image).permute(2, 0, 1).contiguous().float() |
|
alpha = torch.from_numpy(alpha).permute(2, 0, 1).contiguous().float() |
|
return image, alpha |
|
|
|
def load_albedo(self, path, color, mask): |
|
''' |
|
replace background pixel with random color in rendering |
|
''' |
|
pil_img = Image.open(path) |
|
pil_img = pil_img.resize((self.input_image_size, self.input_image_size), resample=Image.BICUBIC) |
|
|
|
image = np.asarray(pil_img, dtype=np.float32) / 255. |
|
image = torch.from_numpy(image).permute(2, 0, 1).contiguous().float() |
|
|
|
color = torch.ones_like(image) |
|
image = image * mask + color * (1 - mask) |
|
return image |
|
|
|
def __getitem__(self, index): |
|
|
|
|
|
input_image_path = os.path.join(self.root_dir, self.paths[index]) |
|
|
|
'''background color, default: white''' |
|
bkg_color = [1.0, 1.0, 1.0] |
|
|
|
image_list = [] |
|
albedo_list = [] |
|
alpha_list = [] |
|
specular_list = [] |
|
diffuse_list = [] |
|
metallic_list = [] |
|
roughness_list = [] |
|
|
|
exist_comb_list = [] |
|
for subfolder in os.listdir(input_image_path): |
|
found_numeric_subfolder=False |
|
subfolder_path = os.path.join(input_image_path, subfolder) |
|
if os.path.isdir(subfolder_path) and '_' in subfolder and 'specular' not in subfolder and 'diffuse' not in subfolder: |
|
try: |
|
parts = subfolder.split('_') |
|
float(parts[0]) |
|
float(parts[1]) |
|
found_numeric_subfolder = True |
|
except ValueError: |
|
continue |
|
if found_numeric_subfolder: |
|
exist_comb_list.append(subfolder) |
|
|
|
selected_one_comb = random.choice(exist_comb_list) |
|
|
|
|
|
for idx in range(self.input_view_num): |
|
img_path = find_matching_files(os.path.join(input_image_path, selected_one_comb, 'rgb'), idx) |
|
albedo_path = img_path.replace('rgb', 'albedo') |
|
metallic_path = img_path.replace('rgb', 'metallic') |
|
roughness_path = img_path.replace('rgb', 'roughness') |
|
|
|
image, alpha = self.load_im(img_path, bkg_color) |
|
albedo = self.load_albedo(albedo_path, bkg_color, alpha) |
|
metallic,_ = self.load_mat(metallic_path, bkg_color) |
|
roughness,_ = self.load_mat(roughness_path, bkg_color) |
|
|
|
light_num = os.path.basename(img_path).split('_')[1].split('.')[0] |
|
light_path = os.path.join(self.light_dir, str(int(light_num)+1)) |
|
|
|
specular, diffuse = load_mipmap(light_path) |
|
|
|
image_list.append(image) |
|
alpha_list.append(alpha) |
|
albedo_list.append(albedo) |
|
metallic_list.append(metallic) |
|
roughness_list.append(roughness) |
|
specular_list.append(specular) |
|
diffuse_list.append(diffuse) |
|
|
|
images = torch.stack(image_list, dim=0).float() |
|
alphas = torch.stack(alpha_list, dim=0).float() |
|
albedo = torch.stack(albedo_list, dim=0).float() |
|
metallic = torch.stack(metallic_list, dim=0).float() |
|
roughness = torch.stack(roughness_list, dim=0).float() |
|
|
|
data = { |
|
'input_images': images, |
|
'input_alphas': alphas, |
|
'input_c2ws': self.c2ws, |
|
'input_Ks': self.Ks, |
|
|
|
'input_albedos': albedo[:self.input_view_num], |
|
'input_metallics': metallic[:self.input_view_num], |
|
'input_roughness': roughness[:self.input_view_num], |
|
|
|
'specular': specular_list[:self.input_view_num], |
|
'diffuse': diffuse_list[:self.input_view_num], |
|
|
|
'render_c2ws': self.render_c2ws, |
|
'render_Ks': self.render_Ks, |
|
} |
|
return data |
|
|
|
|
|
if __name__ == '__main__': |
|
dataset = ObjaverseData() |
|
dataset.new(1) |
|
|