Spaces:
Runtime error
Runtime error
import argparse | |
import requests | |
import logging | |
import os | |
import gradio as gr | |
import numpy as np | |
import cv2 | |
import torch | |
import torch.nn as nn | |
from PIL import Image | |
from torchvision import transforms | |
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD | |
from timm.data import create_transform | |
from config import get_config | |
from collections import OrderedDict | |
import detectron2.utils.comm as comm | |
from detectron2.checkpoint import DetectionCheckpointer | |
from detectron2.config import get_cfg | |
from detectron2.data import MetadataCatalog | |
from detectron2.engine import DefaultTrainer as Trainer | |
from detectron2.engine import default_argument_parser, default_setup, hooks, launch | |
from detectron2.evaluation import ( | |
CityscapesInstanceEvaluator, | |
CityscapesSemSegEvaluator, | |
COCOEvaluator, | |
COCOPanopticEvaluator, | |
DatasetEvaluators, | |
LVISEvaluator, | |
PascalVOCDetectionEvaluator, | |
SemSegEvaluator, | |
verify_results, | |
FLICKR30KEvaluator, | |
) | |
from detectron2.modeling import GeneralizedRCNNWithTTA | |
def parse_option(): | |
parser = argparse.ArgumentParser('RegionCLIP demo script', add_help=False) | |
parser.add_argument('--config-file', type=str, default="configs/CLIP_fast_rcnn_R_50_C4.yaml", metavar="FILE", help='path to config file', ) | |
args, unparsed = parser.parse_known_args() | |
return args | |
def build_transforms(img_size, center_crop=True): | |
t = [] | |
if center_crop: | |
size = int((256 / 224) * img_size) | |
t.append( | |
transforms.Resize(size) | |
) | |
t.append( | |
transforms.CenterCrop(img_size) | |
) | |
else: | |
t.append( | |
transforms.Resize(img_size) | |
) | |
t.append(transforms.ToTensor()) | |
return transforms.Compose(t) | |
def setup(args): | |
""" | |
Create configs and perform basic setups. | |
""" | |
cfg = get_cfg() | |
cfg.merge_from_file(args.config_file) | |
cfg.freeze() | |
default_setup(cfg, args) | |
return cfg | |
''' | |
build model | |
''' | |
args = parse_option() | |
cfg = setup(args) | |
model = Trainer.build_model(cfg) | |
DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load( | |
cfg.MODEL.WEIGHTS, resume=False | |
) | |
if cfg.MODEL.META_ARCHITECTURE in ['CLIPRCNN', 'CLIPFastRCNN', 'PretrainFastRCNN'] \ | |
and cfg.MODEL.CLIP.BB_RPN_WEIGHTS is not None\ | |
and cfg.MODEL.CLIP.CROP_REGION_TYPE == 'RPN': # load 2nd pretrained model | |
DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR, bb_rpn_weights=True).resume_or_load( | |
cfg.MODEL.CLIP.BB_RPN_WEIGHTS, resume=False | |
) | |
''' | |
build data transform | |
''' | |
eval_transforms = build_transforms(800, center_crop=False) | |
# display_transforms = build_transforms4display(960, center_crop=False) | |
def localize_object(image, texts): | |
print(texts) | |
img_t = eval_transforms(Image.fromarray(image).convert("RGB")) * 255 | |
print(img_t.shape) | |
model.eval() | |
with torch.no_grad(): | |
print(img_t[0][:10, :10]) | |
res = model(texts, [{"image": img_t}]) | |
return res | |
image = gr.inputs.Image() | |
gr.Interface( | |
description="RegionCLIP for Open-Vocabulary Object Detection", | |
fn=localize_object, | |
inputs=["image", "text"], | |
outputs=[ | |
gr.outputs.Image( | |
type="pil", | |
label="grounding results"), | |
], | |
examples=[ | |
["./elephants.png", "an elephant"], | |
["./apple_with_ipod.jpg", "an apple"], | |
], | |
).launch() | |