Weiyu Liu
add demo
8c02843
import os
import argparse
import torch
import numpy as np
import pytorch_lightning as pl
from omegaconf import OmegaConf
from StructDiffusion.data.semantic_arrangement import SemanticArrangementDataset
from StructDiffusion.language.tokenizer import Tokenizer
from StructDiffusion.models.pl_models import ConditionalPoseDiffusionModel
from StructDiffusion.diffusion.sampler import Sampler
from StructDiffusion.diffusion.pose_conversion import get_struct_objs_poses
from StructDiffusion.utils.files import get_checkpoint_path_from_dir
from StructDiffusion.utils.batch_inference import move_pc_and_create_scene_simple, visualize_batch_pcs
def main(args, cfg):
pl.seed_everything(args.eval_random_seed)
device = (torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"))
checkpoint_dir = os.path.join(cfg.WANDB.save_dir, cfg.WANDB.project, args.checkpoint_id, "checkpoints")
checkpoint_path = get_checkpoint_path_from_dir(checkpoint_dir)
if args.eval_mode == "infer":
tokenizer = Tokenizer(cfg.DATASET.vocab_dir)
# override ignore_rgb for visualization
cfg.DATASET.ignore_rgb = False
dataset = SemanticArrangementDataset(split="test", tokenizer=tokenizer, **cfg.DATASET)
sampler = Sampler(ConditionalPoseDiffusionModel, checkpoint_path, device)
data_idxs = np.random.permutation(len(dataset))
for di in data_idxs:
raw_datum = dataset.get_raw_data(di)
print(tokenizer.convert_structure_params_to_natural_language(raw_datum["sentence"]))
datum = dataset.convert_to_tensors(raw_datum, tokenizer)
batch = dataset.single_datum_to_batch(datum, args.num_samples, device, inference_mode=True)
num_poses = datum["goal_poses"].shape[0]
xs = sampler.sample(batch, num_poses)
struct_pose, pc_poses_in_struct = get_struct_objs_poses(xs[0])
new_obj_xyzs = move_pc_and_create_scene_simple(batch["pcs"], struct_pose, pc_poses_in_struct)
visualize_batch_pcs(new_obj_xyzs, args.num_samples, limit_B=10, trimesh=True)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="infer")
parser.add_argument("--base_config_file", help='base config yaml file',
default='../configs/base.yaml',
type=str)
parser.add_argument("--config_file", help='config yaml file',
default='../configs/conditional_pose_diffusion.yaml',
type=str)
parser.add_argument("--checkpoint_id",
default="ConditionalPoseDiffusion",
type=str)
parser.add_argument("--eval_mode",
default="infer",
type=str)
parser.add_argument("--eval_random_seed",
default=42,
type=int)
parser.add_argument("--num_samples",
default=10,
type=int)
args = parser.parse_args()
base_cfg = OmegaConf.load(args.base_config_file)
cfg = OmegaConf.load(args.config_file)
cfg = OmegaConf.merge(base_cfg, cfg)
main(args, cfg)