Jose Benitez commited on
Commit
aa36c04
·
1 Parent(s): 7048651
.DS_Store ADDED
Binary file (6.15 kB). View file
 
.gitattributes ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ efficient_sam_s_gpu.jit filter=lfs diff=lfs merge=lfs -text
2
+ efficient_sam_s_cpu.jit filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: SAM Arena
3
+ emoji: 🐢
4
+ colorFrom: red
5
+ colorTo: green
6
+ sdk: gradio
7
+ sdk_version: 4.9.0
8
+ app_file: app.py
9
+ pinned: false
10
+ ---
11
+
12
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Thanks to the following repos:
2
+ # https://huggingface.co/spaces/An-619/FastSAM/blob/main/app_gradio.py
3
+ # https://huggingface.co/spaces/SkalskiP/EfficientSAM
4
+ from typing import Tuple
5
+
6
+ from ultralytics import YOLO
7
+ from PIL import ImageDraw
8
+ from PIL import Image
9
+ import gradio as gr
10
+ import numpy as np
11
+ import torch
12
+
13
+ from transformers import SamModel, SamProcessor
14
+
15
+ import supervision as sv
16
+ from utils.tools_gradio import fast_process
17
+ from utils.tools import format_results, point_prompt
18
+ from utils.draw import draw_circle, calculate_dynamic_circle_radius
19
+ from utils.efficient_sam import load, inference_with_box, inference_with_point
20
+
21
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
22
+ # Load the pre-trained models
23
+ FASTSAM_MODEL = YOLO('FastSAM-s.pt')
24
+ SAM_MODEL = SamModel.from_pretrained("facebook/sam-vit-huge").to(DEVICE)
25
+ SAM_PROCESSOR = SamProcessor.from_pretrained("facebook/sam-vit-huge")
26
+ EFFICIENT_SAM_MODEL = load(device=DEVICE)
27
+
28
+ MASK_COLOR = sv.Color.from_hex("#FF0000")
29
+ PROMPT_COLOR = sv.Color.from_hex("#D3D3D3")
30
+ MASK_ANNOTATOR = sv.MaskAnnotator(
31
+ color=MASK_COLOR,
32
+ color_lookup=sv.ColorLookup.INDEX)
33
+
34
+ title = "<center><strong><font size='8'>🤗 Segment Anything Model Arena ⚔️</font></strong></center>"
35
+
36
+ description = "<center><font size='4'>This is a demo of the <strong>Segment Anything Model Arena</strong>, a collection of models for segmenting anything. "
37
+
38
+ css = "h1 { text-align: center } .about { text-align: justify; padding-left: 10%; padding-right: 10%; }"
39
+
40
+ #examples = [["examples/retail01.png"], ["examples/vend01.png"], ["examples/vend02.png"]]
41
+
42
+ POINT_EXAMPLES = [
43
+ ['https://media.roboflow.com/efficient-sam/corgi.jpg', 1291, 751],
44
+ ['https://media.roboflow.com/efficient-sam/horses.jpg', 1168, 939],
45
+ ['https://media.roboflow.com/efficient-sam/bears.jpg', 913, 1051]
46
+ ]
47
+
48
+ #default_example = examples[0]
49
+
50
+ def annotate_image_with_point_prompt_result(
51
+ image: np.ndarray,
52
+ detections: sv.Detections,
53
+ x: int,
54
+ y: int
55
+ ) -> np.ndarray:
56
+ h, w, _ = image.shape
57
+ bgr_image = image[:, :, ::-1]
58
+ annotated_bgr_image = MASK_ANNOTATOR.annotate(
59
+ scene=bgr_image, detections=detections)
60
+ annotated_bgr_image = draw_circle(
61
+ scene=annotated_bgr_image,
62
+ center=sv.Point(x=x, y=y),
63
+ radius=calculate_dynamic_circle_radius(resolution_wh=(w, h)),
64
+ color=PROMPT_COLOR)
65
+ return annotated_bgr_image[:, :, ::-1]
66
+
67
+ def SAM_points_inference(image: np.ndarray) -> np.ndarray:
68
+ global global_points
69
+ input_points = [[[float(num) for num in sublist]] for sublist in global_points]
70
+ print(input_points)
71
+ #input_points = [[[773.0, 167.0]]]
72
+ x = int(input_points[0][0][0])
73
+ y = int(input_points[0][0][1])
74
+
75
+ inputs = SAM_PROCESSOR(
76
+ Image.fromarray(image),
77
+ input_points=[input_points],
78
+ return_tensors="pt"
79
+ ).to(DEVICE)
80
+
81
+ with torch.no_grad():
82
+ outputs = SAM_MODEL(**inputs)
83
+
84
+ mask = SAM_PROCESSOR.image_processor.post_process_masks(
85
+ outputs.pred_masks.cpu(),
86
+ inputs["original_sizes"].cpu(),
87
+ inputs["reshaped_input_sizes"].cpu()
88
+ )[0][0][0].numpy()
89
+ mask = mask[np.newaxis, ...]
90
+ detections = sv.Detections(xyxy=sv.mask_to_xyxy(masks=mask), mask=mask)
91
+
92
+ return annotate_image_with_point_prompt_result(
93
+ image=image, detections=detections, x=x, y=y)
94
+
95
+ def FastSAM_points_inference(
96
+ input,
97
+ input_size=1024,
98
+ iou_threshold=0.7,
99
+ conf_threshold=0.25,
100
+ better_quality=False,
101
+ withContours=True,
102
+ use_retina=True,
103
+ mask_random_color=True,
104
+ ):
105
+ global global_points
106
+ global global_point_label
107
+ input = Image.fromarray(input)
108
+ input_size = int(input_size) # 确保 imgsz 是整数
109
+ # Thanks for the suggestion by hysts in HuggingFace.
110
+ w, h = input.size
111
+ scale = input_size / max(w, h)
112
+ new_w = int(w * scale)
113
+ new_h = int(h * scale)
114
+ input = input.resize((new_w, new_h))
115
+
116
+ scaled_points = [[int(x * scale) for x in point] for point in global_points]
117
+
118
+ results = FASTSAM_MODEL(input,
119
+ device=DEVICE,
120
+ retina_masks=True,
121
+ iou=iou_threshold,
122
+ conf=conf_threshold,
123
+ imgsz=input_size,)
124
+
125
+ results = format_results(results[0], 0)
126
+ annotations, _ = point_prompt(results, scaled_points, global_point_label, new_h, new_w)
127
+ annotations = np.array([annotations])
128
+
129
+ fig = fast_process(annotations=annotations,
130
+ image=input,
131
+ device=DEVICE,
132
+ scale=(1024 // input_size),
133
+ better_quality=better_quality,
134
+ mask_random_color=mask_random_color,
135
+ bbox=None,
136
+ use_retina=use_retina,
137
+ withContours=withContours,)
138
+
139
+ global_points = []
140
+ global_point_label = []
141
+
142
+ return fig
143
+
144
+ def EfficientSAM_points_inference(image: np.ndarray):
145
+ x, y = int(global_points[0][0]), int(global_points[0][1])
146
+ point = np.array([[int(x), int(y)]])
147
+ mask = inference_with_point(image, point, EFFICIENT_SAM_MODEL, DEVICE)
148
+ mask = mask[np.newaxis, ...]
149
+ detections = sv.Detections(xyxy=sv.mask_to_xyxy(masks=mask), mask=mask)
150
+
151
+ return annotate_image_with_point_prompt_result(image=image, detections=detections, x=x, y=y)
152
+
153
+ def get_points_with_draw(image, label, evt: gr.SelectData):
154
+ global global_points
155
+ global global_point_label
156
+
157
+ x, y = evt.index[0], evt.index[1]
158
+ point_radius, point_color = 15, (255, 0, 0) if label == 'Add Mask' else (255, 0, 255)
159
+ global_points.append([x, y])
160
+ global_point_label.append(1 if label == 'Add Mask' else 0)
161
+
162
+ print(x, y, label == 'Add Mask')
163
+ image = Image.fromarray(image)
164
+ draw = ImageDraw.Draw(image)
165
+ draw.ellipse([(x - point_radius, y - point_radius), (x + point_radius, y + point_radius)], fill=point_color)
166
+ return image
167
+
168
+ def clear(_: np.ndarray) -> Tuple[None, None, None, None]:
169
+ return None, None, None, None
170
+
171
+ gr_input_image = gr.Image(label="Input", value='examples/fruits.jpg')
172
+
173
+ fast_sam_segmented_image = gr.Image(label="Fast SAM", interactive=False, type='pil')
174
+
175
+ edge_sam_segmented_imaged = gr.Image(label="Edge SAM", interactive=False, type='pil')
176
+
177
+
178
+ global_points = []
179
+ global_point_label = []
180
+
181
+ with gr.Blocks() as demo:
182
+ with gr.Tab("Points prompt"):
183
+ # Input Image
184
+ with gr.Row(variant="panel"):
185
+ with gr.Column(scale=1, min_width="320", variant="compact"):
186
+ gr_input_image.render()
187
+
188
+ # Submit & Clear
189
+ with gr.Row():
190
+ with gr.Column():
191
+ with gr.Row():
192
+ add_or_remove = gr.Radio(["Add Mask", "Remove Area"], value="Add Mask", label="Point label (foreground/background)")
193
+ with gr.Column():
194
+ inference_point_button = gr.Button("Segment", variant='primary')
195
+ clear_button = gr.Button("Clear points", variant='secondary')
196
+
197
+ # Segment Results Grid
198
+ with gr.Row(variant="panel"):
199
+ with gr.Column(scale=1):
200
+ sam_segmented_image = gr.Image(label="SAM")
201
+ with gr.Column(scale=1):
202
+ efficient_sam_segmented_image = gr.Image(label="Efficient SAM")
203
+
204
+ with gr.Row(variant="panel"):
205
+ with gr.Column(scale=1):
206
+ fast_sam_segmented_image.render()
207
+ with gr.Column(scale=1):
208
+ edge_sam_segmented_imaged.render()
209
+
210
+ gr.Markdown("AI Generated Examples")
211
+ # gr.Examples(examples=examples,
212
+ # inputs=[gr_input_image],
213
+ # # outputs=sam_segmented_image,
214
+ # # fn=segment_with_points,
215
+ # # cache_examples=True,
216
+ # examples_per_page=3)
217
+
218
+ gr_input_image.select(get_points_with_draw, [gr_input_image, add_or_remove], gr_input_image)
219
+
220
+ inference_point_button.click(
221
+ SAM_points_inference,
222
+ inputs=[gr_input_image],
223
+ outputs=[sam_segmented_image]
224
+ )
225
+
226
+ inference_point_button.click(
227
+ EfficientSAM_points_inference,
228
+ inputs=[gr_input_image],
229
+ outputs=[efficient_sam_segmented_image])
230
+
231
+ inference_point_button.click(
232
+ FastSAM_points_inference,
233
+ inputs=[gr_input_image],
234
+ outputs=[fast_sam_segmented_image])
235
+
236
+ # inference_point_button.click(
237
+ # EdgeSAM_points_inference,
238
+ # inputs=[gr_input_image],
239
+ # outputs=[fast_sam_segmented_image, gr_input_image])
240
+
241
+ gr_input_image.change(
242
+ clear,
243
+ inputs=gr_input_image,
244
+ outputs=[efficient_sam_segmented_image, sam_segmented_image, fast_sam_segmented_image]
245
+ )
246
+
247
+ clear_button.click(clear, outputs=[gr_input_image, efficient_sam_segmented_image, sam_segmented_image, fast_sam_segmented_image])
248
+
249
+
250
+ demo.queue()
251
+ demo.launch(debug=True, show_error=True)
efficient_sam_s_cpu.jit ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8b63ab268e9020b0fb7fc9f46e742644d4c9ea6e5d9caf56045f0afb6475db09
3
+ size 106006979
efficient_sam_s_gpu.jit ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e47c589ead2c6a80d38050ce63083a551e288db27113d534e0278270fc7cba26
3
+ size 106006979
examples/.DS_Store ADDED
Binary file (6.15 kB). View file
 
examples/fruits.jpg ADDED
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+
4
+ pillow
5
+ gradio==3.44.0
6
+ transformers
7
+ supervision
8
+ ultralytics
9
+ clip
10
+ opencv-python
utils/draw.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #https://huggingface.co/spaces/SkalskiP/EfficientSAM
2
+ from typing import Tuple
3
+
4
+ import cv2
5
+ import numpy as np
6
+ import supervision as sv
7
+
8
+
9
+ def draw_circle(
10
+ scene: np.ndarray, center: sv.Point, color: sv.Color, radius: int = 2
11
+ ) -> np.ndarray:
12
+ cv2.circle(
13
+ scene,
14
+ center=center.as_xy_int_tuple(),
15
+ radius=radius,
16
+ color=color.as_bgr(),
17
+ thickness=-1,
18
+ )
19
+ return scene
20
+
21
+
22
+ def calculate_dynamic_circle_radius(resolution_wh: Tuple[int, int]) -> int:
23
+ min_dimension = min(resolution_wh)
24
+ if min_dimension < 480:
25
+ return 4
26
+ if min_dimension < 720:
27
+ return 8
28
+ if min_dimension < 1080:
29
+ return 8
30
+ if min_dimension < 2160:
31
+ return 16
32
+ else:
33
+ return 16
utils/efficient_sam.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from torchvision.transforms import ToTensor
4
+
5
+ GPU_EFFICIENT_SAM_CHECKPOINT = "efficient_sam_s_gpu.jit"
6
+ CPU_EFFICIENT_SAM_CHECKPOINT = "efficient_sam_s_cpu.jit"
7
+
8
+
9
+ def load(device: torch.device) -> torch.jit.ScriptModule:
10
+ if device.type == "cuda":
11
+ model = torch.jit.load(GPU_EFFICIENT_SAM_CHECKPOINT)
12
+ else:
13
+ model = torch.jit.load(CPU_EFFICIENT_SAM_CHECKPOINT)
14
+ model.eval()
15
+ return model
16
+
17
+
18
+ def inference_with_box(
19
+ image: np.ndarray,
20
+ box: np.ndarray,
21
+ model: torch.jit.ScriptModule,
22
+ device: torch.device
23
+ ) -> np.ndarray:
24
+ bbox = torch.reshape(torch.tensor(box), [1, 1, 2, 2])
25
+ bbox_labels = torch.reshape(torch.tensor([2, 3]), [1, 1, 2])
26
+ img_tensor = ToTensor()(image)
27
+
28
+ predicted_logits, predicted_iou = model(
29
+ img_tensor[None, ...].to(device),
30
+ bbox.to(device),
31
+ bbox_labels.to(device),
32
+ )
33
+ predicted_logits = predicted_logits.cpu()
34
+ all_masks = torch.ge(torch.sigmoid(predicted_logits[0, 0, :, :, :]), 0.5).numpy()
35
+ predicted_iou = predicted_iou[0, 0, ...].cpu().detach().numpy()
36
+
37
+ max_predicted_iou = -1
38
+ selected_mask_using_predicted_iou = None
39
+ for m in range(all_masks.shape[0]):
40
+ curr_predicted_iou = predicted_iou[m]
41
+ if (
42
+ curr_predicted_iou > max_predicted_iou
43
+ or selected_mask_using_predicted_iou is None
44
+ ):
45
+ max_predicted_iou = curr_predicted_iou
46
+ selected_mask_using_predicted_iou = all_masks[m]
47
+ return selected_mask_using_predicted_iou
48
+
49
+
50
+ def inference_with_point(
51
+ image: np.ndarray,
52
+ point: np.ndarray,
53
+ model: torch.jit.ScriptModule,
54
+ device: torch.device
55
+ ) -> np.ndarray:
56
+ pts_sampled = torch.reshape(torch.tensor(point), [1, 1, -1, 2])
57
+ max_num_pts = pts_sampled.shape[2]
58
+ pts_labels = torch.ones(1, 1, max_num_pts)
59
+ img_tensor = ToTensor()(image)
60
+
61
+ predicted_logits, predicted_iou = model(
62
+ img_tensor[None, ...].to(device),
63
+ pts_sampled.to(device),
64
+ pts_labels.to(device),
65
+ )
66
+ predicted_logits = predicted_logits.cpu()
67
+ all_masks = torch.ge(torch.sigmoid(predicted_logits[0, 0, :, :, :]), 0.5).numpy()
68
+ predicted_iou = predicted_iou[0, 0, ...].cpu().detach().numpy()
69
+
70
+ max_predicted_iou = -1
71
+ selected_mask_using_predicted_iou = None
72
+ for m in range(all_masks.shape[0]):
73
+ curr_predicted_iou = predicted_iou[m]
74
+ if (
75
+ curr_predicted_iou > max_predicted_iou
76
+ or selected_mask_using_predicted_iou is None
77
+ ):
78
+ max_predicted_iou = curr_predicted_iou
79
+ selected_mask_using_predicted_iou = all_masks[m]
80
+ return selected_mask_using_predicted_iou
utils/tools.py ADDED
@@ -0,0 +1,443 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://huggingface.co/spaces/An-619/FastSAM/edit/main/utils/tools.py
2
+ import numpy as np
3
+ from PIL import Image
4
+ import matplotlib.pyplot as plt
5
+ import cv2
6
+ import torch
7
+ import os
8
+ import sys
9
+ import clip
10
+
11
+
12
+ def convert_box_xywh_to_xyxy(box):
13
+ if len(box) == 4:
14
+ return [box[0], box[1], box[0] + box[2], box[1] + box[3]]
15
+ else:
16
+ result = []
17
+ for b in box:
18
+ b = convert_box_xywh_to_xyxy(b)
19
+ result.append(b)
20
+ return result
21
+
22
+
23
+ def segment_image(image, bbox):
24
+ image_array = np.array(image)
25
+ segmented_image_array = np.zeros_like(image_array)
26
+ x1, y1, x2, y2 = bbox
27
+ segmented_image_array[y1:y2, x1:x2] = image_array[y1:y2, x1:x2]
28
+ segmented_image = Image.fromarray(segmented_image_array)
29
+ black_image = Image.new("RGB", image.size, (255, 255, 255))
30
+ # transparency_mask = np.zeros_like((), dtype=np.uint8)
31
+ transparency_mask = np.zeros(
32
+ (image_array.shape[0], image_array.shape[1]), dtype=np.uint8
33
+ )
34
+ transparency_mask[y1:y2, x1:x2] = 255
35
+ transparency_mask_image = Image.fromarray(transparency_mask, mode="L")
36
+ black_image.paste(segmented_image, mask=transparency_mask_image)
37
+ return black_image
38
+
39
+
40
+ def format_results(result, filter=0):
41
+ annotations = []
42
+ n = len(result.masks.data)
43
+ for i in range(n):
44
+ annotation = {}
45
+ mask = result.masks.data[i] == 1.0
46
+
47
+ if torch.sum(mask) < filter:
48
+ continue
49
+ annotation["id"] = i
50
+ annotation["segmentation"] = mask.cpu().numpy()
51
+ annotation["bbox"] = result.boxes.data[i]
52
+ annotation["score"] = result.boxes.conf[i]
53
+ annotation["area"] = annotation["segmentation"].sum()
54
+ annotations.append(annotation)
55
+ return annotations
56
+
57
+
58
+ def filter_masks(annotations): # filter the overlap mask
59
+ annotations.sort(key=lambda x: x["area"], reverse=True)
60
+ to_remove = set()
61
+ for i in range(0, len(annotations)):
62
+ a = annotations[i]
63
+ for j in range(i + 1, len(annotations)):
64
+ b = annotations[j]
65
+ if i != j and j not in to_remove:
66
+ # check if
67
+ if b["area"] < a["area"]:
68
+ if (a["segmentation"] & b["segmentation"]).sum() / b[
69
+ "segmentation"
70
+ ].sum() > 0.8:
71
+ to_remove.add(j)
72
+
73
+ return [a for i, a in enumerate(annotations) if i not in to_remove], to_remove
74
+
75
+
76
+ def get_bbox_from_mask(mask):
77
+ mask = mask.astype(np.uint8)
78
+ contours, hierarchy = cv2.findContours(
79
+ mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
80
+ )
81
+ x1, y1, w, h = cv2.boundingRect(contours[0])
82
+ x2, y2 = x1 + w, y1 + h
83
+ if len(contours) > 1:
84
+ for b in contours:
85
+ x_t, y_t, w_t, h_t = cv2.boundingRect(b)
86
+ # 将多个bbox合并成一个
87
+ x1 = min(x1, x_t)
88
+ y1 = min(y1, y_t)
89
+ x2 = max(x2, x_t + w_t)
90
+ y2 = max(y2, y_t + h_t)
91
+ h = y2 - y1
92
+ w = x2 - x1
93
+ return [x1, y1, x2, y2]
94
+
95
+
96
+ def fast_process(
97
+ annotations, args, mask_random_color, bbox=None, points=None, edges=False
98
+ ):
99
+ if isinstance(annotations[0], dict):
100
+ annotations = [annotation["segmentation"] for annotation in annotations]
101
+ result_name = os.path.basename(args.img_path)
102
+ image = cv2.imread(args.img_path)
103
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
104
+ original_h = image.shape[0]
105
+ original_w = image.shape[1]
106
+ if sys.platform == "darwin":
107
+ plt.switch_backend("TkAgg")
108
+ plt.figure(figsize=(original_w/100, original_h/100))
109
+ # Add subplot with no margin.
110
+ plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
111
+ plt.margins(0, 0)
112
+ plt.gca().xaxis.set_major_locator(plt.NullLocator())
113
+ plt.gca().yaxis.set_major_locator(plt.NullLocator())
114
+ plt.imshow(image)
115
+ if args.better_quality == True:
116
+ if isinstance(annotations[0], torch.Tensor):
117
+ annotations = np.array(annotations.cpu())
118
+ for i, mask in enumerate(annotations):
119
+ mask = cv2.morphologyEx(
120
+ mask.astype(np.uint8), cv2.MORPH_CLOSE, np.ones((3, 3), np.uint8)
121
+ )
122
+ annotations[i] = cv2.morphologyEx(
123
+ mask.astype(np.uint8), cv2.MORPH_OPEN, np.ones((8, 8), np.uint8)
124
+ )
125
+ if args.device == "cpu":
126
+ annotations = np.array(annotations)
127
+ fast_show_mask(
128
+ annotations,
129
+ plt.gca(),
130
+ random_color=mask_random_color,
131
+ bbox=bbox,
132
+ points=points,
133
+ point_label=args.point_label,
134
+ retinamask=args.retina,
135
+ target_height=original_h,
136
+ target_width=original_w,
137
+ )
138
+ else:
139
+ if isinstance(annotations[0], np.ndarray):
140
+ annotations = torch.from_numpy(annotations)
141
+ fast_show_mask_gpu(
142
+ annotations,
143
+ plt.gca(),
144
+ random_color=args.randomcolor,
145
+ bbox=bbox,
146
+ points=points,
147
+ point_label=args.point_label,
148
+ retinamask=args.retina,
149
+ target_height=original_h,
150
+ target_width=original_w,
151
+ )
152
+ if isinstance(annotations, torch.Tensor):
153
+ annotations = annotations.cpu().numpy()
154
+ if args.withContours == True:
155
+ contour_all = []
156
+ temp = np.zeros((original_h, original_w, 1))
157
+ for i, mask in enumerate(annotations):
158
+ if type(mask) == dict:
159
+ mask = mask["segmentation"]
160
+ annotation = mask.astype(np.uint8)
161
+ if args.retina == False:
162
+ annotation = cv2.resize(
163
+ annotation,
164
+ (original_w, original_h),
165
+ interpolation=cv2.INTER_NEAREST,
166
+ )
167
+ contours, hierarchy = cv2.findContours(
168
+ annotation, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE
169
+ )
170
+ for contour in contours:
171
+ contour_all.append(contour)
172
+ cv2.drawContours(temp, contour_all, -1, (255, 255, 255), 2)
173
+ color = np.array([0 / 255, 0 / 255, 255 / 255, 0.8])
174
+ contour_mask = temp / 255 * color.reshape(1, 1, -1)
175
+ plt.imshow(contour_mask)
176
+
177
+ save_path = args.output
178
+ if not os.path.exists(save_path):
179
+ os.makedirs(save_path)
180
+ plt.axis("off")
181
+ fig = plt.gcf()
182
+ plt.draw()
183
+
184
+ try:
185
+ buf = fig.canvas.tostring_rgb()
186
+ except AttributeError:
187
+ fig.canvas.draw()
188
+ buf = fig.canvas.tostring_rgb()
189
+
190
+ cols, rows = fig.canvas.get_width_height()
191
+ img_array = np.fromstring(buf, dtype=np.uint8).reshape(rows, cols, 3)
192
+ cv2.imwrite(os.path.join(save_path, result_name), cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR))
193
+
194
+
195
+ # CPU post process
196
+ def fast_show_mask(
197
+ annotation,
198
+ ax,
199
+ random_color=False,
200
+ bbox=None,
201
+ points=None,
202
+ point_label=None,
203
+ retinamask=True,
204
+ target_height=960,
205
+ target_width=960,
206
+ ):
207
+ msak_sum = annotation.shape[0]
208
+ height = annotation.shape[1]
209
+ weight = annotation.shape[2]
210
+ # 将annotation 按照面积 排序
211
+ areas = np.sum(annotation, axis=(1, 2))
212
+ sorted_indices = np.argsort(areas)
213
+ annotation = annotation[sorted_indices]
214
+
215
+ index = (annotation != 0).argmax(axis=0)
216
+ if random_color == True:
217
+ color = np.random.random((msak_sum, 1, 1, 3))
218
+ else:
219
+ color = np.ones((msak_sum, 1, 1, 3)) * np.array(
220
+ [30 / 255, 144 / 255, 255 / 255]
221
+ )
222
+ transparency = np.ones((msak_sum, 1, 1, 1)) * 0.6
223
+ visual = np.concatenate([color, transparency], axis=-1)
224
+ mask_image = np.expand_dims(annotation, -1) * visual
225
+
226
+ show = np.zeros((height, weight, 4))
227
+ h_indices, w_indices = np.meshgrid(
228
+ np.arange(height), np.arange(weight), indexing="ij"
229
+ )
230
+ indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
231
+ # 使用向量化索引更新show的值
232
+ show[h_indices, w_indices, :] = mask_image[indices]
233
+ if bbox is not None:
234
+ x1, y1, x2, y2 = bbox
235
+ ax.add_patch(
236
+ plt.Rectangle(
237
+ (x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor="b", linewidth=1
238
+ )
239
+ )
240
+ # draw point
241
+ if points is not None:
242
+ plt.scatter(
243
+ [point[0] for i, point in enumerate(points) if point_label[i] == 1],
244
+ [point[1] for i, point in enumerate(points) if point_label[i] == 1],
245
+ s=20,
246
+ c="y",
247
+ )
248
+ plt.scatter(
249
+ [point[0] for i, point in enumerate(points) if point_label[i] == 0],
250
+ [point[1] for i, point in enumerate(points) if point_label[i] == 0],
251
+ s=20,
252
+ c="m",
253
+ )
254
+
255
+ if retinamask == False:
256
+ show = cv2.resize(
257
+ show, (target_width, target_height), interpolation=cv2.INTER_NEAREST
258
+ )
259
+ ax.imshow(show)
260
+
261
+
262
+ def fast_show_mask_gpu(
263
+ annotation,
264
+ ax,
265
+ random_color=False,
266
+ bbox=None,
267
+ points=None,
268
+ point_label=None,
269
+ retinamask=True,
270
+ target_height=960,
271
+ target_width=960,
272
+ ):
273
+ msak_sum = annotation.shape[0]
274
+ height = annotation.shape[1]
275
+ weight = annotation.shape[2]
276
+ areas = torch.sum(annotation, dim=(1, 2))
277
+ sorted_indices = torch.argsort(areas, descending=False)
278
+ annotation = annotation[sorted_indices]
279
+ # 找每个位置第一个非零值下标
280
+ index = (annotation != 0).to(torch.long).argmax(dim=0)
281
+ if random_color == True:
282
+ color = torch.rand((msak_sum, 1, 1, 3)).to(annotation.device)
283
+ else:
284
+ color = torch.ones((msak_sum, 1, 1, 3)).to(annotation.device) * torch.tensor(
285
+ [30 / 255, 144 / 255, 255 / 255]
286
+ ).to(annotation.device)
287
+ transparency = torch.ones((msak_sum, 1, 1, 1)).to(annotation.device) * 0.6
288
+ visual = torch.cat([color, transparency], dim=-1)
289
+ mask_image = torch.unsqueeze(annotation, -1) * visual
290
+ # 按index取数,index指每个位置选哪个batch的数��把mask_image转成一个batch的形式
291
+ show = torch.zeros((height, weight, 4)).to(annotation.device)
292
+ h_indices, w_indices = torch.meshgrid(
293
+ torch.arange(height), torch.arange(weight), indexing="ij"
294
+ )
295
+ indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
296
+ # 使用向量化索引更新show的值
297
+ show[h_indices, w_indices, :] = mask_image[indices]
298
+ show_cpu = show.cpu().numpy()
299
+ if bbox is not None:
300
+ x1, y1, x2, y2 = bbox
301
+ ax.add_patch(
302
+ plt.Rectangle(
303
+ (x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor="b", linewidth=1
304
+ )
305
+ )
306
+ # draw point
307
+ if points is not None:
308
+ plt.scatter(
309
+ [point[0] for i, point in enumerate(points) if point_label[i] == 1],
310
+ [point[1] for i, point in enumerate(points) if point_label[i] == 1],
311
+ s=20,
312
+ c="y",
313
+ )
314
+ plt.scatter(
315
+ [point[0] for i, point in enumerate(points) if point_label[i] == 0],
316
+ [point[1] for i, point in enumerate(points) if point_label[i] == 0],
317
+ s=20,
318
+ c="m",
319
+ )
320
+ if retinamask == False:
321
+ show_cpu = cv2.resize(
322
+ show_cpu, (target_width, target_height), interpolation=cv2.INTER_NEAREST
323
+ )
324
+ ax.imshow(show_cpu)
325
+
326
+
327
+ # clip
328
+ @torch.no_grad()
329
+ def retriev(
330
+ model, preprocess, elements: [Image.Image], search_text: str, device
331
+ ):
332
+ preprocessed_images = [preprocess(image).to(device) for image in elements]
333
+ tokenized_text = clip.tokenize([search_text]).to(device)
334
+ stacked_images = torch.stack(preprocessed_images)
335
+ image_features = model.encode_image(stacked_images)
336
+ text_features = model.encode_text(tokenized_text)
337
+ image_features /= image_features.norm(dim=-1, keepdim=True)
338
+ text_features /= text_features.norm(dim=-1, keepdim=True)
339
+ probs = 100.0 * image_features @ text_features.T
340
+ return probs[:, 0].softmax(dim=0)
341
+
342
+
343
+ def crop_image(annotations, image_like):
344
+ if isinstance(image_like, str):
345
+ image = Image.open(image_like)
346
+ else:
347
+ image = image_like
348
+ ori_w, ori_h = image.size
349
+ mask_h, mask_w = annotations[0]["segmentation"].shape
350
+ if ori_w != mask_w or ori_h != mask_h:
351
+ image = image.resize((mask_w, mask_h))
352
+ cropped_boxes = []
353
+ cropped_images = []
354
+ not_crop = []
355
+ origin_id = []
356
+ for _, mask in enumerate(annotations):
357
+ if np.sum(mask["segmentation"]) <= 100:
358
+ continue
359
+ origin_id.append(_)
360
+ bbox = get_bbox_from_mask(mask["segmentation"]) # mask 的 bbox
361
+ cropped_boxes.append(segment_image(image, bbox)) # 保存裁剪的图片
362
+ # cropped_boxes.append(segment_image(image,mask["segmentation"]))
363
+ cropped_images.append(bbox) # 保存裁剪的图片的bbox
364
+ return cropped_boxes, cropped_images, not_crop, origin_id, annotations
365
+
366
+
367
+ def box_prompt(masks, bbox, target_height, target_width):
368
+ h = masks.shape[1]
369
+ w = masks.shape[2]
370
+ if h != target_height or w != target_width:
371
+ bbox = [
372
+ int(bbox[0] * w / target_width),
373
+ int(bbox[1] * h / target_height),
374
+ int(bbox[2] * w / target_width),
375
+ int(bbox[3] * h / target_height),
376
+ ]
377
+ bbox[0] = round(bbox[0]) if round(bbox[0]) > 0 else 0
378
+ bbox[1] = round(bbox[1]) if round(bbox[1]) > 0 else 0
379
+ bbox[2] = round(bbox[2]) if round(bbox[2]) < w else w
380
+ bbox[3] = round(bbox[3]) if round(bbox[3]) < h else h
381
+
382
+ # IoUs = torch.zeros(len(masks), dtype=torch.float32)
383
+ bbox_area = (bbox[3] - bbox[1]) * (bbox[2] - bbox[0])
384
+
385
+ masks_area = torch.sum(masks[:, bbox[1] : bbox[3], bbox[0] : bbox[2]], dim=(1, 2))
386
+ orig_masks_area = torch.sum(masks, dim=(1, 2))
387
+
388
+ union = bbox_area + orig_masks_area - masks_area
389
+ IoUs = masks_area / union
390
+ max_iou_index = torch.argmax(IoUs)
391
+
392
+ return masks[max_iou_index].cpu().numpy(), max_iou_index
393
+
394
+
395
+ def point_prompt(masks, points, point_label, target_height, target_width): # numpy 处理
396
+ h = masks[0]["segmentation"].shape[0]
397
+ w = masks[0]["segmentation"].shape[1]
398
+ if h != target_height or w != target_width:
399
+ points = [
400
+ [int(point[0] * w / target_width), int(point[1] * h / target_height)]
401
+ for point in points
402
+ ]
403
+ onemask = np.zeros((h, w))
404
+ masks = sorted(masks, key=lambda x: x['area'], reverse=True)
405
+ for i, annotation in enumerate(masks):
406
+ if type(annotation) == dict:
407
+ mask = annotation['segmentation']
408
+ else:
409
+ mask = annotation
410
+ for i, point in enumerate(points):
411
+ if mask[point[1], point[0]] == 1 and point_label[i] == 1:
412
+ onemask[mask] = 1
413
+ if mask[point[1], point[0]] == 1 and point_label[i] == 0:
414
+ onemask[mask] = 0
415
+ onemask = onemask >= 1
416
+ return onemask, 0
417
+
418
+
419
+ def text_prompt(annotations, text, img_path, device, wider=False, threshold=0.9):
420
+ cropped_boxes, cropped_images, not_crop, origin_id, annotations_ = crop_image(
421
+ annotations, img_path
422
+ )
423
+ clip_model, preprocess = clip.load("./weights/CLIP_ViT_B_32.pt", device=device)
424
+ scores = retriev(
425
+ clip_model, preprocess, cropped_boxes, text, device=device
426
+ )
427
+ max_idx = scores.argsort()
428
+ max_idx = max_idx[-1]
429
+ max_idx = origin_id[int(max_idx)]
430
+
431
+ # find the biggest mask which contains the mask with max score
432
+ if wider:
433
+ mask0 = annotations_[max_idx]["segmentation"]
434
+ area0 = np.sum(mask0)
435
+ areas = [(i, np.sum(mask["segmentation"])) for i, mask in enumerate(annotations_) if i in origin_id]
436
+ areas = sorted(areas, key=lambda area: area[1], reverse=True)
437
+ indices = [area[0] for area in areas]
438
+ for index in indices:
439
+ if index == max_idx or np.sum(annotations_[index]["segmentation"] & mask0) / area0 > threshold:
440
+ max_idx = index
441
+ break
442
+
443
+ return annotations_[max_idx]["segmentation"], max_idx
utils/tools_gradio.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://huggingface.co/spaces/An-619/FastSAM/edit/main/utils/tools_gradio.py
2
+ import numpy as np
3
+ from PIL import Image
4
+ import matplotlib.pyplot as plt
5
+ import cv2
6
+ import torch
7
+
8
+
9
+ def fast_process(
10
+ annotations,
11
+ image,
12
+ device,
13
+ scale,
14
+ better_quality=False,
15
+ mask_random_color=True,
16
+ bbox=None,
17
+ use_retina=True,
18
+ withContours=True,
19
+ ):
20
+ if isinstance(annotations[0], dict):
21
+ annotations = [annotation['segmentation'] for annotation in annotations]
22
+
23
+ original_h = image.height
24
+ original_w = image.width
25
+ if better_quality:
26
+ if isinstance(annotations[0], torch.Tensor):
27
+ annotations = np.array(annotations.cpu())
28
+ for i, mask in enumerate(annotations):
29
+ mask = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_CLOSE, np.ones((3, 3), np.uint8))
30
+ annotations[i] = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_OPEN, np.ones((8, 8), np.uint8))
31
+ if device == 'cpu':
32
+ annotations = np.array(annotations)
33
+ inner_mask = fast_show_mask(
34
+ annotations,
35
+ plt.gca(),
36
+ random_color=mask_random_color,
37
+ bbox=bbox,
38
+ retinamask=use_retina,
39
+ target_height=original_h,
40
+ target_width=original_w,
41
+ )
42
+ else:
43
+ if isinstance(annotations[0], np.ndarray):
44
+ annotations = torch.from_numpy(annotations)
45
+ inner_mask = fast_show_mask_gpu(
46
+ annotations,
47
+ plt.gca(),
48
+ random_color=mask_random_color,
49
+ bbox=bbox,
50
+ retinamask=use_retina,
51
+ target_height=original_h,
52
+ target_width=original_w,
53
+ )
54
+ if isinstance(annotations, torch.Tensor):
55
+ annotations = annotations.cpu().numpy()
56
+
57
+ if withContours:
58
+ contour_all = []
59
+ temp = np.zeros((original_h, original_w, 1))
60
+ for i, mask in enumerate(annotations):
61
+ if type(mask) == dict:
62
+ mask = mask['segmentation']
63
+ annotation = mask.astype(np.uint8)
64
+ if use_retina == False:
65
+ annotation = cv2.resize(
66
+ annotation,
67
+ (original_w, original_h),
68
+ interpolation=cv2.INTER_NEAREST,
69
+ )
70
+ contours, _ = cv2.findContours(annotation, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
71
+ for contour in contours:
72
+ contour_all.append(contour)
73
+ cv2.drawContours(temp, contour_all, -1, (255, 255, 255), 2 // scale)
74
+ color = np.array([0 / 255, 0 / 255, 255 / 255, 0.9])
75
+ contour_mask = temp / 255 * color.reshape(1, 1, -1)
76
+
77
+ image = image.convert('RGBA')
78
+ overlay_inner = Image.fromarray((inner_mask * 255).astype(np.uint8), 'RGBA')
79
+ image.paste(overlay_inner, (0, 0), overlay_inner)
80
+
81
+ if withContours:
82
+ overlay_contour = Image.fromarray((contour_mask * 255).astype(np.uint8), 'RGBA')
83
+ image.paste(overlay_contour, (0, 0), overlay_contour)
84
+
85
+ return image
86
+
87
+
88
+ # CPU post process
89
+ def fast_show_mask(
90
+ annotation,
91
+ ax,
92
+ random_color=False,
93
+ bbox=None,
94
+ retinamask=True,
95
+ target_height=960,
96
+ target_width=960,
97
+ ):
98
+ mask_sum = annotation.shape[0]
99
+ height = annotation.shape[1]
100
+ weight = annotation.shape[2]
101
+ # 将annotation 按照面积 排序
102
+ areas = np.sum(annotation, axis=(1, 2))
103
+ sorted_indices = np.argsort(areas)[::1]
104
+ annotation = annotation[sorted_indices]
105
+
106
+ index = (annotation != 0).argmax(axis=0)
107
+ if random_color:
108
+ color = np.random.random((mask_sum, 1, 1, 3))
109
+ else:
110
+ color = np.ones((mask_sum, 1, 1, 3)) * np.array([30 / 255, 144 / 255, 255 / 255])
111
+ transparency = np.ones((mask_sum, 1, 1, 1)) * 0.6
112
+ visual = np.concatenate([color, transparency], axis=-1)
113
+ mask_image = np.expand_dims(annotation, -1) * visual
114
+
115
+ mask = np.zeros((height, weight, 4))
116
+
117
+ h_indices, w_indices = np.meshgrid(np.arange(height), np.arange(weight), indexing='ij')
118
+ indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
119
+
120
+ mask[h_indices, w_indices, :] = mask_image[indices]
121
+ if bbox is not None:
122
+ x1, y1, x2, y2 = bbox
123
+ ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor='b', linewidth=1))
124
+
125
+ if not retinamask:
126
+ mask = cv2.resize(mask, (target_width, target_height), interpolation=cv2.INTER_NEAREST)
127
+
128
+ return mask
129
+
130
+
131
+ def fast_show_mask_gpu(
132
+ annotation,
133
+ ax,
134
+ random_color=False,
135
+ bbox=None,
136
+ retinamask=True,
137
+ target_height=960,
138
+ target_width=960,
139
+ ):
140
+ device = annotation.device
141
+ mask_sum = annotation.shape[0]
142
+ height = annotation.shape[1]
143
+ weight = annotation.shape[2]
144
+ areas = torch.sum(annotation, dim=(1, 2))
145
+ sorted_indices = torch.argsort(areas, descending=False)
146
+ annotation = annotation[sorted_indices]
147
+ # 找每个位置第一个非零值下标
148
+ index = (annotation != 0).to(torch.long).argmax(dim=0)
149
+ if random_color:
150
+ color = torch.rand((mask_sum, 1, 1, 3)).to(device)
151
+ else:
152
+ color = torch.ones((mask_sum, 1, 1, 3)).to(device) * torch.tensor(
153
+ [30 / 255, 144 / 255, 255 / 255]
154
+ ).to(device)
155
+ transparency = torch.ones((mask_sum, 1, 1, 1)).to(device) * 0.6
156
+ visual = torch.cat([color, transparency], dim=-1)
157
+ mask_image = torch.unsqueeze(annotation, -1) * visual
158
+ # 按index取数,index指每个位置选哪个batch的数,把mask_image转成一个batch的形式
159
+ mask = torch.zeros((height, weight, 4)).to(device)
160
+ h_indices, w_indices = torch.meshgrid(torch.arange(height), torch.arange(weight))
161
+ indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
162
+ # 使用向量化索引更新show的值
163
+ mask[h_indices, w_indices, :] = mask_image[indices]
164
+ mask_cpu = mask.cpu().numpy()
165
+ if bbox is not None:
166
+ x1, y1, x2, y2 = bbox
167
+ ax.add_patch(
168
+ plt.Rectangle(
169
+ (x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor="b", linewidth=1
170
+ )
171
+ )
172
+ if not retinamask:
173
+ mask_cpu = cv2.resize(
174
+ mask_cpu, (target_width, target_height), interpolation=cv2.INTER_NEAREST
175
+ )
176
+ return mask_cpu