Add IoU filter function and update Dockerfile & app.py
Browse filesAdded functions to compute Intersection over Union (IoU) and to filter out highly overlapping masks in utils.py. This is to improve the quality of masks by removing redundant masks. Updated post-processing masks function parameters and values in utils.py.
Further, Dockerfile is updated to add one more dependency - requests, a file - gpt4v.py, where it is used for communication with OpenAI API to compute image data in base64 format.
Additionally, OpenAI API key support and changes reflecting chat interface elements in Run button are made to app.py for interaction with GPT-4V. GPT-4V module added to improve the interactive mode support for image based queries.
Dockerfile
CHANGED
@@ -31,7 +31,8 @@ WORKDIR $HOME/app
|
|
31 |
RUN pip install torch==2.0.1+cu117 torchvision==0.15.2+cu117 -f https://download.pytorch.org/whl/torch_stable.html
|
32 |
|
33 |
# Install dependencies
|
34 |
-
RUN pip install --no-cache-dir gradio==3.50.2 opencv-python supervision==0.17.0rc3
|
|
|
35 |
|
36 |
# Install SAM and Detectron2
|
37 |
RUN pip install 'git+https://github.com/facebookresearch/segment-anything.git'
|
@@ -43,6 +44,7 @@ RUN wget -c -O $HOME/app/weights/sam_vit_h_4b8939.pth https://dl.fbaipublicfiles
|
|
43 |
|
44 |
COPY app.py .
|
45 |
COPY utils.py .
|
|
|
46 |
|
47 |
RUN find $HOME/app
|
48 |
|
|
|
31 |
RUN pip install torch==2.0.1+cu117 torchvision==0.15.2+cu117 -f https://download.pytorch.org/whl/torch_stable.html
|
32 |
|
33 |
# Install dependencies
|
34 |
+
RUN pip install --no-cache-dir gradio==3.50.2 opencv-python supervision==0.17.0rc3 \
|
35 |
+
pillow requests
|
36 |
|
37 |
# Install SAM and Detectron2
|
38 |
RUN pip install 'git+https://github.com/facebookresearch/segment-anything.git'
|
|
|
44 |
|
45 |
COPY app.py .
|
46 |
COPY utils.py .
|
47 |
+
COPY gpt4v.py .
|
48 |
|
49 |
RUN find $HOME/app
|
50 |
|
app.py
CHANGED
@@ -14,8 +14,8 @@ HOME = os.getenv("HOME")
|
|
14 |
DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
15 |
MINIMUM_AREA_THRESHOLD = 0.01
|
16 |
|
17 |
-
|
18 |
-
SAM_CHECKPOINT = "weights/sam_vit_h_4b8939.pth"
|
19 |
SAM_MODEL_TYPE = "vit_h"
|
20 |
|
21 |
MARKDOWN = """
|
@@ -27,11 +27,19 @@ MARKDOWN = """
|
|
27 |
Set-of-Mark (SoM) Prompting Unleashes Extraordinary Visual Grounding in GPT-4V
|
28 |
</h1>
|
29 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
## 🚧 Roadmap
|
31 |
|
32 |
- [ ] Support for alphabetic labels
|
33 |
- [ ] Support for Semantic-SAM (multi-level)
|
34 |
- [ ] Support for interactive mode
|
|
|
35 |
"""
|
36 |
|
37 |
SAM = sam_model_registry[SAM_MODEL_TYPE](checkpoint=SAM_CHECKPOINT).to(device=DEVICE)
|
@@ -60,6 +68,10 @@ def inference(
|
|
60 |
return cv2.cvtColor(annotated_image, cv2.COLOR_BGR2RGB)
|
61 |
|
62 |
|
|
|
|
|
|
|
|
|
63 |
image_input = gr.Image(
|
64 |
label="Input",
|
65 |
type="numpy",
|
@@ -77,6 +89,12 @@ image_output = gr.Image(
|
|
77 |
label="SoM Visual Prompt",
|
78 |
type="numpy",
|
79 |
height=512)
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
run_button = gr.Button("Run")
|
81 |
|
82 |
with gr.Blocks() as demo:
|
@@ -92,6 +110,9 @@ with gr.Blocks() as demo:
|
|
92 |
with gr.Column():
|
93 |
image_output.render()
|
94 |
run_button.render()
|
|
|
|
|
|
|
95 |
|
96 |
run_button.click(
|
97 |
fn=inference,
|
|
|
14 |
DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
15 |
MINIMUM_AREA_THRESHOLD = 0.01
|
16 |
|
17 |
+
SAM_CHECKPOINT = os.path.join(HOME, "app/weights/sam_vit_h_4b8939.pth")
|
18 |
+
# SAM_CHECKPOINT = "weights/sam_vit_h_4b8939.pth"
|
19 |
SAM_MODEL_TYPE = "vit_h"
|
20 |
|
21 |
MARKDOWN = """
|
|
|
27 |
Set-of-Mark (SoM) Prompting Unleashes Extraordinary Visual Grounding in GPT-4V
|
28 |
</h1>
|
29 |
|
30 |
+
## 🚀 How To
|
31 |
+
|
32 |
+
- Upload an image.
|
33 |
+
- Click the `Run` button to generate the image with marks.
|
34 |
+
- Pass OpenAI API 🔑. You can get one [here](https://platform.openai.com/api-keys).
|
35 |
+
- Ask GPT-4V questions about the image in the chatbot.
|
36 |
+
|
37 |
## 🚧 Roadmap
|
38 |
|
39 |
- [ ] Support for alphabetic labels
|
40 |
- [ ] Support for Semantic-SAM (multi-level)
|
41 |
- [ ] Support for interactive mode
|
42 |
+
- [ ] Support for result highlighting
|
43 |
"""
|
44 |
|
45 |
SAM = sam_model_registry[SAM_MODEL_TYPE](checkpoint=SAM_CHECKPOINT).to(device=DEVICE)
|
|
|
68 |
return cv2.cvtColor(annotated_image, cv2.COLOR_BGR2RGB)
|
69 |
|
70 |
|
71 |
+
def prompt(message, history):
|
72 |
+
return "response"
|
73 |
+
|
74 |
+
|
75 |
image_input = gr.Image(
|
76 |
label="Input",
|
77 |
type="numpy",
|
|
|
89 |
label="SoM Visual Prompt",
|
90 |
type="numpy",
|
91 |
height=512)
|
92 |
+
textbox_api_key = gr.Textbox(
|
93 |
+
label="OpenAI API KEY",
|
94 |
+
type="password")
|
95 |
+
chatbot = gr.Chatbot(
|
96 |
+
label="GPT-4V + SoM",
|
97 |
+
height=256)
|
98 |
run_button = gr.Button("Run")
|
99 |
|
100 |
with gr.Blocks() as demo:
|
|
|
110 |
with gr.Column():
|
111 |
image_output.render()
|
112 |
run_button.render()
|
113 |
+
textbox_api_key.render()
|
114 |
+
with gr.Row():
|
115 |
+
gr.ChatInterface(chatbot=chatbot, fn=prompt)
|
116 |
|
117 |
run_button.click(
|
118 |
fn=inference,
|
gpt4v.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import base64
|
3 |
+
import requests
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
|
8 |
+
META_PROMPT = '''
|
9 |
+
- For any marks mentioned in your answer, please highlight them with [].
|
10 |
+
'''
|
11 |
+
API_URL = "https://api.openai.com/v1/chat/completions"
|
12 |
+
|
13 |
+
|
14 |
+
def encode_image_to_base64(image: np.ndarray) -> str:
|
15 |
+
"""
|
16 |
+
Encodes an image into a base64-encoded string in JPEG format.
|
17 |
+
|
18 |
+
Parameters:
|
19 |
+
image (np.ndarray): The image to be encoded. This should be a numpy array as
|
20 |
+
typically used in OpenCV.
|
21 |
+
|
22 |
+
Returns:
|
23 |
+
str: A base64-encoded string representing the image in JPEG format.
|
24 |
+
"""
|
25 |
+
success, buffer = cv2.imencode('.jpg', image)
|
26 |
+
if not success:
|
27 |
+
raise ValueError("Could not encode image to JPEG format.")
|
28 |
+
|
29 |
+
encoded_image = base64.b64encode(buffer).decode('utf-8')
|
30 |
+
return encoded_image
|
31 |
+
|
32 |
+
|
33 |
+
def compose_headers(api_key: str) -> dict:
|
34 |
+
return {
|
35 |
+
"Content-Type": "application/json",
|
36 |
+
"Authorization": f"Bearer {api_key}"
|
37 |
+
}
|
38 |
+
|
39 |
+
|
40 |
+
def compose_payload(image: np.ndarray, prompt: str) -> dict:
|
41 |
+
base64_image = encode_image_to_base64(image)
|
42 |
+
return {
|
43 |
+
"model": "gpt-4-vision-preview",
|
44 |
+
"messages": [
|
45 |
+
{
|
46 |
+
"role": "user",
|
47 |
+
"content": [
|
48 |
+
{
|
49 |
+
"role": "system",
|
50 |
+
"content": [
|
51 |
+
META_PROMPT
|
52 |
+
]
|
53 |
+
},
|
54 |
+
{
|
55 |
+
"type": "text",
|
56 |
+
"text": prompt
|
57 |
+
},
|
58 |
+
{
|
59 |
+
"type": "image_url",
|
60 |
+
"image_url": {
|
61 |
+
"url": f"data:image/jpeg;base64,{base64_image}"
|
62 |
+
}
|
63 |
+
}
|
64 |
+
]
|
65 |
+
}
|
66 |
+
],
|
67 |
+
"max_tokens": 800
|
68 |
+
}
|
69 |
+
|
70 |
+
|
71 |
+
def prompt_image(api_key: str, image: np.ndarray, prompt: str) -> str:
|
72 |
+
headers = compose_headers(api_key=api_key)
|
73 |
+
payload = compose_payload(image=image, prompt=prompt)
|
74 |
+
response = requests.post(url=API_URL, headers=headers, json=payload).json()
|
75 |
+
|
76 |
+
if 'error' in response:
|
77 |
+
raise ValueError(response['error']['message'])
|
78 |
+
return response['choices'][0]['message']['content']
|
utils.py
CHANGED
@@ -113,11 +113,58 @@ def filter_masks_by_relative_area(
|
|
113 |
return masks[min_area_filter & max_area_filter]
|
114 |
|
115 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
116 |
def postprocess_masks(
|
117 |
detections: sv.Detections,
|
118 |
-
area_threshold: float = 0.
|
119 |
-
min_relative_area: float = 0.
|
120 |
-
max_relative_area: float = 1.0
|
|
|
121 |
) -> sv.Detections:
|
122 |
"""
|
123 |
Post-processes the masks of detection objects by removing small islands and filling
|
@@ -128,6 +175,8 @@ def postprocess_masks(
|
|
128 |
area_threshold (float): Threshold for relative area to remove or fill features.
|
129 |
min_relative_area (float): Minimum relative area threshold for detections.
|
130 |
max_relative_area (float): Maximum relative area threshold for detections.
|
|
|
|
|
131 |
|
132 |
Returns:
|
133 |
np.ndarray: Post-processed masks.
|
@@ -148,6 +197,9 @@ def postprocess_masks(
|
|
148 |
masks=masks,
|
149 |
min_relative_area=min_relative_area,
|
150 |
max_relative_area=max_relative_area)
|
|
|
|
|
|
|
151 |
|
152 |
return sv.Detections(
|
153 |
xyxy=sv.mask_to_xyxy(masks),
|
|
|
113 |
return masks[min_area_filter & max_area_filter]
|
114 |
|
115 |
|
116 |
+
def compute_iou(mask1: np.ndarray, mask2: np.ndarray) -> float:
|
117 |
+
"""
|
118 |
+
Computes the Intersection over Union (IoU) of two masks.
|
119 |
+
|
120 |
+
Parameters:
|
121 |
+
mask1, mask2 (np.ndarray): Two mask arrays.
|
122 |
+
|
123 |
+
Returns:
|
124 |
+
float: The IoU of the two masks.
|
125 |
+
"""
|
126 |
+
intersection = np.logical_and(mask1, mask2).sum()
|
127 |
+
union = np.logical_or(mask1, mask2).sum()
|
128 |
+
return intersection / union if union != 0 else 0
|
129 |
+
|
130 |
+
|
131 |
+
def filter_highly_overlapping_masks(
|
132 |
+
masks: np.ndarray,
|
133 |
+
iou_threshold: float
|
134 |
+
) -> np.ndarray:
|
135 |
+
"""
|
136 |
+
Removes masks with high overlap from a set of masks.
|
137 |
+
|
138 |
+
Parameters:
|
139 |
+
masks (np.ndarray): A 3D numpy array with shape (N, H, W), where N is the
|
140 |
+
number of masks, and H and W are the height and width of the masks.
|
141 |
+
iou_threshold (float): The IoU threshold above which masks will be considered as
|
142 |
+
overlapping.
|
143 |
+
|
144 |
+
Returns:
|
145 |
+
np.ndarray: A 3D numpy array of masks with highly overlapping masks removed.
|
146 |
+
"""
|
147 |
+
num_masks = masks.shape[0]
|
148 |
+
keep_mask = np.ones(num_masks, dtype=bool)
|
149 |
+
|
150 |
+
for i in range(num_masks):
|
151 |
+
for j in range(i + 1, num_masks):
|
152 |
+
if not keep_mask[i] or not keep_mask[j]:
|
153 |
+
continue
|
154 |
+
|
155 |
+
iou = compute_iou(masks[i, :, :], masks[j, :, :])
|
156 |
+
if iou > iou_threshold:
|
157 |
+
keep_mask[j] = False
|
158 |
+
|
159 |
+
return masks[keep_mask]
|
160 |
+
|
161 |
+
|
162 |
def postprocess_masks(
|
163 |
detections: sv.Detections,
|
164 |
+
area_threshold: float = 0.01,
|
165 |
+
min_relative_area: float = 0.01,
|
166 |
+
max_relative_area: float = 1.0,
|
167 |
+
iou_threshold: float = 0.9
|
168 |
) -> sv.Detections:
|
169 |
"""
|
170 |
Post-processes the masks of detection objects by removing small islands and filling
|
|
|
175 |
area_threshold (float): Threshold for relative area to remove or fill features.
|
176 |
min_relative_area (float): Minimum relative area threshold for detections.
|
177 |
max_relative_area (float): Maximum relative area threshold for detections.
|
178 |
+
iou_threshold (float): The IoU threshold above which masks will be considered as
|
179 |
+
overlapping.
|
180 |
|
181 |
Returns:
|
182 |
np.ndarray: Post-processed masks.
|
|
|
197 |
masks=masks,
|
198 |
min_relative_area=min_relative_area,
|
199 |
max_relative_area=max_relative_area)
|
200 |
+
masks = filter_highly_overlapping_masks(
|
201 |
+
masks=masks,
|
202 |
+
iou_threshold=iou_threshold)
|
203 |
|
204 |
return sv.Detections(
|
205 |
xyxy=sv.mask_to_xyxy(masks),
|