|
from typing import Optional |
|
|
|
import gradio as gr |
|
import subprocess |
|
import supervision as sv |
|
import torch |
|
from PIL import Image |
|
|
|
from utils.florence import load_florence_model, run_florence_inference, \ |
|
FLORENCE_OPEN_VOCABULARY_DETECTION_TASK |
|
from utils.sam import load_sam_image_model, run_sam_inference |
|
|
|
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True) |
|
|
|
|
|
DEVICE = torch.device("cpu") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
FLORENCE_MODEL, FLORENCE_PROCESSOR = load_florence_model(device=DEVICE) |
|
SAM_IMAGE_MODEL = load_sam_image_model(device=DEVICE) |
|
|
|
|
|
|
|
|
|
|
|
def process_image(image_input, text_input) -> Optional[Image.Image]: |
|
if not image_input: |
|
gr.Info("Please upload an image.") |
|
return None |
|
|
|
if not text_input: |
|
gr.Info("Please enter a text prompt.") |
|
return None |
|
|
|
_, result = run_florence_inference( |
|
model=FLORENCE_MODEL, |
|
processor=FLORENCE_PROCESSOR, |
|
device=DEVICE, |
|
image=image_input, |
|
task=FLORENCE_OPEN_VOCABULARY_DETECTION_TASK, |
|
text=text_input |
|
) |
|
detections = sv.Detections.from_lmm( |
|
lmm=sv.LMM.FLORENCE_2, |
|
result=result, |
|
resolution_wh=image_input.size |
|
) |
|
detections = run_sam_inference(SAM_IMAGE_MODEL, image_input, detections) |
|
if len(detections) == 0: |
|
gr.Info("No objects detected.") |
|
return None |
|
return Image.fromarray(detections.mask[0].astype("uint8") * 255) |
|
|
|
|
|
with gr.Blocks() as demo: |
|
with gr.Row(): |
|
with gr.Column(): |
|
image_input_component = gr.Image( |
|
type='pil', label='Upload image') |
|
text_input_component = gr.Textbox( |
|
label='Text prompt', |
|
placeholder='Enter text prompts') |
|
submit_button_component = gr.Button( |
|
value='Submit', variant='primary') |
|
with gr.Column(): |
|
image_output_component = gr.Image(label='Output mask') |
|
|
|
submit_button_component.click( |
|
fn=process_image, |
|
inputs=[ |
|
image_input_component, |
|
text_input_component |
|
], |
|
outputs=[ |
|
image_output_component, |
|
] |
|
) |
|
text_input_component.submit( |
|
fn=process_image, |
|
inputs=[ |
|
image_input_component, |
|
text_input_component |
|
], |
|
outputs=[ |
|
image_output_component, |
|
] |
|
) |
|
|
|
demo.launch(debug=False, show_error=True) |
|
|