Spaces:
Runtime error
Runtime error
import argparse | |
import cv2 | |
import os | |
from PIL import Image, ImageDraw, ImageFont, ImageOps | |
import numpy as np | |
from pathlib import Path | |
import gradio as gr | |
import matplotlib.pyplot as plt | |
from loguru import logger | |
import subprocess | |
import copy | |
import time | |
import warnings | |
import torch | |
warnings.filterwarnings("ignore") | |
# grounding DINO | |
from groundingdino.models import build_model | |
from groundingdino.util.slconfig import SLConfig | |
from groundingdino.util.utils import clean_state_dict | |
from groundingdino.util.inference import annotate, load_image, predict | |
import groundingdino.datasets.transforms as T | |
from torchvision.ops import box_convert | |
# segment anything | |
from segment_anything import build_sam, SamPredictor | |
from huggingface_hub import hf_hub_download | |
os.environ["CUDA_VISIBLE_DEVICES"] = "0" | |
if not os.path.exists('./sam_vit_h_4b8939.pth'): | |
logger.info(f"get sam_vit_h_4b8939.pth...") | |
result = subprocess.run(['wget', 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth'], check=True) | |
print(f'wget sam_vit_h_4b8939.pth result = {result}') | |
# Use this command for evaluate the GLIP-T model | |
config_file = "groundingdino/config/GroundingDINO_SwinT_OGC.py" | |
ckpt_repo_id = "ShilongLiu/GroundingDINO" | |
ckpt_filename = "groundingdino_swint_ogc.pth" | |
sam_checkpoint = './sam_vit_h_4b8939.pth' | |
output_dir = "outputs" | |
groundingdino_device = 'cpu' | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
print(f'device={device}') | |
# make dir | |
os.makedirs(output_dir, exist_ok=True) | |
def load_model_hf(model_config_path, repo_id, filename, device='cpu'): | |
args = SLConfig.fromfile(model_config_path) | |
model = build_model(args) | |
args.device = device | |
cache_file = hf_hub_download(repo_id=repo_id, filename=filename) | |
checkpoint = torch.load(cache_file, map_location='cpu') | |
log = model.load_state_dict(clean_state_dict(checkpoint['model']), strict=False) | |
print("Model loaded from {} \n => {}".format(cache_file, log)) | |
_ = model.eval() | |
return model | |
def load_image_and_transform(init_image): | |
init_image = init_image.convert("RGB") | |
transform = T.Compose([ | |
T.RandomResize([800], max_size=1333), | |
T.ToTensor(), | |
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
]) | |
image, _ = transform(init_image, None) # 3, h, w | |
return init_image, image | |
def image_transform_grounding_for_vis(init_image): | |
transform = T.Compose([ | |
T.RandomResize([800], max_size=1333), | |
]) | |
image, _ = transform(init_image, None) # 3, h, w | |
return image | |
def plot_boxes_to_image(image_pil, tgt): | |
H, W = tgt["size"] | |
boxes = tgt["boxes"] | |
labels = tgt["labels"] | |
assert len(boxes) == len(labels), "boxes and labels must have same length" | |
draw = ImageDraw.Draw(image_pil) | |
mask = Image.new("L", image_pil.size, 0) | |
mask_draw = ImageDraw.Draw(mask) | |
# draw boxes and masks | |
for box, label in zip(boxes, labels): | |
# from 0..1 to 0..W, 0..H | |
box = box * torch.Tensor([W, H, W, H]) | |
# from xywh to xyxy | |
box[:2] -= box[2:] / 2 | |
box[2:] += box[:2] | |
# random color | |
color = tuple(np.random.randint(0, 255, size=3).tolist()) | |
# draw | |
x0, y0, x1, y1 = box | |
x0, y0, x1, y1 = int(x0), int(y0), int(x1), int(y1) | |
draw.rectangle([x0, y0, x1, y1], outline=color, width=6) | |
# draw.text((x0, y0), str(label), fill=color) | |
font = ImageFont.load_default() | |
if hasattr(font, "getbbox"): | |
bbox = draw.textbbox((x0, y0), str(label), font) | |
else: | |
w, h = draw.textsize(str(label), font) | |
bbox = (x0, y0, w + x0, y0 + h) | |
# bbox = draw.textbbox((x0, y0), str(label)) | |
draw.rectangle(bbox, fill=color) | |
font = os.path.join(cv2.__path__[0],'qt','fonts','DejaVuSans.ttf') | |
font_size = 36 | |
new_font = ImageFont.truetype(font, font_size) | |
draw.text((x0+2, y0+2), str(label), font=new_font, fill="white") | |
mask_draw.rectangle([x0, y0, x1, y1], fill=255, width=6) | |
return image_pil, mask | |
def show_mask(mask, ax, random_color=False): | |
if random_color: | |
color = np.concatenate([np.random.random(3), np.array([0.8])], axis=0) | |
else: | |
color = np.array([30/255, 144/255, 255/255, 0.6]) | |
h, w = mask.shape[-2:] | |
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) | |
ax.imshow(mask_image) | |
def show_box(box, ax, label): | |
x0, y0 = box[0], box[1] | |
w, h = box[2] - box[0], box[3] - box[1] | |
ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2)) | |
ax.text(x0, y0, label, fontdict={'fontsize': 7}) | |
def get_grounding_box(image_tensor, grounding_caption, box_threshold, text_threshold): | |
# run grounding | |
boxes, logits, phrases = predict(groundingDino_model, image_tensor, grounding_caption, box_threshold, text_threshold, device=groundingdino_device) | |
# annotated_frame = annotate(image_source=np.asarray(image_pil), boxes=boxes, logits=logits, phrases=phrases) | |
# image_with_box = Image.fromarray(cv2.cvtColor(annotated_frame, cv2.COLOR_BGR2RGB)) | |
return boxes, phrases | |
def grounding_sam(input_image, text_prompt, task_type, box_threshold, text_threshold, iou_threshold): | |
text_prompt = text_prompt.strip() | |
# user guidance messages | |
if not (task_type == 'inpainting' or task_type == 'remove'): | |
if text_prompt == '': | |
return [], gr.Gallery.update(label='Please input detection prompt~~') | |
if input_image is None: | |
return [], gr.Gallery.update(label='Please upload a image~~') | |
file_temp = int(time.time()) | |
image_pil, image_tensor = load_image_and_transform(input_image['image']) | |
# get dino bounding boxes | |
boxes, phrases = get_grounding_box(image_tensor, text_prompt, box_threshold, text_threshold) | |
if boxes.size(0) == 0: | |
logger.info(f'run_grounded_sam_[]_{task_type}_[{text_prompt}]_1_[No objects detected, please try others.]_') | |
return [], gr.Gallery.update(label='No objects detected, please try others!') | |
size = image_pil.size | |
pred_dict = { | |
"boxes": boxes, | |
"size": [size[1], size[0]], # H,W | |
"labels": phrases, | |
} | |
# store and save dino output | |
output_images = [] | |
image_with_box = plot_boxes_to_image(copy.deepcopy(image_pil), pred_dict)[0] | |
image_path = os.path.join(output_dir, f"grounding_dino_output_{file_temp}.jpg") | |
image_with_box.save(image_path) | |
detection_image_result = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB) | |
os.remove(image_path) | |
output_images.append(detection_image_result) | |
if task_type == 'segment': | |
image = np.array(input_image['image']) | |
sam_predictor.set_image(image) | |
# map the bounding boxes from dino to original size | |
h, w = size[1], size[0] | |
boxes = boxes * torch.Tensor([w, h, w, h]) | |
boxes = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy") | |
# can use box_convert function or below | |
# for i in range(boxes.size(0)): | |
# boxes[i] = boxes[i] * torch.Tensor([W, H, W, H]) | |
# boxes[i][:2] -= boxes[i][2:] / 2 # top left corner | |
# boxes[i][2:] += boxes[i][:2] # bottom left corner | |
# transform boxes from original ratio to sam's zoomed ratio | |
transformed_boxes = sam_predictor.transform.apply_boxes_torch(boxes, image.shape[:2]) | |
# predict masks/segmentation | |
# masks: [number of masks, C, H, W] but note that H and W is 512 | |
masks, _, _ = sam_predictor.predict_torch( | |
point_coords = None, | |
point_labels = None, | |
boxes = transformed_boxes, | |
multimask_output = False, | |
) | |
# draw output image | |
plt.figure(figsize=(10, 10)) | |
plt.imshow(image) | |
for mask in masks: | |
show_mask(mask.cpu().numpy(), plt.gca(), random_color=True) | |
for box, label in zip(boxes, phrases): | |
show_box(box.numpy(), plt.gca(), label) | |
plt.axis('off') | |
image_path = os.path.join(output_dir, f"grounding_seg_output_{file_temp}.jpg") | |
plt.savefig(image_path, bbox_inches="tight") | |
segment_image_result = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB) | |
os.remove(image_path) | |
output_images.append(segment_image_result) | |
return output_images, gr.Gallery.update(label='result images') | |
groundingDino_model = load_model_hf(config_file, ckpt_repo_id, ckpt_filename, groundingdino_device) | |
sam_predictor = SamPredictor(build_sam(checkpoint=sam_checkpoint)) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser("Grounding SAM demo", add_help=True) | |
parser.add_argument("--debug", action="store_true", help="using debug mode") | |
parser.add_argument("--share", action="store_true", help="share the app") | |
args = parser.parse_args() | |
print(f'args = {args}') | |
block = gr.Blocks().queue() | |
with block: | |
gr.Markdown("# GroundingDino and SAM") | |
with gr.Row(): | |
with gr.Column(): | |
input_image = gr.Image(source='upload', elem_id="image_upload", tool='sketch', type='pil', label="Upload") | |
task_type = gr.Radio(["segment"], value="segment", | |
label='Task type',interactive=True, visible=True) | |
text_prompt = gr.Textbox(label="Detection Prompt, seperating each name with ',', i.e.: cat,dog,chair ]", \ | |
placeholder="Cannot be empty") | |
run_button = gr.Button(label="Run") | |
with gr.Accordion("Advanced options", open=False): | |
box_threshold = gr.Slider( | |
label="Box Threshold", minimum=0.0, maximum=1.0, value=0.25, step=0.001 | |
) | |
text_threshold = gr.Slider( | |
label="Text Threshold", minimum=0.0, maximum=1.0, value=0.25, step=0.001 | |
) | |
iou_threshold = gr.Slider( | |
label="IOU Threshold", minimum=0.0, maximum=1.0, value=0.8, step=0.001 | |
) | |
with gr.Column(): | |
gallery = gr.Gallery( | |
label="result images", show_label=True, elem_id="gallery" | |
).style(grid=[2], full_width=True, full_height=True) | |
# gallery = gr.Gallery(label="Generated images", show_label=False).style( | |
# grid=[1], height="auto", container=True, full_width=True, full_height=True) | |
DESCRIPTION = '### This demo from [Grounded-Segment-Anything](https://github.com/IDEA-Research/Grounded-Segment-Anything) and kudos to thier excellent works. Welcome everyone to try this out and learn together!' | |
gr.Markdown(DESCRIPTION) | |
run_button.click(fn=grounding_sam, inputs=[ | |
input_image, text_prompt, task_type, box_threshold, text_threshold, iou_threshold], outputs=[gallery, gallery]) | |
block.launch(share=False, show_api=False, show_error=True) |