import argparse from torch.utils.data import DataLoader import lightning as L from datasets import dataset_dict from model import PL_RelPose, keypoint_dict from configs.default import get_cfg_defaults def main(args): config = get_cfg_defaults() config.merge_from_file(args.config) task = config.DATASET.TASK dataset = config.DATASET.DATA_SOURCE batch_size = config.TRAINER.BATCH_SIZE num_workers = config.TRAINER.NUM_WORKERS pin_memory = config.TRAINER.PIN_MEMORY test_num_keypoints = config.MODEL.TEST_NUM_KEYPOINTS build_fn = dataset_dict[task][dataset] testset = build_fn('test', config) testloader = DataLoader(testset, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory) pl_relpose = PL_RelPose.load_from_checkpoint(args.ckpt_path) pl_relpose.extractor = keypoint_dict[pl_relpose.hparams['features']](max_num_keypoints=test_num_keypoints, detection_threshold=0.0).eval() trainer = L.Trainer( devices=[0], ) trainer.test(pl_relpose, dataloaders=testloader) def get_parser(): parser = argparse.ArgumentParser() parser.add_argument('config', type=str, help='.yaml configure file path') parser.add_argument('ckpt_path', type=str) return parser if __name__ == "__main__": parser = get_parser() args = parser.parse_args() main(args)