SoM / app.py
SkalskiP's picture
Update README file and code refactoring
0757835
raw
history blame
3.82 kB
import os
import cv2
import torch
import gradio as gr
import numpy as np
import supervision as sv
from typing import List
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
from utils import postprocess_masks, Visualizer
from gpt4v import prompt_image
HOME = os.getenv("HOME")
DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
SAM_CHECKPOINT = os.path.join(HOME, "app/weights/sam_vit_h_4b8939.pth")
# SAM_CHECKPOINT = "weights/sam_vit_h_4b8939.pth"
SAM_MODEL_TYPE = "vit_h"
MARKDOWN = """
<h1 style='text-align: center'>
<img
src='https://som-gpt4v.github.io/website/img/som_logo.png'
style='height:50px; display:inline-block'
/>
Set-of-Mark (SoM) Prompting Unleashes Extraordinary Visual Grounding in GPT-4V
</h1>
## 🚧 Roadmap
- [ ] Support for alphabetic labels
- [ ] Support for Semantic-SAM (multi-level)
- [ ] Support for interactive mode
- [ ] Support for result highlighting
"""
SAM = sam_model_registry[SAM_MODEL_TYPE](checkpoint=SAM_CHECKPOINT).to(device=DEVICE)
def inference(
image: np.ndarray,
annotation_mode: List[str],
mask_alpha: float
) -> np.ndarray:
visualizer = Visualizer(mask_opacity=mask_alpha)
mask_generator = SamAutomaticMaskGenerator(SAM)
result = mask_generator.generate(image=image)
detections = sv.Detections.from_sam(result)
detections = postprocess_masks(
detections=detections)
bgr_image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
annotated_image = visualizer.visualize(
image=bgr_image,
detections=detections,
with_box="Box" in annotation_mode,
with_mask="Mask" in annotation_mode,
with_polygon="Polygon" in annotation_mode,
with_label="Mark" in annotation_mode)
return cv2.cvtColor(annotated_image, cv2.COLOR_BGR2RGB)
def prompt(message, history, image: np.ndarray, api_key: str) -> str:
if api_key == "":
return "⚠️ Please set your OpenAI API key first"
if image is None:
return "⚠️ Please generate SoM visual prompt first"
return prompt_image(
api_key=api_key,
image=cv2.cvtColor(image, cv2.COLOR_BGR2RGB),
prompt=message
)
image_input = gr.Image(
label="Input",
type="numpy",
height=512)
checkbox_annotation_mode = gr.CheckboxGroup(
choices=["Mark", "Polygon", "Mask", "Box"],
value=['Mark'],
label="Annotation Mode")
slider_mask_alpha = gr.Slider(
minimum=0,
maximum=1,
value=0.05,
label="Mask Alpha")
image_output = gr.Image(
label="SoM Visual Prompt",
type="numpy",
height=512)
openai_api_key = gr.Textbox(
show_label=False,
placeholder="Before you start chatting, set your OpenAI API key here",
lines=1,
type="password")
chatbot = gr.Chatbot(
label="GPT-4V + SoM",
height=256)
run_button = gr.Button("Run")
with gr.Blocks() as demo:
gr.Markdown(MARKDOWN)
with gr.Row():
with gr.Column():
image_input.render()
with gr.Accordion(
label="Detailed prompt settings (e.g., mark type)",
open=False):
with gr.Row():
checkbox_annotation_mode.render()
with gr.Row():
slider_mask_alpha.render()
with gr.Column():
image_output.render()
run_button.render()
with gr.Row():
openai_api_key.render()
with gr.Row():
gr.ChatInterface(
chatbot=chatbot,
fn=prompt,
additional_inputs=[image_output, openai_api_key])
run_button.click(
fn=inference,
inputs=[image_input, checkbox_annotation_mode, slider_mask_alpha],
outputs=image_output)
demo.queue().launch(debug=False, show_error=True)