#!/usr/bin/env python from __future__ import annotations import argparse import os import pathlib import subprocess import sys if os.getenv('SYSTEM') == 'spaces': import mim mim.uninstall('mmcv-full', confirm_yes=True) mim.install('mmcv-full==1.5.0', is_yes=True) subprocess.run('pip uninstall -y opencv-python'.split()) subprocess.run('pip uninstall -y opencv-python-headless'.split()) subprocess.run('pip install opencv-python-headless==4.5.5.64'.split()) with open('patch') as f: subprocess.run('patch -p1'.split(), cwd='CBNetV2', stdin=f) subprocess.run('mv palette.py CBNetV2/mmdet/core/visualization/'.split()) import gradio as gr import numpy as np import torch import torch.nn as nn sys.path.insert(0, 'CBNetV2/') from mmdet.apis import inference_detector, init_detector DESCRIPTION = '''# CBNetV2 This is an unofficial demo for [https://github.com/VDIGPKU/CBNetV2](https://github.com/VDIGPKU/CBNetV2).''' FOOTER = 'visitor badge' def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument('--device', type=str, default='cpu') parser.add_argument('--theme', type=str) parser.add_argument('--share', action='store_true') parser.add_argument('--port', type=int) parser.add_argument('--disable-queue', dest='enable_queue', action='store_false') return parser.parse_args() class Model: def __init__(self, device: str | torch.device): self.device = torch.device(device) self.models = self._load_models() self.model_name = 'Improved HTC (DB-Swin-B)' def _load_models(self) -> dict[str, nn.Module]: model_dict = { 'Faster R-CNN (DB-ResNet50)': { 'config': 'CBNetV2/configs/cbnet/faster_rcnn_cbv2d1_r50_fpn_1x_coco.py', 'model': 'https://github.com/CBNetwork/storage/releases/download/v1.0.0/faster_rcnn_cbv2d1_r50_fpn_1x_coco.pth.zip', }, 'Mask R-CNN (DB-Swin-T)': { 'config': 'CBNetV2/configs/cbnet/mask_rcnn_cbv2_swin_tiny_patch4_window7_mstrain_480-800_adamw_3x_coco.py', 'model': 'https://github.com/CBNetwork/storage/releases/download/v1.0.0/mask_rcnn_cbv2_swin_tiny_patch4_window7_mstrain_480-800_adamw_3x_coco.pth.zip', }, # 'Cascade Mask R-CNN (DB-Swin-S)': { # 'config': # 'CBNetV2/configs/cbnet/cascade_mask_rcnn_cbv2_swin_small_patch4_window7_mstrain_400-1400_adamw_3x_coco.py', # 'model': # 'https://github.com/CBNetwork/storage/releases/download/v1.0.0/cascade_mask_rcnn_cbv2_swin_small_patch4_window7_mstrain_400-1400_adamw_3x_coco.pth.zip', # }, 'Improved HTC (DB-Swin-B)': { 'config': 'CBNetV2/configs/cbnet/htc_cbv2_swin_base_patch4_window7_mstrain_400-1400_giou_4conv1f_adamw_20e_coco.py', 'model': 'https://github.com/CBNetwork/storage/releases/download/v1.0.0/htc_cbv2_swin_base22k_patch4_window7_mstrain_400-1400_giou_4conv1f_adamw_20e_coco.pth.zip', }, 'Improved HTC (DB-Swin-L)': { 'config': 'CBNetV2/configs/cbnet/htc_cbv2_swin_large_patch4_window7_mstrain_400-1400_giou_4conv1f_adamw_1x_coco.py', 'model': 'https://github.com/CBNetwork/storage/releases/download/v1.0.0/htc_cbv2_swin_large22k_patch4_window7_mstrain_400-1400_giou_4conv1f_adamw_1x_coco.pth.zip', }, 'Improved HTC (DB-Swin-L (TTA))': { 'config': 'CBNetV2/configs/cbnet/htc_cbv2_swin_large_patch4_window7_mstrain_400-1400_giou_4conv1f_adamw_1x_coco.py', 'model': 'https://github.com/CBNetwork/storage/releases/download/v1.0.0/htc_cbv2_swin_large22k_patch4_window7_mstrain_400-1400_giou_4conv1f_adamw_1x_coco.pth.zip', }, } weight_dir = pathlib.Path('weights') weight_dir.mkdir(exist_ok=True) def _download(model_name: str, out_dir: pathlib.Path) -> None: import zipfile model_url = model_dict[model_name]['model'] zip_name = model_url.split('/')[-1] out_path = out_dir / zip_name if out_path.exists(): return torch.hub.download_url_to_file(model_url, out_path) with zipfile.ZipFile(out_path) as f: f.extractall(out_dir) def _get_model_path(model_name: str) -> str: model_url = model_dict[model_name]['model'] model_name = model_url.split('/')[-1][:-4] return (weight_dir / model_name).as_posix() for model_name in model_dict: _download(model_name, weight_dir) models = { key: init_detector(dic['config'], _get_model_path(key), device=self.device) for key, dic in model_dict.items() } return models def set_model_name(self, name: str) -> None: self.model_name = name def detect_and_visualize( self, image: np.ndarray, score_threshold: float) -> tuple[list[np.ndarray], np.ndarray]: out = self.detect(image) vis = self.visualize_detection_results(image, out, score_threshold) return out, vis def detect(self, image: np.ndarray) -> list[np.ndarray]: image = image[:, :, ::-1] # RGB -> BGR model = self.models[self.model_name] out = inference_detector(model, image) return out def visualize_detection_results( self, image: np.ndarray, detection_results: list[np.ndarray], score_threshold: float = 0.3) -> np.ndarray: image = image[:, :, ::-1] # RGB -> BGR model = self.models[self.model_name] vis = model.show_result(image, detection_results, score_thr=score_threshold, bbox_color=None, text_color=(200, 200, 200), mask_color=None) return vis[:, :, ::-1] # BGR -> RGB def set_example_image(example: list) -> dict: return gr.Image.update(value=example[0]) def main(): args = parse_args() model = Model(args.device) with gr.Blocks(theme=args.theme, css='style.css') as demo: gr.Markdown(DESCRIPTION) with gr.Row(): with gr.Column(): with gr.Row(): input_image = gr.Image(label='Input Image', type='numpy') with gr.Row(): detector_name = gr.Dropdown(list(model.models.keys()), value=model.model_name, label='Detector') with gr.Row(): detect_button = gr.Button(value='Detect') detection_results = gr.Variable() with gr.Column(): with gr.Row(): detection_visualization = gr.Image( label='Detection Result', type='numpy') with gr.Row(): visualization_score_threshold = gr.Slider( 0, 1, step=0.05, value=0.3, label='Visualization Score Threshold') with gr.Row(): redraw_button = gr.Button(value='Redraw') with gr.Row(): paths = sorted(pathlib.Path('images').rglob('*.jpg')) example_images = gr.Dataset(components=[input_image], samples=[[path.as_posix()] for path in paths]) gr.Markdown(FOOTER) detector_name.change(fn=model.set_model_name, inputs=[detector_name], outputs=None) detect_button.click(fn=model.detect_and_visualize, inputs=[ input_image, visualization_score_threshold, ], outputs=[ detection_results, detection_visualization, ]) redraw_button.click(fn=model.visualize_detection_results, inputs=[ input_image, detection_results, visualization_score_threshold, ], outputs=[detection_visualization]) example_images.click(fn=set_example_image, inputs=[example_images], outputs=[input_image]) demo.launch( enable_queue=args.enable_queue, server_port=args.port, share=args.share, ) if __name__ == '__main__': main()