|
import gradio as gr |
|
|
|
from grounded_sam.inference import grounded_segmentation |
|
from grounded_sam.plot import plot_detections, plot_detections_plotly |
|
|
|
def app_fn( |
|
image: gr.Image, |
|
labels: str, |
|
threshold: float, |
|
bounding_box_selection: bool |
|
) -> str: |
|
labels = labels.split("\n") |
|
labels = [label if label.endswith(".") else label + "." for label in labels] |
|
image_array, detections = grounded_segmentation(image, labels, threshold, True) |
|
fig_detection = plot_detections_plotly(image_array, detections, bounding_box_selection) |
|
|
|
return fig_detection |
|
|
|
if __name__=="__main__": |
|
title = "Grounding SAM - Text-to-Segmentation Model" |
|
with gr.Blocks(title=title) as demo: |
|
gr.Markdown(f"# {title}") |
|
gr.Markdown( |
|
""" |
|
Grounded SAM is a text-to-segmentation model that generates segmentation masks from natural language descriptions. |
|
This demo uses Grounding DINO in tandem with SAM to generate segmentation masks from text. |
|
The workflow is as follows: |
|
1. Select text labels to generate bounding boxes with Grounding DINO. |
|
2. Prompt the SAM model to generate segmentation masks from the bounding boxes. |
|
3. Refine the masks if needed. |
|
4. Visualize the segmentation masks. |
|
|
|
|
|
### Notes |
|
- To pass multiple labels, separate them by a new line. |
|
- The model may take a few seconds to generate the segmentation masks as we need to run through two models. |
|
- The refinement is done by default by converting the mask to a polygon and back to a mask with openCV. |
|
- I use in here a concise implementation, but you can find the full code at [GitHub](https://github.com/EduardoPach/grounded-sam) |
|
""" |
|
) |
|
with gr.Row(): |
|
threshold = gr.Slider(minimum=0, maximum=1, value=0.3, step=0.05, label="Box Threshold") |
|
labels = gr.Textbox(lines=2, max_lines=5, label="Labels") |
|
bounding_box_selection = gr.Checkbox(label="Allow Box Selection") |
|
btn = gr.Button() |
|
with gr.Row(): |
|
img = gr.Image(type="pil") |
|
fig = gr.Plot(label="Segmentation Mask") |
|
|
|
btn.click(fn=app_fn, inputs=[img, labels, threshold, bounding_box_selection], outputs=[fig]) |
|
|
|
gr.Examples( |
|
[ |
|
["input_image.jpeg", "a person.\na mountain.", 0.3, False], |
|
], |
|
inputs = [img, labels, threshold, bounding_box_selection], |
|
outputs = [fig], |
|
fn=app_fn, |
|
cache_examples="lazy", |
|
label='Try this example input!' |
|
) |
|
|
|
demo.launch() |