Deleted 'gpt4v.py' and moved its functionalities to 'utils.py' and 'app.py'.
Browse files
Dockerfile
CHANGED
@@ -32,7 +32,7 @@ RUN pip install torch==2.0.1+cu117 torchvision==0.15.2+cu117 -f https://download
|
|
32 |
|
33 |
# Install dependencies
|
34 |
RUN pip install --no-cache-dir gradio==3.50.2 opencv-python supervision==0.17.0rc4 \
|
35 |
-
pillow requests
|
36 |
|
37 |
# Install SAM and Detectron2
|
38 |
RUN pip install 'git+https://github.com/facebookresearch/segment-anything.git'
|
@@ -44,7 +44,6 @@ RUN wget -c -O $HOME/app/weights/sam_vit_h_4b8939.pth https://dl.fbaipublicfiles
|
|
44 |
|
45 |
COPY app.py .
|
46 |
COPY utils.py .
|
47 |
-
COPY gpt4v.py .
|
48 |
COPY sam_utils.py .
|
49 |
|
50 |
RUN find $HOME/app
|
|
|
32 |
|
33 |
# Install dependencies
|
34 |
RUN pip install --no-cache-dir gradio==3.50.2 opencv-python supervision==0.17.0rc4 \
|
35 |
+
pillow requests setofmark==0.1.0rc3
|
36 |
|
37 |
# Install SAM and Detectron2
|
38 |
RUN pip install 'git+https://github.com/facebookresearch/segment-anything.git'
|
|
|
44 |
|
45 |
COPY app.py .
|
46 |
COPY utils.py .
|
|
|
47 |
COPY sam_utils.py .
|
48 |
|
49 |
RUN find $HOME/app
|
app.py
CHANGED
@@ -4,13 +4,13 @@ from typing import List, Dict, Tuple, Any, Optional
|
|
4 |
import cv2
|
5 |
import gradio as gr
|
6 |
import numpy as np
|
|
|
7 |
import supervision as sv
|
8 |
import torch
|
9 |
from segment_anything import sam_model_registry
|
10 |
|
11 |
-
from gpt4v import prompt_image
|
12 |
from sam_utils import sam_interactive_inference, sam_inference
|
13 |
-
from utils import postprocess_masks, Visualizer
|
14 |
|
15 |
HOME = os.getenv("HOME")
|
16 |
DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
@@ -21,17 +21,21 @@ SAM_MODEL_TYPE = "vit_h"
|
|
21 |
|
22 |
ANNOTATED_IMAGE_KEY = "annotated_image"
|
23 |
DETECTIONS_KEY = "detections"
|
24 |
-
|
25 |
MARKDOWN = """
|
26 |
-
|
27 |
-
|
28 |
-
<
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
|
|
|
|
|
|
|
|
|
|
35 |
|
36 |
## 🚧 Roadmap
|
37 |
|
@@ -90,7 +94,7 @@ def prompt(
|
|
90 |
return "⚠️ Please set your OpenAI API key first"
|
91 |
if state is None or ANNOTATED_IMAGE_KEY not in state:
|
92 |
return "⚠️ Please generate SoM visual prompt first"
|
93 |
-
return prompt_image(
|
94 |
api_key=api_key,
|
95 |
image=cv2.cvtColor(state[ANNOTATED_IMAGE_KEY], cv2.COLOR_BGR2RGB),
|
96 |
prompt=message
|
@@ -114,15 +118,17 @@ def highlight(
|
|
114 |
if len(history) == 0:
|
115 |
return None
|
116 |
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
|
|
|
|
|
|
123 |
]
|
124 |
-
|
125 |
-
return annotated_image, highlighted_detections
|
126 |
|
127 |
|
128 |
image_input = gr.Image(
|
@@ -131,7 +137,8 @@ image_input = gr.Image(
|
|
131 |
tool="sketch",
|
132 |
interactive=True,
|
133 |
brush_radius=20.0,
|
134 |
-
brush_color="#FFFFFF"
|
|
|
135 |
)
|
136 |
checkbox_annotation_mode = gr.CheckboxGroup(
|
137 |
choices=["Mark", "Polygon", "Mask", "Box"],
|
@@ -147,7 +154,8 @@ image_output = gr.AnnotatedImage(
|
|
147 |
color_map={
|
148 |
str(i): sv.ColorPalette.default().by_idx(i).as_hex()
|
149 |
for i in range(64)
|
150 |
-
}
|
|
|
151 |
)
|
152 |
openai_api_key = gr.Textbox(
|
153 |
show_label=False,
|
|
|
4 |
import cv2
|
5 |
import gradio as gr
|
6 |
import numpy as np
|
7 |
+
import som
|
8 |
import supervision as sv
|
9 |
import torch
|
10 |
from segment_anything import sam_model_registry
|
11 |
|
|
|
12 |
from sam_utils import sam_interactive_inference, sam_inference
|
13 |
+
from utils import postprocess_masks, Visualizer
|
14 |
|
15 |
HOME = os.getenv("HOME")
|
16 |
DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
|
|
21 |
|
22 |
ANNOTATED_IMAGE_KEY = "annotated_image"
|
23 |
DETECTIONS_KEY = "detections"
|
|
|
24 |
MARKDOWN = """
|
25 |
+
<div align='center'>
|
26 |
+
<h1>
|
27 |
+
<img
|
28 |
+
src='https://som-gpt4v.github.io/website/img/som_logo.png'
|
29 |
+
style='height:50px; display:inline-block'
|
30 |
+
/>
|
31 |
+
Set-of-Mark (SoM) Prompting Unleashes Extraordinary Visual Grounding in GPT-4V
|
32 |
+
</h1>
|
33 |
+
<br>
|
34 |
+
[<a href="https://arxiv.org/abs/2109.07529"> arXiv paper </a>]
|
35 |
+
[<a href="https://som-gpt4v.github.io"> project page </a>]
|
36 |
+
[<a href="https://github.com/roboflow/set-of-mark"> python package </a>]
|
37 |
+
[<a href="https://github.com/microsoft/SoM"> code </a>]
|
38 |
+
</div>
|
39 |
|
40 |
## 🚧 Roadmap
|
41 |
|
|
|
94 |
return "⚠️ Please set your OpenAI API key first"
|
95 |
if state is None or ANNOTATED_IMAGE_KEY not in state:
|
96 |
return "⚠️ Please generate SoM visual prompt first"
|
97 |
+
return som.prompt_image(
|
98 |
api_key=api_key,
|
99 |
image=cv2.cvtColor(state[ANNOTATED_IMAGE_KEY], cv2.COLOR_BGR2RGB),
|
100 |
prompt=message
|
|
|
118 |
if len(history) == 0:
|
119 |
return None
|
120 |
|
121 |
+
text = history[-1][-1]
|
122 |
+
relevant_masks = som.extract_relevant_masks(
|
123 |
+
text=text,
|
124 |
+
detections=detections
|
125 |
+
)
|
126 |
+
relevant_masks = [
|
127 |
+
(mask, mark)
|
128 |
+
for mark, mask
|
129 |
+
in relevant_masks.items()
|
130 |
]
|
131 |
+
return annotated_image, relevant_masks
|
|
|
132 |
|
133 |
|
134 |
image_input = gr.Image(
|
|
|
137 |
tool="sketch",
|
138 |
interactive=True,
|
139 |
brush_radius=20.0,
|
140 |
+
brush_color="#FFFFFF",
|
141 |
+
height=512
|
142 |
)
|
143 |
checkbox_annotation_mode = gr.CheckboxGroup(
|
144 |
choices=["Mark", "Polygon", "Mask", "Box"],
|
|
|
154 |
color_map={
|
155 |
str(i): sv.ColorPalette.default().by_idx(i).as_hex()
|
156 |
for i in range(64)
|
157 |
+
},
|
158 |
+
height=512
|
159 |
)
|
160 |
openai_api_key = gr.Textbox(
|
161 |
show_label=False,
|
gpt4v.py
DELETED
@@ -1,81 +0,0 @@
|
|
1 |
-
import cv2
|
2 |
-
import base64
|
3 |
-
import requests
|
4 |
-
|
5 |
-
import numpy as np
|
6 |
-
|
7 |
-
|
8 |
-
META_PROMPT = '''
|
9 |
-
For any labels or markings on an image that you reference in your response, please
|
10 |
-
enclose them in square brackets ([]) and list them explicitly. Do not use ranges; for
|
11 |
-
example, instead of '1 - 4', list as '[1], [2], [3], [4]'. These labels could be
|
12 |
-
numbers or letters and typically correspond to specific segments or parts of the image.
|
13 |
-
'''
|
14 |
-
API_URL = "https://api.openai.com/v1/chat/completions"
|
15 |
-
|
16 |
-
|
17 |
-
def encode_image_to_base64(image: np.ndarray) -> str:
|
18 |
-
"""
|
19 |
-
Encodes an image into a base64-encoded string in JPEG format.
|
20 |
-
|
21 |
-
Parameters:
|
22 |
-
image (np.ndarray): The image to be encoded. This should be a numpy array as
|
23 |
-
typically used in OpenCV.
|
24 |
-
|
25 |
-
Returns:
|
26 |
-
str: A base64-encoded string representing the image in JPEG format.
|
27 |
-
"""
|
28 |
-
success, buffer = cv2.imencode('.jpg', image)
|
29 |
-
if not success:
|
30 |
-
raise ValueError("Could not encode image to JPEG format.")
|
31 |
-
|
32 |
-
encoded_image = base64.b64encode(buffer).decode('utf-8')
|
33 |
-
return encoded_image
|
34 |
-
|
35 |
-
|
36 |
-
def compose_headers(api_key: str) -> dict:
|
37 |
-
return {
|
38 |
-
"Content-Type": "application/json",
|
39 |
-
"Authorization": f"Bearer {api_key}"
|
40 |
-
}
|
41 |
-
|
42 |
-
|
43 |
-
def compose_payload(image: np.ndarray, prompt: str) -> dict:
|
44 |
-
base64_image = encode_image_to_base64(image)
|
45 |
-
return {
|
46 |
-
"model": "gpt-4-vision-preview",
|
47 |
-
"messages": [
|
48 |
-
{
|
49 |
-
"role": "system",
|
50 |
-
"content": [
|
51 |
-
META_PROMPT
|
52 |
-
]
|
53 |
-
},
|
54 |
-
{
|
55 |
-
"role": "user",
|
56 |
-
"content": [
|
57 |
-
{
|
58 |
-
"type": "text",
|
59 |
-
"text": prompt
|
60 |
-
},
|
61 |
-
{
|
62 |
-
"type": "image_url",
|
63 |
-
"image_url": {
|
64 |
-
"url": f"data:image/jpeg;base64,{base64_image}"
|
65 |
-
}
|
66 |
-
}
|
67 |
-
]
|
68 |
-
}
|
69 |
-
],
|
70 |
-
"max_tokens": 800
|
71 |
-
}
|
72 |
-
|
73 |
-
|
74 |
-
def prompt_image(api_key: str, image: np.ndarray, prompt: str) -> str:
|
75 |
-
headers = compose_headers(api_key=api_key)
|
76 |
-
payload = compose_payload(image=image, prompt=prompt)
|
77 |
-
response = requests.post(url=API_URL, headers=headers, json=payload).json()
|
78 |
-
|
79 |
-
if 'error' in response:
|
80 |
-
raise ValueError(response['error']['message'])
|
81 |
-
return response['choices'][0]['message']['content']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
utils.py
CHANGED
@@ -1,7 +1,5 @@
|
|
1 |
-
import re
|
2 |
-
from typing import List
|
3 |
-
|
4 |
import cv2
|
|
|
5 |
|
6 |
import numpy as np
|
7 |
import supervision as sv
|
@@ -13,7 +11,7 @@ class Visualizer:
|
|
13 |
self,
|
14 |
line_thickness: int = 2,
|
15 |
mask_opacity: float = 0.1,
|
16 |
-
text_scale: float = 0.
|
17 |
) -> None:
|
18 |
self.box_annotator = sv.BoundingBoxAnnotator(
|
19 |
color_lookup=sv.ColorLookup.INDEX,
|
@@ -25,6 +23,8 @@ class Visualizer:
|
|
25 |
color_lookup=sv.ColorLookup.INDEX,
|
26 |
thickness=line_thickness)
|
27 |
self.label_annotator = sv.LabelAnnotator(
|
|
|
|
|
28 |
color_lookup=sv.ColorLookup.INDEX,
|
29 |
text_position=sv.Position.CENTER_OF_MASS,
|
30 |
text_scale=text_scale)
|
@@ -85,7 +85,11 @@ def refine_mask(
|
|
85 |
relative_area = area / total_area
|
86 |
if relative_area < area_threshold:
|
87 |
cv2.drawContours(
|
88 |
-
mask,
|
|
|
|
|
|
|
|
|
89 |
)
|
90 |
|
91 |
return np.where(mask > 0, 1, 0).astype(bool)
|
@@ -116,52 +120,6 @@ def filter_masks_by_relative_area(
|
|
116 |
return masks[min_area_filter & max_area_filter]
|
117 |
|
118 |
|
119 |
-
def compute_iou(mask1: np.ndarray, mask2: np.ndarray) -> float:
|
120 |
-
"""
|
121 |
-
Computes the Intersection over Union (IoU) of two masks.
|
122 |
-
|
123 |
-
Parameters:
|
124 |
-
mask1, mask2 (np.ndarray): Two mask arrays.
|
125 |
-
|
126 |
-
Returns:
|
127 |
-
float: The IoU of the two masks.
|
128 |
-
"""
|
129 |
-
intersection = np.logical_and(mask1, mask2).sum()
|
130 |
-
union = np.logical_or(mask1, mask2).sum()
|
131 |
-
return intersection / union if union != 0 else 0
|
132 |
-
|
133 |
-
|
134 |
-
def filter_highly_overlapping_masks(
|
135 |
-
masks: np.ndarray,
|
136 |
-
iou_threshold: float
|
137 |
-
) -> np.ndarray:
|
138 |
-
"""
|
139 |
-
Removes masks with high overlap from a set of masks.
|
140 |
-
|
141 |
-
Parameters:
|
142 |
-
masks (np.ndarray): A 3D numpy array with shape (N, H, W), where N is the
|
143 |
-
number of masks, and H and W are the height and width of the masks.
|
144 |
-
iou_threshold (float): The IoU threshold above which masks will be considered as
|
145 |
-
overlapping.
|
146 |
-
|
147 |
-
Returns:
|
148 |
-
np.ndarray: A 3D numpy array of masks with highly overlapping masks removed.
|
149 |
-
"""
|
150 |
-
num_masks = masks.shape[0]
|
151 |
-
keep_mask = np.ones(num_masks, dtype=bool)
|
152 |
-
|
153 |
-
for i in range(num_masks):
|
154 |
-
for j in range(i + 1, num_masks):
|
155 |
-
if not keep_mask[i] or not keep_mask[j]:
|
156 |
-
continue
|
157 |
-
|
158 |
-
iou = compute_iou(masks[i, :, :], masks[j, :, :])
|
159 |
-
if iou > iou_threshold:
|
160 |
-
keep_mask[j] = False
|
161 |
-
|
162 |
-
return masks[keep_mask]
|
163 |
-
|
164 |
-
|
165 |
def postprocess_masks(
|
166 |
detections: sv.Detections,
|
167 |
area_threshold: float = 0.01,
|
@@ -200,7 +158,7 @@ def postprocess_masks(
|
|
200 |
masks=masks,
|
201 |
min_relative_area=min_relative_area,
|
202 |
max_relative_area=max_relative_area)
|
203 |
-
masks =
|
204 |
masks=masks,
|
205 |
iou_threshold=iou_threshold)
|
206 |
|
@@ -208,18 +166,3 @@ def postprocess_masks(
|
|
208 |
xyxy=sv.mask_to_xyxy(masks),
|
209 |
mask=masks
|
210 |
)
|
211 |
-
|
212 |
-
|
213 |
-
def extract_numbers_in_brackets(text: str) -> List[int]:
|
214 |
-
"""
|
215 |
-
Extracts all numbers enclosed in square brackets from a given string.
|
216 |
-
|
217 |
-
Args:
|
218 |
-
text (str): The string to be searched.
|
219 |
-
|
220 |
-
Returns:
|
221 |
-
List[int]: A list of integers found within square brackets.
|
222 |
-
"""
|
223 |
-
pattern = r'\[(\d+)\]'
|
224 |
-
numbers = [int(num) for num in re.findall(pattern, text)]
|
225 |
-
return numbers
|
|
|
|
|
|
|
|
|
1 |
import cv2
|
2 |
+
import som
|
3 |
|
4 |
import numpy as np
|
5 |
import supervision as sv
|
|
|
11 |
self,
|
12 |
line_thickness: int = 2,
|
13 |
mask_opacity: float = 0.1,
|
14 |
+
text_scale: float = 0.6
|
15 |
) -> None:
|
16 |
self.box_annotator = sv.BoundingBoxAnnotator(
|
17 |
color_lookup=sv.ColorLookup.INDEX,
|
|
|
23 |
color_lookup=sv.ColorLookup.INDEX,
|
24 |
thickness=line_thickness)
|
25 |
self.label_annotator = sv.LabelAnnotator(
|
26 |
+
color=sv.Color.black(),
|
27 |
+
text_color=sv.Color.white(),
|
28 |
color_lookup=sv.ColorLookup.INDEX,
|
29 |
text_position=sv.Position.CENTER_OF_MASS,
|
30 |
text_scale=text_scale)
|
|
|
85 |
relative_area = area / total_area
|
86 |
if relative_area < area_threshold:
|
87 |
cv2.drawContours(
|
88 |
+
image=mask,
|
89 |
+
contours=[contour],
|
90 |
+
contourIdx=-1,
|
91 |
+
color=(0 if mode == 'islands' else 255),
|
92 |
+
thickness=-1
|
93 |
)
|
94 |
|
95 |
return np.where(mask > 0, 1, 0).astype(bool)
|
|
|
120 |
return masks[min_area_filter & max_area_filter]
|
121 |
|
122 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
123 |
def postprocess_masks(
|
124 |
detections: sv.Detections,
|
125 |
area_threshold: float = 0.01,
|
|
|
158 |
masks=masks,
|
159 |
min_relative_area=min_relative_area,
|
160 |
max_relative_area=max_relative_area)
|
161 |
+
masks = som.mask_non_max_suppression(
|
162 |
masks=masks,
|
163 |
iou_threshold=iou_threshold)
|
164 |
|
|
|
166 |
xyxy=sv.mask_to_xyxy(masks),
|
167 |
mask=masks
|
168 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|