import os import sys os.environ["TOKENIZERS_PARALLELISM"] = "false" project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) sys.path.insert(0, project_root) import gc import resource import argparse import cv2 import tqdm import torch from torch.multiprocessing import Pool, set_start_method import mmcv from mmcv.transforms import Compose from mmengine.utils import track_iter_progress from mmdet.apis import init_detector from mmdet.registry import VISUALIZERS from mmcv.ops.nms import batched_nms import masa from masa.apis import inference_masa, init_masa, inference_detector, build_test_pipeline from masa.models.sam import SamPredictor, sam_model_registry from utils import filter_and_update_tracks import warnings warnings.filterwarnings('ignore') # Ensure the right start method for multiprocessing try: set_start_method('spawn') except RuntimeError: pass def set_file_descriptor_limit(limit): soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE) resource.setrlimit(resource.RLIMIT_NOFILE, (limit, hard)) # Set the file descriptor limit to 65536 set_file_descriptor_limit(65536) def visualize_frame(args, visualizer, frame, track_result, frame_idx, fps=None): visualizer.add_datasample( name='video_' + str(frame_idx), image=frame[:, :, ::-1], data_sample=track_result[0], draw_gt=False, show=False, out_file=None, pred_score_thr=args.score_thr, fps=fps,) frame = visualizer.get_image() gc.collect() return frame def parse_args(): parser = argparse.ArgumentParser(description='MASA video demo') parser.add_argument('video', help='Video file') parser.add_argument('--det_config', help='Detector Config file') parser.add_argument('--masa_config', help='Masa Config file') parser.add_argument('--det_checkpoint', help='Detector Checkpoint file') parser.add_argument('--masa_checkpoint', help='Masa Checkpoint file') parser.add_argument( '--device', default='cuda:0', help='Device used for inference') parser.add_argument('--score-thr', type=float, default=0.2, help='Bbox score threshold') parser.add_argument('--out', type=str, help='Output video file') parser.add_argument('--save_dir', type=str, help='Output for video frames') parser.add_argument('--texts', help='text prompt') parser.add_argument('--line_width', type=int, default=5, help='Line width') parser.add_argument('--unified', action='store_true', help='Use unified model, which means the masa adapter is built upon the detector model.') parser.add_argument('--detector_type', type=str, default='mmdet', help='Choose detector type') parser.add_argument('--fp16', action='store_true', help='Activation fp16 mode') parser.add_argument('--no-post', action='store_true', help='Do not post-process the results ') parser.add_argument('--show_fps', action='store_true', help='Visualize the fps') parser.add_argument('--sam_mask', action='store_true', help='Use SAM to generate mask for segmentation tracking') parser.add_argument('--sam_path', type=str, default='saved_models/pretrain_weights/sam_vit_h_4b8939.pth', help='Default path for SAM models') parser.add_argument('--sam_type', type=str, default='vit_h', help='Default type for SAM models') parser.add_argument( '--wait-time', type=float, default=1, help='The interval of show (s), 0 is block') args = parser.parse_args() return args def main(): args = parse_args() assert args.out, \ ('Please specify at least one operation (save the ' 'video) with the argument "--out" ') # build the model from a config file and a checkpoint file if args.unified: masa_model = init_masa(args.masa_config, args.masa_checkpoint, device=args.device) else: det_model = init_detector(args.det_config, args.det_checkpoint, palette='random', device=args.device) masa_model = init_masa(args.masa_config, args.masa_checkpoint, device=args.device) # build test pipeline det_model.cfg.test_dataloader.dataset.pipeline[ 0].type = 'mmdet.LoadImageFromNDArray' test_pipeline = Compose(det_model.cfg.test_dataloader.dataset.pipeline) if args.sam_mask: print('Loading SAM model...') device = args.device sam_model = sam_model_registry[args.sam_type](args.sam_path) sam_predictor = SamPredictor(sam_model.to(device)) video_reader = mmcv.VideoReader(args.video) video_writer = None #### parsing the text input texts = args.texts if texts is not None: masa_test_pipeline = build_test_pipeline(masa_model.cfg, with_text=True) else: masa_test_pipeline = build_test_pipeline(masa_model.cfg) if texts is not None: masa_model.cfg.visualizer['texts'] = texts else: masa_model.cfg.visualizer['texts'] = det_model.dataset_meta['classes'] # init visualizer masa_model.cfg.visualizer['save_dir'] = args.save_dir masa_model.cfg.visualizer['line_width'] = args.line_width if args.sam_mask: masa_model.cfg.visualizer['alpha'] = 0.5 visualizer = VISUALIZERS.build(masa_model.cfg.visualizer) if args.out: fourcc = cv2.VideoWriter_fourcc(*'mp4v') video_writer = cv2.VideoWriter( args.out, fourcc, video_reader.fps, (video_reader.width, video_reader.height)) frame_idx = 0 instances_list = [] frames = [] fps_list = [] for frame in track_iter_progress((video_reader, len(video_reader))): # unified models mean that masa build upon and reuse the foundation model's backbone features for tracking if args.unified: track_result = inference_masa(masa_model, frame, frame_id=frame_idx, video_len=len(video_reader), test_pipeline=masa_test_pipeline, text_prompt=texts, fp16=args.fp16, detector_type=args.detector_type, show_fps=args.show_fps) if args.show_fps: track_result, fps = track_result else: if args.detector_type == 'mmdet': result = inference_detector(det_model, frame, text_prompt=texts, test_pipeline=test_pipeline, fp16=args.fp16) # Perfom inter-class NMS to remove nosiy detections det_bboxes, keep_idx = batched_nms(boxes=result.pred_instances.bboxes, scores=result.pred_instances.scores, idxs=result.pred_instances.labels, class_agnostic=True, nms_cfg=dict(type='nms', iou_threshold=0.5, class_agnostic=True, split_thr=100000)) det_bboxes = torch.cat([det_bboxes, result.pred_instances.scores[keep_idx].unsqueeze(1)], dim=1) det_labels = result.pred_instances.labels[keep_idx] track_result = inference_masa(masa_model, frame, frame_id=frame_idx, video_len=len(video_reader), test_pipeline=masa_test_pipeline, det_bboxes=det_bboxes, det_labels=det_labels, fp16=args.fp16, show_fps=args.show_fps) if args.show_fps: track_result, fps = track_result frame_idx += 1 if 'masks' in track_result[0].pred_track_instances: if len(track_result[0].pred_track_instances.masks) >0: track_result[0].pred_track_instances.masks = torch.stack(track_result[0].pred_track_instances.masks, dim=0) track_result[0].pred_track_instances.masks = track_result[0].pred_track_instances.masks.cpu().numpy() track_result[0].pred_track_instances.bboxes = track_result[0].pred_track_instances.bboxes.to(torch.float32) instances_list.append(track_result.to('cpu')) frames.append(frame) if args.show_fps: fps_list.append(fps) if not args.no_post: instances_list = filter_and_update_tracks(instances_list, (frame.shape[1], frame.shape[0])) if args.sam_mask: print('Start to generate mask using SAM!') for idx, (frame, track_result) in tqdm.tqdm(enumerate(zip(frames, instances_list))): track_result = track_result.to(device) track_result[0].pred_track_instances.instances_id = track_result[0].pred_track_instances.instances_id.to(device) track_result[0].pred_track_instances = track_result[0].pred_track_instances[(track_result[0].pred_track_instances.scores.float() > args.score_thr).to(device)] input_boxes = track_result[0].pred_track_instances.bboxes if len(input_boxes) == 0: continue sam_predictor.set_image(frame) transformed_boxes = sam_predictor.transform.apply_boxes_torch(input_boxes, frame.shape[:2]) masks, _, _ = sam_predictor.predict_torch( point_coords=None, point_labels=None, boxes=transformed_boxes, multimask_output=False, ) track_result[0].pred_track_instances.masks = masks.squeeze(1).cpu().numpy() instances_list[idx] = track_result if args.out: print('Start to visualize the results...') num_cores = max(1, min(os.cpu_count() - 1, 16)) print('Using {} cores for visualization'.format(num_cores)) if args.show_fps: with Pool(processes=num_cores) as pool: frames = pool.starmap( visualize_frame, [(args, visualizer, frame, track_result.to('cpu'), idx, fps) for idx, (frame, fps, track_result) in enumerate(zip(frames, fps_list, instances_list))] ) else: with Pool(processes=num_cores) as pool: frames = pool.starmap( visualize_frame, [(args, visualizer, frame, track_result.to('cpu'), idx) for idx, (frame, track_result) in enumerate(zip(frames, instances_list))] ) for frame in frames: if args.out: video_writer.write(frame[:, :, ::-1]) if video_writer: video_writer.release() print('Done') if __name__ == '__main__': main()