ngthanhtinqn commited on
Commit
127eb07
β€’
1 Parent(s): 0c76662
Files changed (3) hide show
  1. .gitattributes +1 -0
  2. app.py +2 -48
  3. demo.py +178 -0
.gitattributes CHANGED
@@ -32,3 +32,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ sam_vit_h_4b8939.pth filter=lfs diff=lfs merge=lfs -text
app.py CHANGED
@@ -1,53 +1,7 @@
1
- import torch
2
- import cv2
3
- import gradio as gr
4
- import numpy as np
5
- from transformers import OwlViTProcessor, OwlViTForObjectDetection
6
-
7
-
8
- # Use GPU if available
9
- if torch.cuda.is_available():
10
- device = torch.device("cuda:4")
11
- else:
12
- device = torch.device("cpu")
13
-
14
- model = OwlViTForObjectDetection.from_pretrained("google/owlvit-base-patch32").to(device)
15
- model.eval()
16
- processor = OwlViTProcessor.from_pretrained("google/owlvit-base-patch32")
17
-
18
-
19
- def query_image(img, text_queries, score_threshold):
20
- text_queries = text_queries
21
- text_queries = text_queries.split(",")
22
 
23
- target_sizes = torch.Tensor([img.shape[:2]])
24
- inputs = processor(text=text_queries, images=img, return_tensors="pt").to(device)
25
-
26
- with torch.no_grad():
27
- outputs = model(**inputs)
28
-
29
- outputs.logits = outputs.logits.cpu()
30
- outputs.pred_boxes = outputs.pred_boxes.cpu()
31
- results = processor.post_process(outputs=outputs, target_sizes=target_sizes)
32
- boxes, scores, labels = results[0]["boxes"], results[0]["scores"], results[0]["labels"]
33
-
34
- font = cv2.FONT_HERSHEY_SIMPLEX
35
-
36
- for box, score, label in zip(boxes, scores, labels):
37
- box = [int(i) for i in box.tolist()]
38
-
39
- if score >= score_threshold:
40
- img = cv2.rectangle(img, box[:2], box[2:], (255,0,0), 5)
41
- if box[3] + 25 > 768:
42
- y = box[3] - 10
43
- else:
44
- y = box[3] + 25
45
-
46
- img = cv2.putText(
47
- img, text_queries[label], (box[0], y), font, 1, (255,0,0), 2, cv2.LINE_AA
48
- )
49
- return img
50
 
 
51
 
52
  description = """
53
  Gradio demo for <a href="https://huggingface.co/docs/transformers/main/en/model_doc/owlvit">OWL-ViT</a>,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
 
2
+ import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
+ from demo import query_image
5
 
6
  description = """
7
  Gradio demo for <a href="https://huggingface.co/docs/transformers/main/en/model_doc/owlvit">OWL-ViT</a>,
demo.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import copy
4
+
5
+ import numpy as np
6
+ import torch
7
+ from PIL import Image, ImageDraw, ImageFont
8
+ import PIL
9
+ # OwlViT Detection
10
+ from transformers import OwlViTProcessor, OwlViTForObjectDetection
11
+
12
+ # segment anything
13
+ from segment_anything import build_sam, SamPredictor
14
+ import cv2
15
+ import numpy as np
16
+ import matplotlib.pyplot as plt
17
+
18
+ import gc
19
+
20
+ def show_mask(mask, ax, random_color=False):
21
+ if random_color:
22
+ color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
23
+ else:
24
+ color = np.array([30/255, 144/255, 255/255, 0.6])
25
+ h, w = mask.shape[-2:]
26
+ mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
27
+ ax.imshow(mask_image)
28
+
29
+
30
+ def show_box(box, ax):
31
+ x0, y0 = box[0], box[1]
32
+ w, h = box[2] - box[0], box[3] - box[1]
33
+ ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))
34
+
35
+ def plot_boxes_to_image(image_pil, tgt):
36
+ H, W = tgt["size"]
37
+ boxes = tgt["boxes"]
38
+ labels = tgt["labels"]
39
+ assert len(boxes) == len(labels), "boxes and labels must have same length"
40
+
41
+ draw = ImageDraw.Draw(image_pil)
42
+ mask = Image.new("L", image_pil.size, 0)
43
+ mask_draw = ImageDraw.Draw(mask)
44
+
45
+ # draw boxes and masks
46
+ for box, label in zip(boxes, labels):
47
+ # random color
48
+ color = tuple(np.random.randint(0, 255, size=3).tolist())
49
+ # draw
50
+ x0, y0, x1, y1 = box
51
+ x0, y0, x1, y1 = int(x0), int(y0), int(x1), int(y1)
52
+
53
+ draw.rectangle([x0, y0, x1, y1], outline=color, width=6)
54
+ draw.text((x0, y0), str(label), fill=color)
55
+
56
+ font = ImageFont.load_default()
57
+ if hasattr(font, "getbbox"):
58
+ bbox = draw.textbbox((x0, y0), str(label), font)
59
+ else:
60
+ w, h = draw.textsize(str(label), font)
61
+ bbox = (x0, y0, w + x0, y0 + h)
62
+ # bbox = draw.textbbox((x0, y0), str(label))
63
+ draw.rectangle(bbox, fill=color)
64
+ draw.text((x0, y0), str(label), fill="white")
65
+
66
+ mask_draw.rectangle([x0, y0, x1, y1], fill=255, width=6)
67
+
68
+ return image_pil, mask
69
+
70
+ # Use GPU if available
71
+ if torch.cuda.is_available():
72
+ device = torch.device("cuda:4")
73
+ else:
74
+ device = torch.device("cpu")
75
+
76
+ # load OWL-ViT model
77
+ owlvit_model = OwlViTForObjectDetection.from_pretrained("google/owlvit-base-patch32").to(device)
78
+ owlvit_model.eval()
79
+ owlvit_processor = OwlViTProcessor.from_pretrained("google/owlvit-base-patch32")
80
+
81
+ # run segment anything (SAM)
82
+ sam_predictor = SamPredictor(build_sam(checkpoint="./sam_vit_h_4b8939.pth"))
83
+
84
+ def query_image(img, text_prompt, box_threshold):
85
+ # load image
86
+ if not isinstance(img, PIL.Image.Image):
87
+ pil_img = Image.fromarray(np.uint8(img)).convert('RGB')
88
+
89
+ text_prompt = text_prompt
90
+ texts = text_prompt.split(",")
91
+
92
+ box_threshold = box_threshold
93
+
94
+ # run object detection model
95
+ with torch.no_grad():
96
+ inputs = owlvit_processor(text=texts, images=pil_img, return_tensors="pt").to(device)
97
+ outputs = owlvit_model(**inputs)
98
+
99
+ # Target image sizes (height, width) to rescale box predictions [batch_size, 2]
100
+ target_sizes = torch.Tensor([pil_img.size[::-1]])
101
+ # Convert outputs (bounding boxes and class logits) to COCO API
102
+ results = owlvit_processor.post_process_object_detection(outputs=outputs, threshold=box_threshold, target_sizes=target_sizes.to(device))
103
+ scores = torch.sigmoid(outputs.logits)
104
+ topk_scores, topk_idxs = torch.topk(scores, k=1, dim=1)
105
+
106
+ i = 0 # Retrieve predictions for the first image for the corresponding text queries
107
+ text = texts[i]
108
+
109
+ topk_idxs = topk_idxs.squeeze(1).tolist()
110
+ topk_boxes = results[i]['boxes'][topk_idxs]
111
+ topk_scores = topk_scores.view(len(text), -1)
112
+ topk_labels = results[i]["labels"][topk_idxs]
113
+ boxes, scores, labels = topk_boxes, topk_scores, topk_labels
114
+
115
+ # boxes, scores, labels = results[i]["boxes"], results[i]["scores"], results[i]["labels"]
116
+
117
+
118
+ # Print detected objects and rescaled box coordinates
119
+ # for box, score, label in zip(boxes, scores, labels):
120
+ # box = [round(i, 2) for i in box.tolist()]
121
+ # print(f"Detected {text[label]} with confidence {round(score.item(), 3)} at location {box}")
122
+
123
+ boxes = boxes.cpu().detach().numpy()
124
+ normalized_boxes = copy.deepcopy(boxes)
125
+
126
+ # # visualize pred
127
+ size = pil_img.size
128
+ pred_dict = {
129
+ "boxes": normalized_boxes,
130
+ "size": [size[1], size[0]], # H, W
131
+ "labels": [text[idx] for idx in labels]
132
+ }
133
+
134
+ # release the OWL-ViT
135
+ # owlvit_model.cpu()
136
+ # del owlvit_model
137
+ gc.collect()
138
+ torch.cuda.empty_cache()
139
+
140
+ # run segment anything (SAM)
141
+ open_cv_image = np.array(pil_img)
142
+ image = cv2.cvtColor(open_cv_image, cv2.COLOR_BGR2RGB)
143
+ sam_predictor.set_image(image)
144
+
145
+ H, W = size[1], size[0]
146
+
147
+ for i in range(boxes.shape[0]):
148
+ boxes[i] = torch.Tensor(boxes[i])
149
+
150
+ boxes = torch.tensor(boxes, device=sam_predictor.device)
151
+
152
+ transformed_boxes = sam_predictor.transform.apply_boxes_torch(boxes, image.shape[:2])
153
+
154
+ masks, _, _ = sam_predictor.predict_torch(
155
+ point_coords = None,
156
+ point_labels = None,
157
+ boxes = transformed_boxes,
158
+ multimask_output = False,
159
+ )
160
+ plt.figure(figsize=(10, 10))
161
+ plt.imshow(image)
162
+ for mask in masks:
163
+ show_mask(mask.cpu().numpy(), plt.gca(), random_color=True)
164
+ for box in boxes:
165
+ show_box(box.numpy(), plt.gca())
166
+ plt.axis('off')
167
+
168
+ import io
169
+ buf = io.BytesIO()
170
+ plt.savefig(buf)
171
+ buf.seek(0)
172
+ owlvit_segment_image = Image.open(buf).convert('RGB')
173
+
174
+ # grounded results
175
+ image_with_box = plot_boxes_to_image(pil_img, pred_dict)[0]
176
+
177
+ # return owlvit_segment_image, image_with_box
178
+ return owlvit_segment_image