import os,sys sys.path.append("..") from configs.config_utils import CONFIG from models import get_model import torch import numpy as np import open3d as o3d import timm from PIL import Image from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize from simple_dataset import InTheWild_Dataset,classname_remap,classname_map try: from torchvision.transforms import InterpolationMode BICUBIC = InterpolationMode.BICUBIC except ImportError: BICUBIC = Image.BICUBIC import mcubes import trimesh from torch.utils.data import DataLoader def image_transform(n_px): return Compose([ Resize(n_px, interpolation=BICUBIC), CenterCrop(n_px), ToTensor(), Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), ]) MAX_IMG_LENGTH=5 #take up to 5 images as inputs ae_paths={ "chair":"../checkpoint/ae/chair/best-checkpoint.pth", "table":"../checkpoint/ae/table/best-checkpoint.pth", "cabinet":"../checkpoint/ae/cabinet/best-checkpoint.pth", "shelf":"../checkpoint/ae/shelf/best-checkpoint.pth", "sofa":"../checkpoint/ae/sofa/best-checkpoint.pth", "bed":"../checkpoint/ae/bed/best-checkpoint.pth" } dm_paths={ "chair":"../checkpoint/finetune_dm/chair/best-checkpoint.pth", "table":"../checkpoint/finetune_dm/table/best-checkpoint.pth", "cabinet":"../checkpoint/finetune_dm/cabinet/best-checkpoint.pth", "shelf":"../checkpoint/finetune_dm/shelf/best-checkpoint.pth", "sofa":"../checkpoint/finetune_dm/sofa/best-checkpoint.pth", "bed":"../checkpoint/finetune_dm/bed/best-checkpoint.pth" } def inference(ae_model,dm_model,data_batch,device,reso=256): density = reso gap = 2.2 / density x = np.linspace(-1.1, 1.1, int(density + 1)) y = np.linspace(-1.1, 1.1, int(density + 1)) z = np.linspace(-1.1, 1.1, int(density + 1)) xv, yv, zv = np.meshgrid(x, y, z, indexing='ij') grid = torch.from_numpy(np.stack([xv, yv, zv]).astype(np.float32)).view(3, -1).transpose(0, 1)[None].to(device, non_blocking=True) with torch.no_grad(): sample_input = dm_model.prepare_sample_data(data_batch) sampled_array = dm_model.sample(sample_input, num_steps=36).float() sampled_array = torch.nn.functional.interpolate(sampled_array, scale_factor=2, mode="bilinear") model_ids = data_batch['model_id'] tran_mats = data_batch['tran_mat'] output_meshes={} for j in range(sampled_array.shape[0]): grid_list = torch.split(grid, 128 ** 3, dim=1) output_list = [] with torch.no_grad(): for sub_grid in grid_list: output_list.append(ae_model.decode(sampled_array[j:j + 1], sub_grid)) output = torch.cat(output_list, dim=1) logits = output[j].detach() volume = logits.view(density + 1, density + 1, density + 1).cpu().numpy() verts, faces = mcubes.marching_cubes(volume, 0) verts *= gap verts -= 1.1 tran_mat = tran_mats[j].numpy() verts_homo = np.concatenate([verts, np.ones((verts.shape[0], 1))], axis=1) verts_inwrd = np.dot(verts_homo, tran_mat.T)[:, 0:3] m_inwrd = trimesh.Trimesh(verts_inwrd, faces[:, ::-1]) #transform the mesh into world coordinate output_meshes[model_ids[j]]=m_inwrd return output_meshes if __name__=="__main__": import argparse parser=argparse.ArgumentParser() parser.add_argument("--data_dir", type=str, default="../example_process_data") parser.add_argument('--scene_id', default="all", type=str) parser.add_argument("--save_dir", type=str,default="../example_output_data") args = parser.parse_args() config_path="../configs/finetune_triplane_diffusion.yaml" config=CONFIG(config_path).config '''creating save folder''' save_folder=os.path.join(args.save_dir,args.scene_id) os.makedirs(save_folder,exist_ok=True) '''prepare model''' device=torch.device("cuda") ae_config=config['model']['ae'] dm_config=config['model']['dm'] dm_model=get_model(dm_config).to(device) ae_model=get_model(ae_config).to(device) dm_model.eval() ae_model.eval() '''preparing data''' '''find out how many classes are there in the whole scene''' images_folder=os.path.join(args.data_dir,args.scene_id,"6_images") object_id_list=os.listdir(images_folder) object_class_list=[item.split("_")[0] for item in object_id_list] all_object_class=list(set(object_class_list)) exist_super_categories=[] for object_class in all_object_class: if object_class not in classname_remap: continue else: exist_super_categories.append(classname_remap[object_class]) #find which category specific models should be employed exist_super_categories=list(set(exist_super_categories)) for super_category in exist_super_categories: print("processing %s"%(super_category)) ae_ckpt=torch.load(ae_paths[super_category],map_location="cpu")["model"] dm_ckpt=torch.load(dm_paths[super_category],map_location="cpu")["model"] ae_model.load_state_dict(ae_ckpt) dm_model.load_state_dict(dm_ckpt) dataset = InTheWild_Dataset(data_dir=args.data_dir, scene_id=args.scene_id, category=super_category, max_n_imgs=5) dataloader=DataLoader( dataset=dataset, num_workers=1, batch_size=1, shuffle=False ) for data_batch in dataloader: output_meshes=inference(ae_model,dm_model,data_batch,device) #print(output_meshes) for model_id in output_meshes: mesh=output_meshes[model_id] save_path=os.path.join(save_folder,model_id+".ply") print("saving to %s"%(save_path)) mesh.export(save_path)