|
import os,sys |
|
sys.path.append("..") |
|
from configs.config_utils import CONFIG |
|
from models import get_model |
|
import torch |
|
import numpy as np |
|
import open3d as o3d |
|
import timm |
|
from PIL import Image |
|
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize |
|
from simple_dataset import InTheWild_Dataset,classname_remap,classname_map |
|
try: |
|
from torchvision.transforms import InterpolationMode |
|
BICUBIC = InterpolationMode.BICUBIC |
|
except ImportError: |
|
BICUBIC = Image.BICUBIC |
|
import mcubes |
|
import trimesh |
|
from torch.utils.data import DataLoader |
|
|
|
def image_transform(n_px): |
|
return Compose([ |
|
Resize(n_px, interpolation=BICUBIC), |
|
CenterCrop(n_px), |
|
ToTensor(), |
|
Normalize((0.48145466, 0.4578275, 0.40821073), |
|
(0.26862954, 0.26130258, 0.27577711)), |
|
]) |
|
|
|
MAX_IMG_LENGTH=5 |
|
|
|
ae_paths={ |
|
"chair":"../checkpoint/ae/chair/best-checkpoint.pth", |
|
"table":"../checkpoint/ae/table/best-checkpoint.pth", |
|
"cabinet":"../checkpoint/ae/cabinet/best-checkpoint.pth", |
|
"shelf":"../checkpoint/ae/shelf/best-checkpoint.pth", |
|
"sofa":"../checkpoint/ae/sofa/best-checkpoint.pth", |
|
"bed":"../checkpoint/ae/bed/best-checkpoint.pth" |
|
} |
|
dm_paths={ |
|
"chair":"../checkpoint/finetune_dm/chair/best-checkpoint.pth", |
|
"table":"../checkpoint/finetune_dm/table/best-checkpoint.pth", |
|
"cabinet":"../checkpoint/finetune_dm/cabinet/best-checkpoint.pth", |
|
"shelf":"../checkpoint/finetune_dm/shelf/best-checkpoint.pth", |
|
"sofa":"../checkpoint/finetune_dm/sofa/best-checkpoint.pth", |
|
"bed":"../checkpoint/finetune_dm/bed/best-checkpoint.pth" |
|
} |
|
|
|
def inference(ae_model,dm_model,data_batch,device,reso=256): |
|
density = reso |
|
gap = 2.2 / density |
|
x = np.linspace(-1.1, 1.1, int(density + 1)) |
|
y = np.linspace(-1.1, 1.1, int(density + 1)) |
|
z = np.linspace(-1.1, 1.1, int(density + 1)) |
|
xv, yv, zv = np.meshgrid(x, y, z, indexing='ij') |
|
grid = torch.from_numpy(np.stack([xv, yv, zv]).astype(np.float32)).view(3, -1).transpose(0, 1)[None].to(device, |
|
non_blocking=True) |
|
with torch.no_grad(): |
|
sample_input = dm_model.prepare_sample_data(data_batch) |
|
sampled_array = dm_model.sample(sample_input, num_steps=36).float() |
|
sampled_array = torch.nn.functional.interpolate(sampled_array, scale_factor=2, mode="bilinear") |
|
|
|
model_ids = data_batch['model_id'] |
|
tran_mats = data_batch['tran_mat'] |
|
|
|
output_meshes={} |
|
|
|
for j in range(sampled_array.shape[0]): |
|
grid_list = torch.split(grid, 128 ** 3, dim=1) |
|
output_list = [] |
|
with torch.no_grad(): |
|
for sub_grid in grid_list: |
|
output_list.append(ae_model.decode(sampled_array[j:j + 1], sub_grid)) |
|
output = torch.cat(output_list, dim=1) |
|
logits = output[j].detach() |
|
|
|
volume = logits.view(density + 1, density + 1, density + 1).cpu().numpy() |
|
verts, faces = mcubes.marching_cubes(volume, 0) |
|
|
|
verts *= gap |
|
verts -= 1.1 |
|
|
|
tran_mat = tran_mats[j].numpy() |
|
verts_homo = np.concatenate([verts, np.ones((verts.shape[0], 1))], axis=1) |
|
verts_inwrd = np.dot(verts_homo, tran_mat.T)[:, 0:3] |
|
m_inwrd = trimesh.Trimesh(verts_inwrd, faces[:, ::-1]) |
|
|
|
output_meshes[model_ids[j]]=m_inwrd |
|
return output_meshes |
|
|
|
if __name__=="__main__": |
|
import argparse |
|
parser=argparse.ArgumentParser() |
|
parser.add_argument("--data_dir", type=str, default="../example_process_data") |
|
parser.add_argument('--scene_id', default="all", type=str) |
|
parser.add_argument("--save_dir", type=str,default="../example_output_data") |
|
args = parser.parse_args() |
|
|
|
config_path="../configs/finetune_triplane_diffusion.yaml" |
|
config=CONFIG(config_path).config |
|
|
|
'''creating save folder''' |
|
save_folder=os.path.join(args.save_dir,args.scene_id) |
|
os.makedirs(save_folder,exist_ok=True) |
|
|
|
'''prepare model''' |
|
device=torch.device("cuda") |
|
ae_config=config['model']['ae'] |
|
dm_config=config['model']['dm'] |
|
dm_model=get_model(dm_config).to(device) |
|
ae_model=get_model(ae_config).to(device) |
|
dm_model.eval() |
|
ae_model.eval() |
|
|
|
'''preparing data''' |
|
'''find out how many classes are there in the whole scene''' |
|
images_folder=os.path.join(args.data_dir,args.scene_id,"6_images") |
|
object_id_list=os.listdir(images_folder) |
|
object_class_list=[item.split("_")[0] for item in object_id_list] |
|
all_object_class=list(set(object_class_list)) |
|
|
|
exist_super_categories=[] |
|
for object_class in all_object_class: |
|
if object_class not in classname_remap: |
|
continue |
|
else: |
|
exist_super_categories.append(classname_remap[object_class]) |
|
exist_super_categories=list(set(exist_super_categories)) |
|
for super_category in exist_super_categories: |
|
print("processing %s"%(super_category)) |
|
ae_ckpt=torch.load(ae_paths[super_category],map_location="cpu")["model"] |
|
dm_ckpt=torch.load(dm_paths[super_category],map_location="cpu")["model"] |
|
ae_model.load_state_dict(ae_ckpt) |
|
dm_model.load_state_dict(dm_ckpt) |
|
dataset = InTheWild_Dataset(data_dir=args.data_dir, scene_id=args.scene_id, category=super_category, max_n_imgs=5) |
|
dataloader=DataLoader( |
|
dataset=dataset, |
|
num_workers=1, |
|
batch_size=1, |
|
shuffle=False |
|
) |
|
for data_batch in dataloader: |
|
output_meshes=inference(ae_model,dm_model,data_batch,device) |
|
|
|
for model_id in output_meshes: |
|
mesh=output_meshes[model_id] |
|
save_path=os.path.join(save_folder,model_id+".ply") |
|
print("saving to %s"%(save_path)) |
|
mesh.export(save_path) |
|
|
|
|
|
|