Spaces:
Sleeping
Sleeping
try: | |
import detectron2 | |
except: | |
import os | |
os.system('pip install git+https://github.com/facebookresearch/detectron2.git') | |
import gradio as gr | |
import torch | |
from PIL import ImageDraw | |
from PIL import Image | |
import numpy as np | |
from torchvision.transforms import ToPILImage | |
import matplotlib.pyplot as plt | |
import cv2 | |
from regionspot.modeling.regionspot import build_regionspot_model | |
from regionspot import RegionSpot_Predictor | |
from regionspot import SamAutomaticMaskGenerator | |
import ast | |
fdic = { | |
# "family": "Impact", | |
# "style": "italic", | |
"size": 15, | |
# "color": "yellow", | |
# "weight": "bold", | |
} | |
def show_mask(mask, ax, random_color=False): | |
if random_color: | |
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) | |
else: | |
color = np.array([30/255, 144/255, 255/255, 0.6]) | |
h, w = mask.shape[-2:] | |
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) | |
ax.imshow(mask_image) | |
# Function to show points on an image | |
def show_points(coords, labels, ax, marker_size=375): | |
pos_points = coords[labels == 1] | |
neg_points = coords[labels == 0] | |
ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25) | |
ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25) | |
# Function to show bounding boxes on an image | |
def show_box(box, ax): | |
x0, y0 = box[0], box[1] | |
w, h = box[2] - x0, box[3] - y0 | |
ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor='none', lw=2)) | |
def auto_show_box(box, label, ax): | |
x0, y0 = box[0], box[1] | |
w, h =box[2], box[3] | |
ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2)) | |
ax.text(x0,y0,f"{label}", fontdict=fdic) | |
def show_anns(image, anns, custom_vocabulary): | |
if anns == False: | |
plt.imshow(image) | |
ax = plt.gca() | |
ax.set_autoscale_on(False) | |
ax.imshow(image) | |
pic = plt.gcf() | |
pic.canvas.draw() | |
w,h = pic.canvas.get_width_height() | |
image = Image.frombytes('RGB', (w,h), pic.canvas.tostring_rgb()) | |
return image | |
plt.imshow(image) | |
if len(anns) == 0: | |
return | |
sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True) | |
ax = plt.gca() | |
ax.set_autoscale_on(False) | |
img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 4)) | |
img[:,:,3] = 0 | |
for ann in sorted_anns: | |
l = custom_vocabulary[int(ann['pred_class'])] | |
if l != 'background': | |
m = ann['segmentation'] | |
color_mask = np.concatenate([np.random.random(3), [0.35]]) | |
img[m] = color_mask | |
b = ann['bbox'] | |
auto_show_box(b,l, ax) | |
ax.imshow(img) | |
ax.axis('off') | |
pic = plt.gcf() | |
pic.canvas.draw() | |
w,h = pic.canvas.get_width_height() | |
image = Image.frombytes('RGB', (w,h), pic.canvas.tostring_rgb()) | |
return image | |
def process_box(image, input_box, masks, mask_iou_score, class_score, class_index, custom_vocabulary): | |
# Extract class name and display image with masks and box | |
fig, ax = plt.subplots(figsize=(10, 10)) | |
ax.imshow(image) | |
for idx in range(len(input_box)): | |
show_mask(masks[idx], ax) | |
show_box(input_box[idx], ax) # Assuming box_prompt contains all your boxes | |
# You might want to modify the text display for multiple classes as well | |
class_name = custom_vocabulary[int(class_index[idx])] | |
ax.text(input_box[idx][0], input_box[idx][1] - 10, class_name, color='white', fontsize=14, bbox=dict(facecolor='green', edgecolor='green', alpha=0.6)) | |
ax.axis('off') | |
pic = plt.gcf() | |
pic.canvas.draw() | |
w,h = pic.canvas.get_width_height() | |
image = Image.frombytes('RGB', (w,h), pic.canvas.tostring_rgb()) | |
return image | |
device = torch.device( | |
"cuda" | |
if torch.cuda.is_available() | |
else "mps" | |
if torch.backends.mps.is_available() | |
else "cpu" | |
) | |
# Description | |
title = "<center><strong><font size='8'> RegionSpot: Recognize Any Regions </font></strong></center>" | |
description_e = """ This is a demo on Github project [Recognize Any Regions](https://github.com/Surrey-UPLab/Recognize-Any-Regions). Welcome to give a star to it. | |
""" | |
description_p = """ This is a demo on Github project [Recognize Any Regions](https://github.com/Surrey-UPLab/Recognize-Any-Regions). Welcome to give a star to it. | |
""" | |
description_b = """ This is a demo on Github project [Recognize Any Regions](https://github.com/Surrey-UPLab/Recognize-Any-Regions). Welcome to give a star to it. | |
""" | |
examples = [["examples/dogs.jpg"], ["examples/fruits.jpg"], ["examples/flowers.jpg"], | |
["examples/000000190756.jpg"], ["examples/image.jpg"], ["examples/000000263860.jpg"], | |
["examples/000000298738.jpg"], ["examples/000000027620.jpg"], ["examples/000000112634.jpg"], | |
["examples/000000377814.jpg"], ["examples/000000516143.jpg"]] | |
default_example = examples[0] | |
css = "h1 { text-align: center } .about { text-align: justify; padding-left: 10%; padding-right: 10%; }" | |
def segment_sementic(image, text): | |
mask_threshold = 0.0 | |
img = image | |
coor = np.nonzero(img["mask"]) | |
coor[0].sort() | |
xmin = coor[0][0] | |
xmax = coor[0][-1] | |
coor[1].sort() | |
ymin = coor[1][0] | |
ymax = coor[1][-1] | |
input_box = np.array([[ymin, xmin, ymax, xmax]]) | |
image = img["image"] | |
input_image = np.asarray(image) | |
ckpt_path = 'regionspot_bl_336.pth' | |
clip_type = 'CLIP_400M_Large_336' | |
# clip_input_size = 336 | |
clip_input_size = 224 | |
text = text.split(',') | |
custom_vocabulary = text | |
# Build and initialize the model | |
model, msg = build_regionspot_model(is_training=False, image_size=clip_input_size, clip_type=clip_type, pretrain_ckpt=ckpt_path, | |
custom_vocabulary=custom_vocabulary) | |
# Create predictor and set image | |
predictor = RegionSpot_Predictor(model.cuda()) | |
predictor.set_image(input_image, clip_input_size=clip_input_size) | |
masks, mask_iou_score, class_score, class_index = predictor.predict( | |
point_coords=None, | |
point_labels=None, | |
box=input_box, | |
multimask_output=False, | |
mask_threshold = mask_threshold, | |
) | |
fig = process_box(input_image, input_box,masks, mask_iou_score, class_score, class_index, custom_vocabulary) | |
torch.cuda.empty_cache() | |
torch.cuda.empty_cache() | |
torch.cuda.empty_cache() | |
torch.cuda.empty_cache() | |
return fig | |
def text_segment_sementic(image, text, conf_threshold, box_threshold, crop_n_layers, crop_nms_threshold): | |
mask_threshold = 0.0 | |
image = image | |
input_image = np.asarray(image) | |
text = text.split(',') | |
textP = ['background'] | |
text = textP + text | |
custom_vocabulary = text | |
ckpt_path = 'regionspot_bl_336.pth' | |
clip_type = 'CLIP_400M_Large_336' | |
clip_input_size = 336 | |
# clip_input_size = 224 | |
model, msg = build_regionspot_model(is_training=False, image_size=clip_input_size, clip_type=clip_type, pretrain_ckpt=ckpt_path, | |
custom_vocabulary=custom_vocabulary) | |
mask_generator = SamAutomaticMaskGenerator(model.cuda(), | |
# crop_thresh=iou_threshold, | |
box_thresh=conf_threshold,mask_threshold=mask_threshold, | |
box_nms_thresh=box_threshold, crop_n_layers=crop_n_layers, crop_nms_thresh= crop_nms_threshold) | |
masks = mask_generator.generate(input_image) | |
fig = show_anns(input_image, masks, custom_vocabulary) | |
torch.cuda.empty_cache() | |
torch.cuda.empty_cache() | |
torch.cuda.empty_cache() | |
torch.cuda.empty_cache() | |
return fig | |
def point_segment_sementic(image, text, box_threshold, crop_nms_threshold): | |
global global_points | |
global global_point_label | |
global image_temp | |
mask_threshold = 0.0 | |
input_image = image_temp | |
output_image = np.asarray(image) | |
ckpt_path = 'regionspot_bl_336.pth' | |
clip_type = 'CLIP_400M_Large_336' | |
clip_input_size = 336 | |
# clip_input_size = 224 | |
text = text.split(',') | |
textP = ['background'] | |
text = textP + text | |
custom_vocabulary = text | |
model, msg = build_regionspot_model(is_training=False, image_size=clip_input_size, clip_type=clip_type, pretrain_ckpt=ckpt_path, | |
custom_vocabulary=custom_vocabulary) | |
mask_generator = SamAutomaticMaskGenerator(model.cuda(), | |
crop_thresh=0.0, | |
box_thresh=0.0, | |
mask_threshold=mask_threshold, | |
box_nms_thresh=box_threshold, crop_nms_thresh= crop_nms_threshold) | |
masks = mask_generator.generate_point(input_image,point=np.asarray(global_points), label=np.asarray(global_point_label)) | |
fig = show_anns(output_image, masks, custom_vocabulary) | |
torch.cuda.empty_cache() | |
torch.cuda.empty_cache() | |
torch.cuda.empty_cache() | |
torch.cuda.empty_cache() | |
return fig | |
def get_points_with_draw(image, label, evt: gr.SelectData): | |
global global_points | |
global global_point_label | |
global image_temp | |
if global_point_label == []: | |
image_temp = np.asarray(image) | |
x, y = evt.index[0], evt.index[1] | |
point_radius, point_color = 15, (255, 255, 0) if label == 'Mask' else (255, 0, 255) | |
global_points.append([x, y]) | |
global_point_label.append(1 if label == 'Mask' else 0) | |
draw = ImageDraw.Draw(image) | |
draw.ellipse([(x - point_radius, y - point_radius), (x + point_radius, y + point_radius)], fill=point_color) | |
return image | |
cond_img_p = gr.Image(label="Input with points", value="examples/dogs.jpg", type='pil') | |
cond_img_t = gr.Image(label="Input with text", value="examples/dogs.jpg", type='pil') | |
cond_img_b = gr.Image(label="Input with box", type="pil",tool='sketch') | |
# cond_img_b = gr.Image(label="Input with box", type="pil") | |
img_p = gr.Image(label="Input with points P", type='pil') | |
segm_img_p = gr.Image(label="Segmented Image with points", interactive=False, type='pil') | |
segm_img_t = gr.Image(label="Segmented Image with text", interactive=False, type='pil') | |
segm_img_b = gr.Image(label="Segmented Image with box", interactive=False, type='pil') | |
global_points = [] | |
global_point_label = [] | |
image_temp = np.array([]) | |
with gr.Blocks(css=css, title='Region Spot') as demo: | |
with gr.Row(): | |
with gr.Column(scale=1): | |
# Title | |
gr.Markdown(title) | |
with gr.Tab("Points mode"): | |
# Images | |
with gr.Row(variant="panel"): | |
with gr.Column(scale=1): | |
cond_img_p.render() | |
with gr.Column(scale=1): | |
segm_img_p.render() | |
# Submit & Clear | |
with gr.Row(): | |
with gr.Column(): | |
with gr.Row(): | |
with gr.Column(): | |
add_or_remove = gr.Radio(["Mask", "Background"], value="Mask", label="Point_label (foreground/background)") | |
text_box_p = gr.Textbox(label="vocabulary", value="dog,cat") | |
with gr.Column(): | |
segment_btn_p = gr.Button("Segment with points prompt", variant='primary') | |
clear_btn_p = gr.Button("Clear", variant='secondary') | |
gr.Markdown("Try some of the examples below") | |
gr.Examples(examples=examples, | |
inputs=[cond_img_t], | |
examples_per_page=18) | |
with gr.Column(): | |
with gr.Accordion("Advanced options", open=True): | |
box_threshold_p = gr.Slider(0.0, 0.9, 0.7, step=0.05, label='box threshold', info='box nms threshold') | |
crop_threshold_p = gr.Slider(0.0, 0.9, 0.7, step=0.05, label='crop threshold', info='crop nms threshold') | |
# Description | |
gr.Markdown(description_p) | |
cond_img_p.select(get_points_with_draw, [cond_img_p, add_or_remove], cond_img_p) | |
segment_btn_p.click(point_segment_sementic, | |
inputs=[ | |
cond_img_p, | |
text_box_p, | |
box_threshold_p, | |
crop_threshold_p, | |
], | |
outputs=[segm_img_p]) | |
with gr.Tab("Text mode"): | |
# Images | |
with gr.Row(variant="panel"): | |
with gr.Column(scale=1): | |
cond_img_t.render() | |
with gr.Column(scale=1): | |
segm_img_t.render() | |
# Submit & Clear | |
with gr.Row(): | |
with gr.Column(): | |
with gr.Row(): | |
with gr.Column(): | |
contour_check = gr.Checkbox(value=True, label='withContours', info='draw the edges of the masks') | |
text_box_t = gr.Textbox(label="text prompt", value="dog,cat") | |
with gr.Column(): | |
segment_btn_t = gr.Button("Segment with text", variant='primary') | |
clear_btn_t = gr.Button("Clear", variant="secondary") | |
gr.Markdown("Try some of the examples below") | |
gr.Examples(examples=examples, | |
inputs=[cond_img_t], | |
examples_per_page=18) | |
with gr.Column(): | |
with gr.Accordion("Advanced options", open=True): | |
conf_threshold_t = gr.Slider(0.0, 0.9, 0.8, step=0.05, label='clip threshold', info='object confidence threshold of clip') | |
box_threshold_t = gr.Slider(0.0, 0.9, 0.5, step=0.05, label='box threshold', info='box nms threshold') | |
crop_n_layers_t = gr.Slider(0, 3, 0, step=1, label='crop_n_layers', info='crop_n_layers of auto generator') | |
crop_threshold_t = gr.Slider(0.0, 0.9, 0.5, step=0.05, label='crop threshold', info='crop nms threshold') | |
# Description | |
gr.Markdown(description_e) | |
segment_btn_t.click(text_segment_sementic, | |
inputs=[ | |
cond_img_t, | |
text_box_t, | |
conf_threshold_t, | |
box_threshold_t, | |
crop_n_layers_t, | |
crop_threshold_t | |
], | |
outputs=[segm_img_t]) | |
with gr.Tab("Box mode"): | |
# Images | |
with gr.Row(variant="panel"): | |
with gr.Column(scale=1): | |
cond_img_b.render() | |
with gr.Column(scale=1): | |
segm_img_b.render() | |
# Submit & Clear | |
with gr.Row(): | |
with gr.Column(): | |
with gr.Row(): | |
with gr.Column(): | |
contour_check = gr.Checkbox(value=True, label='withContours', info='draw the edges of the masks') | |
text_box_b = gr.Textbox(label="vocabulary", value="dog,cat") | |
with gr.Column(): | |
segment_btn_b = gr.Button("Segment with box", variant='primary') | |
clear_btn_b = gr.Button("Clear", variant="secondary") | |
gr.Markdown("Try some of the examples below") | |
gr.Examples(examples=examples, | |
inputs=[cond_img_t], | |
examples_per_page=18) | |
with gr.Column(): | |
# Description | |
gr.Markdown(description_b) | |
segment_btn_b.click(segment_sementic, | |
inputs=[ | |
cond_img_b, | |
text_box_b, | |
], | |
outputs=[segm_img_b]) | |
def clear(): | |
return None, None, None | |
def clear_text(): | |
return None, None, None | |
clear_btn_p.click(clear, outputs=[cond_img_p, segm_img_p, text_box_p]) | |
clear_btn_t.click(clear_text, outputs=[cond_img_t, segm_img_t, text_box_t]) | |
clear_btn_b.click(clear_text, outputs=[cond_img_b, segm_img_b, text_box_b]) | |
demo.queue() | |
demo.launch() |