LASA / demo /api.py
HaolinLiu's picture
update files for demo
18bb538
raw
history blame
6.05 kB
import os,sys
sys.path.append("..")
from configs.config_utils import CONFIG
from models import get_model
import torch
import numpy as np
import open3d as o3d
import timm
from PIL import Image
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
from simple_dataset import InTheWild_Dataset,classname_remap,classname_map
try:
from torchvision.transforms import InterpolationMode
BICUBIC = InterpolationMode.BICUBIC
except ImportError:
BICUBIC = Image.BICUBIC
import mcubes
import trimesh
from torch.utils.data import DataLoader
def image_transform(n_px):
return Compose([
Resize(n_px, interpolation=BICUBIC),
CenterCrop(n_px),
ToTensor(),
Normalize((0.48145466, 0.4578275, 0.40821073),
(0.26862954, 0.26130258, 0.27577711)),
])
MAX_IMG_LENGTH=5 #take up to 5 images as inputs
ae_paths={
"chair":"../checkpoint/ae/chair/best-checkpoint.pth",
"table":"../checkpoint/ae/table/best-checkpoint.pth",
"cabinet":"../checkpoint/ae/cabinet/best-checkpoint.pth",
"shelf":"../checkpoint/ae/shelf/best-checkpoint.pth",
"sofa":"../checkpoint/ae/sofa/best-checkpoint.pth",
"bed":"../checkpoint/ae/bed/best-checkpoint.pth"
}
dm_paths={
"chair":"../checkpoint/finetune_dm/chair/best-checkpoint.pth",
"table":"../checkpoint/finetune_dm/table/best-checkpoint.pth",
"cabinet":"../checkpoint/finetune_dm/cabinet/best-checkpoint.pth",
"shelf":"../checkpoint/finetune_dm/shelf/best-checkpoint.pth",
"sofa":"../checkpoint/finetune_dm/sofa/best-checkpoint.pth",
"bed":"../checkpoint/finetune_dm/bed/best-checkpoint.pth"
}
def inference(ae_model,dm_model,data_batch,device,reso=256):
density = reso
gap = 2.2 / density
x = np.linspace(-1.1, 1.1, int(density + 1))
y = np.linspace(-1.1, 1.1, int(density + 1))
z = np.linspace(-1.1, 1.1, int(density + 1))
xv, yv, zv = np.meshgrid(x, y, z, indexing='ij')
grid = torch.from_numpy(np.stack([xv, yv, zv]).astype(np.float32)).view(3, -1).transpose(0, 1)[None].to(device,
non_blocking=True)
with torch.no_grad():
sample_input = dm_model.prepare_sample_data(data_batch)
sampled_array = dm_model.sample(sample_input, num_steps=36).float()
sampled_array = torch.nn.functional.interpolate(sampled_array, scale_factor=2, mode="bilinear")
model_ids = data_batch['model_id']
tran_mats = data_batch['tran_mat']
output_meshes={}
for j in range(sampled_array.shape[0]):
grid_list = torch.split(grid, 128 ** 3, dim=1)
output_list = []
with torch.no_grad():
for sub_grid in grid_list:
output_list.append(ae_model.decode(sampled_array[j:j + 1], sub_grid))
output = torch.cat(output_list, dim=1)
logits = output[j].detach()
volume = logits.view(density + 1, density + 1, density + 1).cpu().numpy()
verts, faces = mcubes.marching_cubes(volume, 0)
verts *= gap
verts -= 1.1
tran_mat = tran_mats[j].numpy()
verts_homo = np.concatenate([verts, np.ones((verts.shape[0], 1))], axis=1)
verts_inwrd = np.dot(verts_homo, tran_mat.T)[:, 0:3]
m_inwrd = trimesh.Trimesh(verts_inwrd, faces[:, ::-1]) #transform the mesh into world coordinate
output_meshes[model_ids[j]]=m_inwrd
return output_meshes
if __name__=="__main__":
import argparse
parser=argparse.ArgumentParser()
parser.add_argument("--data_dir", type=str, default="../example_process_data")
parser.add_argument('--scene_id', default="all", type=str)
parser.add_argument("--save_dir", type=str,default="../example_output_data")
args = parser.parse_args()
config_path="../configs/finetune_triplane_diffusion.yaml"
config=CONFIG(config_path).config
'''creating save folder'''
save_folder=os.path.join(args.save_dir,args.scene_id)
os.makedirs(save_folder,exist_ok=True)
'''prepare model'''
device=torch.device("cuda")
ae_config=config['model']['ae']
dm_config=config['model']['dm']
dm_model=get_model(dm_config).to(device)
ae_model=get_model(ae_config).to(device)
dm_model.eval()
ae_model.eval()
'''preparing data'''
'''find out how many classes are there in the whole scene'''
images_folder=os.path.join(args.data_dir,args.scene_id,"6_images")
object_id_list=os.listdir(images_folder)
object_class_list=[item.split("_")[0] for item in object_id_list]
all_object_class=list(set(object_class_list))
exist_super_categories=[]
for object_class in all_object_class:
if object_class not in classname_remap:
continue
else:
exist_super_categories.append(classname_remap[object_class]) #find which category specific models should be employed
exist_super_categories=list(set(exist_super_categories))
for super_category in exist_super_categories:
print("processing %s"%(super_category))
ae_ckpt=torch.load(ae_paths[super_category],map_location="cpu")["model"]
dm_ckpt=torch.load(dm_paths[super_category],map_location="cpu")["model"]
ae_model.load_state_dict(ae_ckpt)
dm_model.load_state_dict(dm_ckpt)
dataset = InTheWild_Dataset(data_dir=args.data_dir, scene_id=args.scene_id, category=super_category, max_n_imgs=5)
dataloader=DataLoader(
dataset=dataset,
num_workers=1,
batch_size=1,
shuffle=False
)
for data_batch in dataloader:
output_meshes=inference(ae_model,dm_model,data_batch,device)
#print(output_meshes)
for model_id in output_meshes:
mesh=output_meshes[model_id]
save_path=os.path.join(save_folder,model_id+".ply")
print("saving to %s"%(save_path))
mesh.export(save_path)