AAAAAAyq
commited on
Commit
•
2f10180
1
Parent(s):
5350ba4
Update the examples
Browse files
tools.py
CHANGED
@@ -3,7 +3,7 @@ from PIL import Image
|
|
3 |
import matplotlib.pyplot as plt
|
4 |
import cv2
|
5 |
import torch
|
6 |
-
import clip
|
7 |
|
8 |
|
9 |
def convert_box_xywh_to_xyxy(box):
|
@@ -290,20 +290,20 @@ def fast_show_mask_gpu(
|
|
290 |
return mask_cpu
|
291 |
|
292 |
|
293 |
-
# clip
|
294 |
-
@torch.no_grad()
|
295 |
-
def retriev(
|
296 |
-
|
297 |
-
) -> int:
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
-
|
302 |
-
|
303 |
-
|
304 |
-
|
305 |
-
|
306 |
-
|
307 |
|
308 |
|
309 |
def crop_image(annotations, image_path):
|
@@ -381,15 +381,15 @@ def point_prompt(masks, points, pointlabel, target_height, target_width): # num
|
|
381 |
return onemask, 0
|
382 |
|
383 |
|
384 |
-
def text_prompt(annotations, args):
|
385 |
-
|
386 |
-
|
387 |
-
|
388 |
-
|
389 |
-
|
390 |
-
|
391 |
-
|
392 |
-
|
393 |
-
|
394 |
-
|
395 |
-
|
|
|
3 |
import matplotlib.pyplot as plt
|
4 |
import cv2
|
5 |
import torch
|
6 |
+
# import clip
|
7 |
|
8 |
|
9 |
def convert_box_xywh_to_xyxy(box):
|
|
|
290 |
return mask_cpu
|
291 |
|
292 |
|
293 |
+
# # clip
|
294 |
+
# @torch.no_grad()
|
295 |
+
# def retriev(
|
296 |
+
# model, preprocess, elements, search_text: str, device
|
297 |
+
# ) -> int:
|
298 |
+
# preprocessed_images = [preprocess(image).to(device) for image in elements]
|
299 |
+
# tokenized_text = clip.tokenize([search_text]).to(device)
|
300 |
+
# stacked_images = torch.stack(preprocessed_images)
|
301 |
+
# image_features = model.encode_image(stacked_images)
|
302 |
+
# text_features = model.encode_text(tokenized_text)
|
303 |
+
# image_features /= image_features.norm(dim=-1, keepdim=True)
|
304 |
+
# text_features /= text_features.norm(dim=-1, keepdim=True)
|
305 |
+
# probs = 100.0 * image_features @ text_features.T
|
306 |
+
# return probs[:, 0].softmax(dim=0)
|
307 |
|
308 |
|
309 |
def crop_image(annotations, image_path):
|
|
|
381 |
return onemask, 0
|
382 |
|
383 |
|
384 |
+
# def text_prompt(annotations, args):
|
385 |
+
# cropped_boxes, cropped_images, not_crop, filter_id, annotaions = crop_image(
|
386 |
+
# annotations, args.img_path
|
387 |
+
# )
|
388 |
+
# clip_model, preprocess = clip.load("ViT-B/32", device=args.device)
|
389 |
+
# scores = retriev(
|
390 |
+
# clip_model, preprocess, cropped_boxes, args.text_prompt, device=args.device
|
391 |
+
# )
|
392 |
+
# max_idx = scores.argsort()
|
393 |
+
# max_idx = max_idx[-1]
|
394 |
+
# max_idx += sum(np.array(filter_id) <= int(max_idx))
|
395 |
+
# return annotaions[max_idx]["segmentation"], max_idx
|