import os, sys import math import json import glm from pathlib import Path import random import numpy as np from PIL import Image import webdataset as wds import pytorch_lightning as pl import sys from src.utils import obj, render_utils import torch import torch.nn.functional as F from torch.utils.data import Dataset from torch.utils.data.distributed import DistributedSampler import random import itertools from src.utils.train_util import instantiate_from_config from src.utils.camera_util import ( FOV_to_intrinsics, center_looking_at_camera_pose, get_circular_camera_poses, ) os.environ["OPENCV_IO_ENABLE_OPENEXR"]="1" import re def spherical_camera_pose(azimuths: np.ndarray, elevations: np.ndarray, radius=2.5): azimuths = np.deg2rad(azimuths) elevations = np.deg2rad(elevations) xs = radius * np.cos(elevations) * np.cos(azimuths) ys = radius * np.cos(elevations) * np.sin(azimuths) zs = radius * np.sin(elevations) cam_locations = np.stack([xs, ys, zs], axis=-1) cam_locations = torch.from_numpy(cam_locations).float() c2ws = center_looking_at_camera_pose(cam_locations) return c2ws def find_matching_files(base_path, idx): formatted_idx = '%03d' % idx pattern = re.compile(r'^%s_\d+\.png$' % formatted_idx) matching_files = [] if os.path.exists(base_path): for filename in os.listdir(base_path): if pattern.match(filename): matching_files.append(filename) return os.path.join(base_path, matching_files[0]) def load_mipmap(env_path): diffuse_path = os.path.join(env_path, "diffuse.pth") diffuse = torch.load(diffuse_path, map_location=torch.device('cpu')) specular = [] for i in range(6): specular_path = os.path.join(env_path, f"specular_{i}.pth") specular_tensor = torch.load(specular_path, map_location=torch.device('cpu')) specular.append(specular_tensor) return [specular, diffuse] def convert_to_white_bg(image, write_bg=True): alpha = image[:, :, 3:] if write_bg: return image[:, :, :3] * alpha + 1. * (1 - alpha) else: return image[:, :, :3] * alpha def load_obj(path, return_attributes=False, scale_factor=1.0): return obj.load_obj(path, clear_ks=True, mtl_override=None, return_attributes=return_attributes, scale_factor=scale_factor) def custom_collate_fn(batch): return batch def collate_fn_wrapper(batch): return custom_collate_fn(batch) class DataModuleFromConfig(pl.LightningDataModule): def __init__( self, batch_size=8, num_workers=4, train=None, validation=None, test=None, **kwargs, ): super().__init__() self.batch_size = batch_size self.num_workers = num_workers self.dataset_configs = dict() if train is not None: self.dataset_configs['train'] = train if validation is not None: self.dataset_configs['validation'] = validation if test is not None: self.dataset_configs['test'] = test def setup(self, stage): if stage in ['fit']: self.datasets = dict((k, instantiate_from_config(self.dataset_configs[k])) for k in self.dataset_configs) else: raise NotImplementedError def custom_collate_fn(self, batch): collated_batch = {} for key in batch[0].keys(): if key == 'input_env' or key == 'target_env': collated_batch[key] = [d[key] for d in batch] else: collated_batch[key] = torch.stack([d[key] for d in batch], dim=0) return collated_batch def convert_to_white_bg(self, image): alpha = image[:, :, 3:] return image[:, :, :3] * alpha + 1. * (1 - alpha) def load_obj(self, path): return obj.load_obj(path, clear_ks=True, mtl_override=None) def train_dataloader(self): sampler = DistributedSampler(self.datasets['train']) return wds.WebLoader(self.datasets['train'], batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False, sampler=sampler, collate_fn=collate_fn_wrapper) def val_dataloader(self): sampler = DistributedSampler(self.datasets['validation']) return wds.WebLoader(self.datasets['validation'], batch_size=1, num_workers=self.num_workers, shuffle=False, sampler=sampler, collate_fn=collate_fn_wrapper) def test_dataloader(self): return wds.WebLoader(self.datasets['test'], batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False) class ObjaverseData(Dataset): def __init__(self, root_dir='Objaverse_highQuality', light_dir= 'env_mipmap', input_view_num=6, target_view_num=4, total_view_n=18, distance=3.5, fov=50, camera_random=False, validation=False, ): self.root_dir = Path(root_dir) self.light_dir = light_dir self.all_env_name = [] for temp_dir in os.listdir(light_dir): if os.listdir(os.path.join(self.light_dir, temp_dir)): self.all_env_name.append(temp_dir) self.input_view_num = input_view_num self.target_view_num = target_view_num self.total_view_n = total_view_n self.fov = fov self.camera_random = camera_random self.train_res = [512, 512] self.cam_near_far = [0.1, 1000.0] self.fov_rad = np.deg2rad(fov) self.fov_deg = fov self.spp = 1 self.cam_radius = distance self.layers = 1 numbers = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0] self.combinations = list(itertools.product(numbers, repeat=2)) self.paths = os.listdir(self.root_dir) # with open("BJ_Mesh_list.json", 'r') as file: # self.paths = json.load(file) print('total training object num:', len(self.paths)) self.depth_scale = 6.0 total_objects = len(self.paths) print('============= length of dataset %d =============' % total_objects) def __len__(self): return len(self.paths) def load_obj(self, path): return obj.load_obj(path, clear_ks=True, mtl_override=None) def sample_spherical(self, phi, theta, cam_radius): theta = np.deg2rad(theta) phi = np.deg2rad(phi) z = cam_radius * np.cos(phi) * np.sin(theta) x = cam_radius * np.sin(phi) * np.sin(theta) y = cam_radius * np.cos(theta) return x, y, z def _random_scene(self, cam_radius, fov_rad): iter_res = self.train_res proj_mtx = render_utils.perspective(fov_rad, iter_res[1] / iter_res[0], self.cam_near_far[0], self.cam_near_far[1]) azimuths = random.uniform(0, 360) elevations = random.uniform(30, 150) mv_embedding = spherical_camera_pose(azimuths, 90-elevations, cam_radius) x, y, z = self.sample_spherical(azimuths, elevations, cam_radius) eye = glm.vec3(x, y, z) at = glm.vec3(0.0, 0.0, 0.0) up = glm.vec3(0.0, 1.0, 0.0) view_matrix = glm.lookAt(eye, at, up) mv = torch.from_numpy(np.array(view_matrix)) mvp = proj_mtx @ (mv) #w2c campos = torch.linalg.inv(mv)[:3, 3] return mv[None, ...], mvp[None, ...], campos[None, ...], mv_embedding[None, ...], iter_res, self.spp # Add batch dimension def load_im(self, path, color): ''' replace background pixel with random color in rendering ''' pil_img = Image.open(path) image = np.asarray(pil_img, dtype=np.float32) / 255. alpha = image[:, :, 3:] image = image[:, :, :3] * alpha + color * (1 - alpha) image = torch.from_numpy(image).permute(2, 0, 1).contiguous().float() alpha = torch.from_numpy(alpha).permute(2, 0, 1).contiguous().float() return image, alpha def load_albedo(self, path, color, mask): ''' replace background pixel with random color in rendering ''' pil_img = Image.open(path) image = np.asarray(pil_img, dtype=np.float32) / 255. image = torch.from_numpy(image).permute(2, 0, 1).contiguous().float() color = torch.ones_like(image) image = image * mask + color * (1 - mask) return image def convert_to_white_bg(self, image): alpha = image[:, :, 3:] return image[:, :, :3] * alpha + 1. * (1 - alpha) def calculate_fov(self, initial_distance, initial_fov, new_distance): initial_fov_rad = math.radians(initial_fov) height = 2 * initial_distance * math.tan(initial_fov_rad / 2) new_fov_rad = 2 * math.atan(height / (2 * new_distance)) new_fov = math.degrees(new_fov_rad) return new_fov def __getitem__(self, index): obj_path = os.path.join(self.root_dir, self.paths[index]) mesh_attributes = torch.load(obj_path, map_location=torch.device('cpu')) pose_list = [] env_list = [] material_list = [] camera_pos = [] c2w_list = [] camera_embedding_list = [] random_env = False random_mr = False if random.random() > 0.5: random_env = True if random.random() > 0.5: random_mr = True selected_env = random.randint(0, len(self.all_env_name)-1) materials = random.choice(self.combinations) if self.camera_random: random_perturbation = random.uniform(-1.5, 1.5) cam_radius = self.cam_radius + random_perturbation fov_deg = self.calculate_fov(initial_distance=self.cam_radius, initial_fov=self.fov_deg, new_distance=cam_radius) fov_rad = np.deg2rad(fov_deg) else: cam_radius = self.cam_radius fov_rad = self.fov_rad fov_deg = self.fov_deg if len(self.input_view_num) >= 1: input_view_num = random.choice(self.input_view_num) else: input_view_num = self.input_view_num for _ in range(input_view_num + self.target_view_num): mv, mvp, campos, mv_mebedding, iter_res, iter_spp = self._random_scene(cam_radius, fov_rad) if random_env: selected_env = random.randint(0, len(self.all_env_name)-1) env_path = os.path.join(self.light_dir, self.all_env_name[selected_env]) env = load_mipmap(env_path) if random_mr: materials = random.choice(self.combinations) pose_list.append(mvp) camera_pos.append(campos) c2w_list.append(mv) env_list.append(env) material_list.append(materials) camera_embedding_list.append(mv_mebedding) data = { 'mesh_attributes': mesh_attributes, 'input_view_num': input_view_num, 'target_view_num': self.target_view_num, 'obj_path': obj_path, 'pose_list': pose_list, 'camera_pos': camera_pos, 'c2w_list': c2w_list, 'env_list': env_list, 'material_list': material_list, 'camera_embedding_list': camera_embedding_list, 'fov_deg':fov_deg, 'raduis': cam_radius } return data class ValidationData(Dataset): def __init__(self, root_dir='objaverse/', input_view_num=6, input_image_size=320, fov=30, ): self.root_dir = Path(root_dir) self.input_view_num = input_view_num self.input_image_size = input_image_size self.fov = fov self.light_dir = 'env_mipmap' # with open('Mesh_list.json') as f: # filtered_dict = json.load(f) self.paths = os.listdir(self.root_dir) # self.paths = filtered_dict print('============= length of dataset %d =============' % len(self.paths)) cam_distance = 4.0 azimuths = np.array([30, 90, 150, 210, 270, 330]) elevations = np.array([20, -10, 20, -10, 20, -10]) azimuths = np.deg2rad(azimuths) elevations = np.deg2rad(elevations) x = cam_distance * np.cos(elevations) * np.cos(azimuths) y = cam_distance * np.cos(elevations) * np.sin(azimuths) z = cam_distance * np.sin(elevations) cam_locations = np.stack([x, y, z], axis=-1) cam_locations = torch.from_numpy(cam_locations).float() c2ws = center_looking_at_camera_pose(cam_locations) self.c2ws = c2ws.float() self.Ks = FOV_to_intrinsics(self.fov).unsqueeze(0).repeat(6, 1, 1).float() render_c2ws = get_circular_camera_poses(M=8, radius=cam_distance, elevation=20.0) render_Ks = FOV_to_intrinsics(self.fov).unsqueeze(0).repeat(render_c2ws.shape[0], 1, 1) self.render_c2ws = render_c2ws.float() self.render_Ks = render_Ks.float() def __len__(self): return len(self.paths) def load_im(self, path, color): ''' replace background pixel with random color in rendering ''' pil_img = Image.open(path) pil_img = pil_img.resize((self.input_image_size, self.input_image_size), resample=Image.BICUBIC) image = np.asarray(pil_img, dtype=np.float32) / 255. if image.shape[-1] == 4: alpha = image[:, :, 3:] image = image[:, :, :3] * alpha + color * (1 - alpha) else: alpha = np.ones_like(image[:, :, :1]) image = torch.from_numpy(image).permute(2, 0, 1).contiguous().float() alpha = torch.from_numpy(alpha).permute(2, 0, 1).contiguous().float() return image, alpha def load_mat(self, path, color): ''' replace background pixel with random color in rendering ''' pil_img = Image.open(path) pil_img = pil_img.resize((384,384), resample=Image.BICUBIC) image = np.asarray(pil_img, dtype=np.float32) / 255. if image.shape[-1] == 4: alpha = image[:, :, 3:] image = image[:, :, :3] * alpha + color * (1 - alpha) else: alpha = np.ones_like(image[:, :, :1]) image = torch.from_numpy(image).permute(2, 0, 1).contiguous().float() alpha = torch.from_numpy(alpha).permute(2, 0, 1).contiguous().float() return image, alpha def load_albedo(self, path, color, mask): ''' replace background pixel with random color in rendering ''' pil_img = Image.open(path) pil_img = pil_img.resize((self.input_image_size, self.input_image_size), resample=Image.BICUBIC) image = np.asarray(pil_img, dtype=np.float32) / 255. image = torch.from_numpy(image).permute(2, 0, 1).contiguous().float() color = torch.ones_like(image) image = image * mask + color * (1 - mask) return image def __getitem__(self, index): # load data input_image_path = os.path.join(self.root_dir, self.paths[index]) '''background color, default: white''' bkg_color = [1.0, 1.0, 1.0] image_list = [] albedo_list = [] alpha_list = [] specular_list = [] diffuse_list = [] metallic_list = [] roughness_list = [] exist_comb_list = [] for subfolder in os.listdir(input_image_path): found_numeric_subfolder=False subfolder_path = os.path.join(input_image_path, subfolder) if os.path.isdir(subfolder_path) and '_' in subfolder and 'specular' not in subfolder and 'diffuse' not in subfolder: try: parts = subfolder.split('_') float(parts[0]) # 尝试将分隔符前后的字符串转换为浮点数 float(parts[1]) found_numeric_subfolder = True except ValueError: continue if found_numeric_subfolder: exist_comb_list.append(subfolder) selected_one_comb = random.choice(exist_comb_list) for idx in range(self.input_view_num): img_path = find_matching_files(os.path.join(input_image_path, selected_one_comb, 'rgb'), idx) albedo_path = img_path.replace('rgb', 'albedo') metallic_path = img_path.replace('rgb', 'metallic') roughness_path = img_path.replace('rgb', 'roughness') image, alpha = self.load_im(img_path, bkg_color) albedo = self.load_albedo(albedo_path, bkg_color, alpha) metallic,_ = self.load_mat(metallic_path, bkg_color) roughness,_ = self.load_mat(roughness_path, bkg_color) light_num = os.path.basename(img_path).split('_')[1].split('.')[0] light_path = os.path.join(self.light_dir, str(int(light_num)+1)) specular, diffuse = load_mipmap(light_path) image_list.append(image) alpha_list.append(alpha) albedo_list.append(albedo) metallic_list.append(metallic) roughness_list.append(roughness) specular_list.append(specular) diffuse_list.append(diffuse) images = torch.stack(image_list, dim=0).float() alphas = torch.stack(alpha_list, dim=0).float() albedo = torch.stack(albedo_list, dim=0).float() metallic = torch.stack(metallic_list, dim=0).float() roughness = torch.stack(roughness_list, dim=0).float() data = { 'input_images': images, 'input_alphas': alphas, 'input_c2ws': self.c2ws, 'input_Ks': self.Ks, 'input_albedos': albedo[:self.input_view_num], 'input_metallics': metallic[:self.input_view_num], 'input_roughness': roughness[:self.input_view_num], 'specular': specular_list[:self.input_view_num], 'diffuse': diffuse_list[:self.input_view_num], 'render_c2ws': self.render_c2ws, 'render_Ks': self.render_Ks, } return data if __name__ == '__main__': dataset = ObjaverseData() dataset.new(1)