# Copyright (c) Meta Platforms, Inc. and affiliates import logging import os import argparse import sys import numpy as np from collections import OrderedDict import torch from detectron2.checkpoint import DetectionCheckpointer from detectron2.config import get_cfg from detectron2.engine import default_argument_parser, default_setup, launch from detectron2.data import transforms as T logger = logging.getLogger("detectron2") sys.dont_write_bytecode = True sys.path.append(os.getcwd()) np.set_printoptions(suppress=True) from cubercnn.config import get_cfg_defaults from cubercnn.modeling.proposal_generator import RPNWithIgnore from cubercnn.modeling.roi_heads import ROIHeads3D from cubercnn.modeling.meta_arch import RCNN3D, build_model from cubercnn.modeling.backbone import build_dla_from_vision_fpn_backbone from cubercnn import util, vis def do_test(args, cfg, model): list_of_ims = util.list_files(os.path.join(args.input_folder, ''), '*') model.eval() focal_length = args.focal_length principal_point = args.principal_point thres = args.threshold output_dir = cfg.OUTPUT_DIR min_size = cfg.INPUT.MIN_SIZE_TEST max_size = cfg.INPUT.MAX_SIZE_TEST augmentations = T.AugmentationList([T.ResizeShortestEdge(min_size, max_size, "choice")]) util.mkdir_if_missing(output_dir) category_path = os.path.join(util.file_parts(args.config_file)[0], 'category_meta.json') # store locally if needed if category_path.startswith(util.CubeRCNNHandler.PREFIX): category_path = util.CubeRCNNHandler._get_local_path(util.CubeRCNNHandler, category_path) metadata = util.load_json(category_path) cats = metadata['thing_classes'] for path in list_of_ims: im_name = util.file_parts(path)[1] im = util.imread(path) if im is None: continue image_shape = im.shape[:2] # h, w h, w = image_shape if focal_length == 0: focal_length_ndc = 4.0 focal_length = focal_length_ndc * h / 2 if len(principal_point) == 0: px, py = w/2, h/2 else: px, py = principal_point K = np.array([ [focal_length, 0.0, px], [0.0, focal_length, py], [0.0, 0.0, 1.0] ]) is_ground = os.path.exists(f'datasets/ground_maps/{im_name}.jpg.npz') if is_ground: ground_map = np.load(f'datasets/ground_maps/{im_name}.jpg.npz')['mask'] depth_map = np.load(f'datasets/depth_maps/{im_name}.jpg.npz')['depth'] aug_input = T.AugInput(im) tfms = augmentations(aug_input) image = aug_input.image if is_ground: ground_map = tfms.apply_image(ground_map*1.0) ground_map = torch.as_tensor(ground_map) else: ground_map = None depth_map = tfms.apply_image(depth_map) # batched = [{ # 'image': torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1))).cuda(), # 'height': image_shape[0], 'width': image_shape[1], 'K': K # }] # first you must run the scripts to get the ground and depth map for the images batched = [{ 'image': torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1))), 'depth_map': torch.as_tensor(depth_map), 'ground_map': ground_map, 'height': image_shape[0], 'width': image_shape[1], 'K': K }] dets = model(batched)[0]['instances'] n_det = len(dets) meshes = [] meshes_text = [] if n_det > 0: for idx, (corners3D, center_cam, center_2D, dimensions, pose, score, cat_idx) in enumerate(zip( dets.pred_bbox3D, dets.pred_center_cam, dets.pred_center_2D, dets.pred_dimensions, dets.pred_pose, dets.scores, dets.pred_classes )): # skip if score < thres: continue cat = cats[cat_idx] bbox3D = center_cam.tolist() + dimensions.tolist() meshes_text.append('{} {:.2f}'.format(cat, score)) color = [c/255.0 for c in util.get_color(idx)] box_mesh = util.mesh_cuboid(bbox3D, pose.tolist(), color=color) meshes.append(box_mesh) print('File: {} with {} dets'.format(im_name, len(meshes))) if len(meshes) > 0: im_drawn_rgb, im_topdown, _ = vis.draw_scene_view(im, K, meshes, text=meshes_text, scale=im.shape[0], blend_weight=0.5, blend_weight_overlay=0.85) if args.display: im_concat = np.concatenate((im_drawn_rgb, im_topdown), axis=1) vis.imshow(im_concat) util.imwrite(im_drawn_rgb, os.path.join(output_dir, im_name+'_boxes.jpg')) util.imwrite(im_topdown, os.path.join(output_dir, im_name+'_novel.jpg')) else: util.imwrite(im, os.path.join(output_dir, im_name+'_boxes.jpg')) def setup(args): """ Create configs and perform basic setups. """ cfg = get_cfg() get_cfg_defaults(cfg) config_file = args.config_file # store locally if needed if config_file.startswith(util.CubeRCNNHandler.PREFIX): config_file = util.CubeRCNNHandler._get_local_path(util.CubeRCNNHandler, config_file) cfg.merge_from_file(config_file) cfg.merge_from_list(args.opts) cfg.freeze() default_setup(cfg, args) return cfg def main(args): cfg = setup(args) model = build_model(cfg) logger.info("Model:\n{}".format(model)) DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load( cfg.MODEL.WEIGHTS, resume=True ) with torch.no_grad(): do_test(args, cfg, model) if __name__ == "__main__": parser = argparse.ArgumentParser( epilog=None, formatter_class=argparse.RawDescriptionHelpFormatter, ) parser.add_argument("--config-file", default="", metavar="FILE", help="path to config file") parser.add_argument('--input-folder', type=str, help='list of image folders to process', required=True) parser.add_argument("--focal-length", type=float, default=0, help="focal length for image inputs (in px)") parser.add_argument("--principal-point", type=float, default=[], nargs=2, help="principal point for image inputs (in px)") parser.add_argument("--threshold", type=float, default=0.25, help="threshold on score for visualizing") parser.add_argument("--display", default=False, action="store_true", help="Whether to show the images in matplotlib",) parser.add_argument("--eval-only", default=True, action="store_true", help="perform evaluation only") parser.add_argument("--num-gpus", type=int, default=1, help="number of gpus *per machine*") parser.add_argument("--num-machines", type=int, default=1, help="total number of machines") parser.add_argument( "--machine-rank", type=int, default=0, help="the rank of this machine (unique per machine)" ) port = 2 ** 15 + 2 ** 14 + hash(os.getuid() if sys.platform != "win32" else 1) % 2 ** 14 parser.add_argument( "--dist-url", default="tcp://127.0.0.1:{}".format(port), help="initialization URL for pytorch distributed backend. See " "https://pytorch.org/docs/stable/distributed.html for details.", ) parser.add_argument( "opts", help="Modify config options by adding 'KEY VALUE' pairs at the end of the command. " "See config references at " "https://detectron2.readthedocs.io/modules/config.html#config-references", default=None, nargs=argparse.REMAINDER, ) args = parser.parse_args() print("Command Line Args:", args) launch( main, args.num_gpus, num_machines=args.num_machines, machine_rank=args.machine_rank, dist_url=args.dist_url, args=(args,), )