SkalskiP commited on
Commit
e0f6bc4
1 Parent(s): 203e0e8

Add IoU filter function and update Dockerfile & app.py

Browse files

Added functions to compute Intersection over Union (IoU) and to filter out highly overlapping masks in utils.py. This is to improve the quality of masks by removing redundant masks. Updated post-processing masks function parameters and values in utils.py.

Further, Dockerfile is updated to add one more dependency - requests, a file - gpt4v.py, where it is used for communication with OpenAI API to compute image data in base64 format.

Additionally, OpenAI API key support and changes reflecting chat interface elements in Run button are made to app.py for interaction with GPT-4V. GPT-4V module added to improve the interactive mode support for image based queries.

Files changed (4) hide show
  1. Dockerfile +3 -1
  2. app.py +23 -2
  3. gpt4v.py +78 -0
  4. utils.py +55 -3
Dockerfile CHANGED
@@ -31,7 +31,8 @@ WORKDIR $HOME/app
31
  RUN pip install torch==2.0.1+cu117 torchvision==0.15.2+cu117 -f https://download.pytorch.org/whl/torch_stable.html
32
 
33
  # Install dependencies
34
- RUN pip install --no-cache-dir gradio==3.50.2 opencv-python supervision==0.17.0rc3 pillow
 
35
 
36
  # Install SAM and Detectron2
37
  RUN pip install 'git+https://github.com/facebookresearch/segment-anything.git'
@@ -43,6 +44,7 @@ RUN wget -c -O $HOME/app/weights/sam_vit_h_4b8939.pth https://dl.fbaipublicfiles
43
 
44
  COPY app.py .
45
  COPY utils.py .
 
46
 
47
  RUN find $HOME/app
48
 
 
31
  RUN pip install torch==2.0.1+cu117 torchvision==0.15.2+cu117 -f https://download.pytorch.org/whl/torch_stable.html
32
 
33
  # Install dependencies
34
+ RUN pip install --no-cache-dir gradio==3.50.2 opencv-python supervision==0.17.0rc3 \
35
+ pillow requests
36
 
37
  # Install SAM and Detectron2
38
  RUN pip install 'git+https://github.com/facebookresearch/segment-anything.git'
 
44
 
45
  COPY app.py .
46
  COPY utils.py .
47
+ COPY gpt4v.py .
48
 
49
  RUN find $HOME/app
50
 
app.py CHANGED
@@ -14,8 +14,8 @@ HOME = os.getenv("HOME")
14
  DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
15
  MINIMUM_AREA_THRESHOLD = 0.01
16
 
17
- # SAM_CHECKPOINT = os.path.join(HOME, "app/weights/sam_vit_h_4b8939.pth")
18
- SAM_CHECKPOINT = "weights/sam_vit_h_4b8939.pth"
19
  SAM_MODEL_TYPE = "vit_h"
20
 
21
  MARKDOWN = """
@@ -27,11 +27,19 @@ MARKDOWN = """
27
  Set-of-Mark (SoM) Prompting Unleashes Extraordinary Visual Grounding in GPT-4V
28
  </h1>
29
 
 
 
 
 
 
 
 
30
  ## 🚧 Roadmap
31
 
32
  - [ ] Support for alphabetic labels
33
  - [ ] Support for Semantic-SAM (multi-level)
34
  - [ ] Support for interactive mode
 
35
  """
36
 
37
  SAM = sam_model_registry[SAM_MODEL_TYPE](checkpoint=SAM_CHECKPOINT).to(device=DEVICE)
@@ -60,6 +68,10 @@ def inference(
60
  return cv2.cvtColor(annotated_image, cv2.COLOR_BGR2RGB)
61
 
62
 
 
 
 
 
63
  image_input = gr.Image(
64
  label="Input",
65
  type="numpy",
@@ -77,6 +89,12 @@ image_output = gr.Image(
77
  label="SoM Visual Prompt",
78
  type="numpy",
79
  height=512)
 
 
 
 
 
 
80
  run_button = gr.Button("Run")
81
 
82
  with gr.Blocks() as demo:
@@ -92,6 +110,9 @@ with gr.Blocks() as demo:
92
  with gr.Column():
93
  image_output.render()
94
  run_button.render()
 
 
 
95
 
96
  run_button.click(
97
  fn=inference,
 
14
  DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
15
  MINIMUM_AREA_THRESHOLD = 0.01
16
 
17
+ SAM_CHECKPOINT = os.path.join(HOME, "app/weights/sam_vit_h_4b8939.pth")
18
+ # SAM_CHECKPOINT = "weights/sam_vit_h_4b8939.pth"
19
  SAM_MODEL_TYPE = "vit_h"
20
 
21
  MARKDOWN = """
 
27
  Set-of-Mark (SoM) Prompting Unleashes Extraordinary Visual Grounding in GPT-4V
28
  </h1>
29
 
30
+ ## 🚀 How To
31
+
32
+ - Upload an image.
33
+ - Click the `Run` button to generate the image with marks.
34
+ - Pass OpenAI API 🔑. You can get one [here](https://platform.openai.com/api-keys).
35
+ - Ask GPT-4V questions about the image in the chatbot.
36
+
37
  ## 🚧 Roadmap
38
 
39
  - [ ] Support for alphabetic labels
40
  - [ ] Support for Semantic-SAM (multi-level)
41
  - [ ] Support for interactive mode
42
+ - [ ] Support for result highlighting
43
  """
44
 
45
  SAM = sam_model_registry[SAM_MODEL_TYPE](checkpoint=SAM_CHECKPOINT).to(device=DEVICE)
 
68
  return cv2.cvtColor(annotated_image, cv2.COLOR_BGR2RGB)
69
 
70
 
71
+ def prompt(message, history):
72
+ return "response"
73
+
74
+
75
  image_input = gr.Image(
76
  label="Input",
77
  type="numpy",
 
89
  label="SoM Visual Prompt",
90
  type="numpy",
91
  height=512)
92
+ textbox_api_key = gr.Textbox(
93
+ label="OpenAI API KEY",
94
+ type="password")
95
+ chatbot = gr.Chatbot(
96
+ label="GPT-4V + SoM",
97
+ height=256)
98
  run_button = gr.Button("Run")
99
 
100
  with gr.Blocks() as demo:
 
110
  with gr.Column():
111
  image_output.render()
112
  run_button.render()
113
+ textbox_api_key.render()
114
+ with gr.Row():
115
+ gr.ChatInterface(chatbot=chatbot, fn=prompt)
116
 
117
  run_button.click(
118
  fn=inference,
gpt4v.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import base64
3
+ import requests
4
+
5
+ import numpy as np
6
+
7
+
8
+ META_PROMPT = '''
9
+ - For any marks mentioned in your answer, please highlight them with [].
10
+ '''
11
+ API_URL = "https://api.openai.com/v1/chat/completions"
12
+
13
+
14
+ def encode_image_to_base64(image: np.ndarray) -> str:
15
+ """
16
+ Encodes an image into a base64-encoded string in JPEG format.
17
+
18
+ Parameters:
19
+ image (np.ndarray): The image to be encoded. This should be a numpy array as
20
+ typically used in OpenCV.
21
+
22
+ Returns:
23
+ str: A base64-encoded string representing the image in JPEG format.
24
+ """
25
+ success, buffer = cv2.imencode('.jpg', image)
26
+ if not success:
27
+ raise ValueError("Could not encode image to JPEG format.")
28
+
29
+ encoded_image = base64.b64encode(buffer).decode('utf-8')
30
+ return encoded_image
31
+
32
+
33
+ def compose_headers(api_key: str) -> dict:
34
+ return {
35
+ "Content-Type": "application/json",
36
+ "Authorization": f"Bearer {api_key}"
37
+ }
38
+
39
+
40
+ def compose_payload(image: np.ndarray, prompt: str) -> dict:
41
+ base64_image = encode_image_to_base64(image)
42
+ return {
43
+ "model": "gpt-4-vision-preview",
44
+ "messages": [
45
+ {
46
+ "role": "user",
47
+ "content": [
48
+ {
49
+ "role": "system",
50
+ "content": [
51
+ META_PROMPT
52
+ ]
53
+ },
54
+ {
55
+ "type": "text",
56
+ "text": prompt
57
+ },
58
+ {
59
+ "type": "image_url",
60
+ "image_url": {
61
+ "url": f"data:image/jpeg;base64,{base64_image}"
62
+ }
63
+ }
64
+ ]
65
+ }
66
+ ],
67
+ "max_tokens": 800
68
+ }
69
+
70
+
71
+ def prompt_image(api_key: str, image: np.ndarray, prompt: str) -> str:
72
+ headers = compose_headers(api_key=api_key)
73
+ payload = compose_payload(image=image, prompt=prompt)
74
+ response = requests.post(url=API_URL, headers=headers, json=payload).json()
75
+
76
+ if 'error' in response:
77
+ raise ValueError(response['error']['message'])
78
+ return response['choices'][0]['message']['content']
utils.py CHANGED
@@ -113,11 +113,58 @@ def filter_masks_by_relative_area(
113
  return masks[min_area_filter & max_area_filter]
114
 
115
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
  def postprocess_masks(
117
  detections: sv.Detections,
118
- area_threshold: float = 0.02,
119
- min_relative_area: float = 0.02,
120
- max_relative_area: float = 1.0
 
121
  ) -> sv.Detections:
122
  """
123
  Post-processes the masks of detection objects by removing small islands and filling
@@ -128,6 +175,8 @@ def postprocess_masks(
128
  area_threshold (float): Threshold for relative area to remove or fill features.
129
  min_relative_area (float): Minimum relative area threshold for detections.
130
  max_relative_area (float): Maximum relative area threshold for detections.
 
 
131
 
132
  Returns:
133
  np.ndarray: Post-processed masks.
@@ -148,6 +197,9 @@ def postprocess_masks(
148
  masks=masks,
149
  min_relative_area=min_relative_area,
150
  max_relative_area=max_relative_area)
 
 
 
151
 
152
  return sv.Detections(
153
  xyxy=sv.mask_to_xyxy(masks),
 
113
  return masks[min_area_filter & max_area_filter]
114
 
115
 
116
+ def compute_iou(mask1: np.ndarray, mask2: np.ndarray) -> float:
117
+ """
118
+ Computes the Intersection over Union (IoU) of two masks.
119
+
120
+ Parameters:
121
+ mask1, mask2 (np.ndarray): Two mask arrays.
122
+
123
+ Returns:
124
+ float: The IoU of the two masks.
125
+ """
126
+ intersection = np.logical_and(mask1, mask2).sum()
127
+ union = np.logical_or(mask1, mask2).sum()
128
+ return intersection / union if union != 0 else 0
129
+
130
+
131
+ def filter_highly_overlapping_masks(
132
+ masks: np.ndarray,
133
+ iou_threshold: float
134
+ ) -> np.ndarray:
135
+ """
136
+ Removes masks with high overlap from a set of masks.
137
+
138
+ Parameters:
139
+ masks (np.ndarray): A 3D numpy array with shape (N, H, W), where N is the
140
+ number of masks, and H and W are the height and width of the masks.
141
+ iou_threshold (float): The IoU threshold above which masks will be considered as
142
+ overlapping.
143
+
144
+ Returns:
145
+ np.ndarray: A 3D numpy array of masks with highly overlapping masks removed.
146
+ """
147
+ num_masks = masks.shape[0]
148
+ keep_mask = np.ones(num_masks, dtype=bool)
149
+
150
+ for i in range(num_masks):
151
+ for j in range(i + 1, num_masks):
152
+ if not keep_mask[i] or not keep_mask[j]:
153
+ continue
154
+
155
+ iou = compute_iou(masks[i, :, :], masks[j, :, :])
156
+ if iou > iou_threshold:
157
+ keep_mask[j] = False
158
+
159
+ return masks[keep_mask]
160
+
161
+
162
  def postprocess_masks(
163
  detections: sv.Detections,
164
+ area_threshold: float = 0.01,
165
+ min_relative_area: float = 0.01,
166
+ max_relative_area: float = 1.0,
167
+ iou_threshold: float = 0.9
168
  ) -> sv.Detections:
169
  """
170
  Post-processes the masks of detection objects by removing small islands and filling
 
175
  area_threshold (float): Threshold for relative area to remove or fill features.
176
  min_relative_area (float): Minimum relative area threshold for detections.
177
  max_relative_area (float): Maximum relative area threshold for detections.
178
+ iou_threshold (float): The IoU threshold above which masks will be considered as
179
+ overlapping.
180
 
181
  Returns:
182
  np.ndarray: Post-processed masks.
 
197
  masks=masks,
198
  min_relative_area=min_relative_area,
199
  max_relative_area=max_relative_area)
200
+ masks = filter_highly_overlapping_masks(
201
+ masks=masks,
202
+ iou_threshold=iou_threshold)
203
 
204
  return sv.Detections(
205
  xyxy=sv.mask_to_xyxy(masks),