import os import argparse import torch import numpy as np import pytorch_lightning as pl from omegaconf import OmegaConf from StructDiffusion.data.semantic_arrangement import SemanticArrangementDataset from StructDiffusion.language.tokenizer import Tokenizer from StructDiffusion.models.pl_models import ConditionalPoseDiffusionModel from StructDiffusion.diffusion.sampler import Sampler from StructDiffusion.diffusion.pose_conversion import get_struct_objs_poses from StructDiffusion.utils.files import get_checkpoint_path_from_dir from StructDiffusion.utils.batch_inference import move_pc_and_create_scene_simple, visualize_batch_pcs def main(args, cfg): pl.seed_everything(args.eval_random_seed) device = (torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")) checkpoint_dir = os.path.join(cfg.WANDB.save_dir, cfg.WANDB.project, args.checkpoint_id, "checkpoints") checkpoint_path = get_checkpoint_path_from_dir(checkpoint_dir) if args.eval_mode == "infer": tokenizer = Tokenizer(cfg.DATASET.vocab_dir) # override ignore_rgb for visualization cfg.DATASET.ignore_rgb = False dataset = SemanticArrangementDataset(split="test", tokenizer=tokenizer, **cfg.DATASET) sampler = Sampler(ConditionalPoseDiffusionModel, checkpoint_path, device) data_idxs = np.random.permutation(len(dataset)) for di in data_idxs: raw_datum = dataset.get_raw_data(di) print(tokenizer.convert_structure_params_to_natural_language(raw_datum["sentence"])) datum = dataset.convert_to_tensors(raw_datum, tokenizer) batch = dataset.single_datum_to_batch(datum, args.num_samples, device, inference_mode=True) num_poses = datum["goal_poses"].shape[0] xs = sampler.sample(batch, num_poses) struct_pose, pc_poses_in_struct = get_struct_objs_poses(xs[0]) new_obj_xyzs = move_pc_and_create_scene_simple(batch["pcs"], struct_pose, pc_poses_in_struct) visualize_batch_pcs(new_obj_xyzs, args.num_samples, limit_B=10, trimesh=True) if __name__ == "__main__": parser = argparse.ArgumentParser(description="infer") parser.add_argument("--base_config_file", help='base config yaml file', default='../configs/base.yaml', type=str) parser.add_argument("--config_file", help='config yaml file', default='../configs/conditional_pose_diffusion.yaml', type=str) parser.add_argument("--checkpoint_id", default="ConditionalPoseDiffusion", type=str) parser.add_argument("--eval_mode", default="infer", type=str) parser.add_argument("--eval_random_seed", default=42, type=int) parser.add_argument("--num_samples", default=10, type=int) args = parser.parse_args() base_cfg = OmegaConf.load(args.base_config_file) cfg = OmegaConf.load(args.config_file) cfg = OmegaConf.merge(base_cfg, cfg) main(args, cfg)