SauravMaheshkar's picture
feat: drop redundant image box
630e69b unverified
raw
history blame
2.62 kB
import gradio as gr
import numpy as np
import cv2
import torch
from typing import Dict, Any, List
from src.plot_utils import show_masks
from gradio_image_annotation import image_annotator
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
choice_mapping: Dict[str, List[str]] = {
"tiny": ["sam2_hiera_t.yaml", "assets/checkpoints/sam2_hiera_tiny.pt"],
"small": ["sam2_hiera_s.yaml", "assets/checkpoints/sam2_hiera_small.pt"],
"base_plus": ["sam2_hiera_b+.yaml", "assets/checkpoints/sam2_hiera_base_plus.pt"],
"large": ["sam2_hiera_l.yaml", "assets/checkpoints/sam2_hiera_large.pt"],
}
def predict(model_choice, annotations: Dict[str, Any]):
config_file, ckpt_path = choice_mapping[str(model_choice)]
device = "cuda" if torch.cuda.is_available() else "cpu"
sam2_model = build_sam2(config_file, ckpt_path, device=device)
predictor = SAM2ImagePredictor(sam2_model)
predictor.set_image(annotations["image"])
coordinates = np.array(
[
int(annotations["boxes"][0]["xmin"]),
int(annotations["boxes"][0]["ymin"]),
int(annotations["boxes"][0]["xmax"]),
int(annotations["boxes"][0]["ymax"]),
]
)
masks, scores, _ = predictor.predict(
point_coords=None,
point_labels=None,
box=coordinates[None, :],
multimask_output=False,
)
mask = masks.transpose(1, 2, 0)
mask_image = (mask * 255).astype(np.uint8) # Convert to uint8 format
cv2.imwrite("mask.png", mask_image)
return [
show_masks(annotations["image"], masks, scores, box_coords=coordinates),
gr.DownloadButton("Download Mask", value="mask.png", visible=True),
]
with gr.Blocks(delete_cache=(30, 30)) as demo:
gr.Markdown(
"""
# 1. Choose Model Checkpoint
"""
)
with gr.Row():
model = gr.Dropdown(
choices=["tiny", "small", "base_plus", "large"],
value="tiny",
label="Model Checkpoint",
info="Which model checkpoint to load?",
)
gr.Markdown(
"""
# 2. Upload your Image and draw a bounding box
"""
)
annotator = image_annotator(
value={"image": cv2.imread("assets/example.png")},
disable_edit_boxes=True,
label="Draw a bounding box",
)
btn = gr.Button("Get Segmentation Mask")
download_btn = gr.DownloadButton("Download Mask", value="mask.png", visible=False)
btn.click(fn=predict, inputs=[model, annotator], outputs=[gr.Plot(), download_btn])
demo.launch()