Spaces:
Running
on
Zero
Running
on
Zero
from typing import Optional | |
import numpy as np | |
import gradio as gr | |
import spaces | |
import supervision as sv | |
import torch | |
from PIL import Image | |
from io import BytesIO | |
import PIL.Image | |
import requests | |
import cv2 | |
import json | |
from utils.florence import load_florence_model, run_florence_inference, \ | |
FLORENCE_OPEN_VOCABULARY_DETECTION_TASK | |
from utils.sam import load_sam_image_model, run_sam_inference | |
DEVICE = torch.device("cuda") | |
# DEVICE = torch.device("cpu") | |
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__() | |
if torch.cuda.get_device_properties(0).major >= 8: | |
torch.backends.cuda.matmul.allow_tf32 = True | |
torch.backends.cudnn.allow_tf32 = True | |
FLORENCE_MODEL, FLORENCE_PROCESSOR = load_florence_model(device=DEVICE) | |
SAM_IMAGE_MODEL = load_sam_image_model(device=DEVICE) | |
def process_image(image_input, image_url, task_prompt, text_prompt=None, dilate=0, merge_masks=False, return_rectangles=False) -> Optional[Image.Image]: | |
if not image_input: | |
gr.Info("Please upload an image.") | |
return None | |
if not task_prompt: | |
gr.Info("Please enter a task prompt.") | |
return None | |
if image_url: | |
print("start to fetch image from url", image_url) | |
response = requests.get(image_url) | |
response.raise_for_status() | |
image_input = PIL.Image.open(BytesIO(response.content)) | |
print("fetch image success") | |
# start to parse prompt | |
_, result = run_florence_inference( | |
model=FLORENCE_MODEL, | |
processor=FLORENCE_PROCESSOR, | |
device=DEVICE, | |
image=image_input, | |
task=task_prompt, | |
text=text_prompt | |
) | |
# start to dectect | |
detections = sv.Detections.from_lmm( | |
lmm=sv.LMM.FLORENCE_2, | |
result=result, | |
resolution_wh=image_input.size | |
) | |
json_result = json.dumps({"bbox": detections.xyxy, "data": detections.data}) | |
images = [] | |
if return_rectangles: | |
# create mask in rectangle | |
(image_width, image_height) = image_input.size | |
bboxes = detections.xyxy | |
merge_mask_image = np.zeros((image_height, image_width), dtype=np.uint8) | |
for bbox in bboxes: | |
x1, y1, x2, y2 = map(int, bbox) | |
cv2.rectangle(merge_mask_image, (x1, y1), (x2, y2), 255, thickness=cv2.FILLED) | |
clip_mask = np.zeros((image_height, image_width), dtype=np.uint8) | |
cv2.rectangle(clip_mask, (x1, y1), (x2, y2), 255, thickness=cv2.FILLED) | |
images.append(clip_mask) | |
if merge_masks: | |
images = [merge_mask_image] + images | |
else: | |
# using sam generate segments images | |
detections = run_sam_inference(SAM_IMAGE_MODEL, image_input, detections) | |
if len(detections) == 0: | |
gr.Info("No objects detected.") | |
return None | |
print("mask generated:", len(detections.mask)) | |
kernel_size = dilate | |
kernel = np.ones((kernel_size, kernel_size), np.uint8) | |
for i in range(len(detections.mask)): | |
mask = detections.mask[i].astype(np.uint8) * 255 | |
if dilate > 0: | |
mask = cv2.dilate(mask, kernel, iterations=1) | |
images.append(mask) | |
if merge_masks: | |
merged_mask = np.zeros_like(images[0], dtype=np.uint8) | |
for mask in images: | |
merged_mask = cv2.bitwise_or(merged_mask, mask) | |
images = [merged_mask] | |
return [images, json_result] | |
with gr.Blocks() as demo: | |
with gr.Row(): | |
with gr.Column(): | |
image = gr.Image(type='pil', label='Upload image') | |
image_url = gr.Textbox(label='Image url', placeholder='Enter text prompts (Optional)') | |
task_prompt = gr.Dropdown( | |
['<OD>', '<CAPTION_TO_PHRASE_GROUNDING>', '<DENSE_REGION_CAPTION>', '<REGION_PROPOSAL>', '<OCR_WITH_REGION>', '<REFERRING_EXPRESSION_SEGMENTATION>', '<REGION_TO_SEGMENTATION>', '<OPEN_VOCABULARY_DETECTION>', '<REGION_TO_CATEGORY>', '<REGION_TO_DESCRIPTION>'], value="<CAPTION_TO_PHRASE_GROUNDING>", label="Task Prompt", info="task prompts" | |
) | |
dilate = gr.Slider(label="dilate mask", minimum=0, maximum=50, value=10, step=1) | |
merge_masks = gr.Checkbox(label="Merge masks", value=False) | |
return_rectangles = gr.Checkbox(label="Return Rectangles", value=False) | |
text_prompt = gr.Textbox(label='Text prompt', placeholder='Enter text prompts') | |
submit_button = gr.Button(value='Submit', variant='primary') | |
with gr.Column(): | |
image_gallery = gr.Gallery(label="Generated images", show_label=False, elem_id="gallery", columns=[3], rows=[1], object_fit="contain", height="auto") | |
json_result = gr.Code(label="JSON Result", language="json") | |
print(image, image_url, task_prompt, text_prompt, image_gallery) | |
submit_button.click( | |
fn=process_image, | |
inputs=[image, image_url, task_prompt, text_prompt, dilate, merge_masks, return_rectangles], | |
outputs=[image_gallery, json_result], | |
show_api=False | |
) | |
demo.launch(debug=True, show_error=True) |