Mithu96 commited on
Commit
7acb9bb
·
verified ·
1 Parent(s): e875712

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +69 -0
app.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import cv2
3
+ import gradio as gr
4
+ import numpy as np
5
+ import requests
6
+ from PIL import Image
7
+ from io import BytesIO
8
+ from transformers import OwlViTProcessor, OwlViTForObjectDetection
9
+
10
+
11
+ # Use GPU if available
12
+ if torch.cuda.is_available():
13
+ device = torch.device("cuda")
14
+ else:
15
+ device = torch.device("cpu")
16
+
17
+ model = OwlViTForObjectDetection.from_pretrained("google/owlvit-large-patch14").to(device)
18
+ model.eval()
19
+ processor = OwlViTProcessor.from_pretrained("google/owlvit-large-patch14")
20
+
21
+
22
+ def query_image(img_url, text_queries, score_threshold):
23
+ text_queries = text_queries.split(",")
24
+
25
+ response = requests.get(img_url)
26
+ img = Image.open(BytesIO(response.content))
27
+ img = np.array(img)
28
+
29
+ target_sizes = torch.Tensor([img.shape[:2]])
30
+ inputs = processor(text=text_queries, images=img, return_tensors="pt").to(device)
31
+
32
+ with torch.no_grad():
33
+ outputs = model(**inputs)
34
+
35
+ outputs.logits = outputs.logits.cpu()
36
+ outputs.pred_boxes = outputs.pred_boxes.cpu()
37
+ results = processor.post_process(outputs=outputs, target_sizes=target_sizes)
38
+ boxes, scores, labels = results[0]["boxes"], results[0]["scores"], results[0]["labels"]
39
+
40
+ font = cv2.FONT_HERSHEY_SIMPLEX
41
+
42
+ for box, score, label in zip(boxes, scores, labels):
43
+ box = [int(i) for i in box.tolist()]
44
+
45
+ if score >= score_threshold:
46
+ img = cv2.rectangle(img, box[:2], box[2:], (255,0,0), 5)
47
+ if box[3] + 25 > 768:
48
+ y = box[3] - 10
49
+ else:
50
+ y = box[3] + 25
51
+
52
+ img = cv2.putText(
53
+ img, text_queries[label], (box[0], y), font, 1, (255,0,0), 2, cv2.LINE_AA
54
+ )
55
+ return img
56
+
57
+
58
+ description = """
59
+ DEMO
60
+ """
61
+ demo = gr.Interface(
62
+ query_image,
63
+ inputs=["text", "text", gr.Slider(0, 1, value=0.1)],
64
+ outputs="image",
65
+ title="Zero-Shot Object Detection with OWL-ViT",
66
+ description=description,
67
+ examples=[],
68
+ )
69
+ demo.launch()