ngthanhtinqn commited on
Commit
accbb3c
β€’
1 Parent(s): a72b3f0
Files changed (1) hide show
  1. app.py +76 -4
app.py CHANGED
@@ -1,7 +1,79 @@
 
 
1
  import gradio as gr
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
 
6
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import cv2
3
  import gradio as gr
4
+ import numpy as np
5
+ from transformers import OwlViTProcessor, OwlViTForObjectDetection
6
 
 
 
7
 
8
+ # Use GPU if available
9
+ if torch.cuda.is_available():
10
+ device = torch.device("cuda")
11
+ else:
12
+ device = torch.device("cpu")
13
+
14
+ model = OwlViTForObjectDetection.from_pretrained("google/owlvit-base-patch32").to(device)
15
+ model.eval()
16
+ processor = OwlViTProcessor.from_pretrained("google/owlvit-base-patch32")
17
+
18
+
19
+ def query_image(img, text_queries, score_threshold):
20
+ text_queries = text_queries
21
+ text_queries = text_queries.split(",")
22
+
23
+ target_sizes = torch.Tensor([img.shape[:2]])
24
+ inputs = processor(text=text_queries, images=img, return_tensors="pt").to(device)
25
+
26
+ with torch.no_grad():
27
+ outputs = model(**inputs)
28
+
29
+ outputs.logits = outputs.logits.cpu()
30
+ outputs.pred_boxes = outputs.pred_boxes.cpu()
31
+ results = processor.post_process(outputs=outputs, target_sizes=target_sizes)
32
+ boxes, scores, labels = results[0]["boxes"], results[0]["scores"], results[0]["labels"]
33
+
34
+ font = cv2.FONT_HERSHEY_SIMPLEX
35
+
36
+ for box, score, label in zip(boxes, scores, labels):
37
+ box = [int(i) for i in box.tolist()]
38
+
39
+ if score >= score_threshold:
40
+ img = cv2.rectangle(img, box[:2], box[2:], (255,0,0), 5)
41
+ if box[3] + 25 > 768:
42
+ y = box[3] - 10
43
+ else:
44
+ y = box[3] + 25
45
+
46
+ img = cv2.putText(
47
+ img, text_queries[label], (box[0], y), font, 1, (255,0,0), 2, cv2.LINE_AA
48
+ )
49
+ return img
50
+
51
+
52
+ description = """
53
+ Gradio demo for <a href="https://huggingface.co/docs/transformers/main/en/model_doc/owlvit">OWL-ViT</a>,
54
+ introduced in <a href="https://arxiv.org/abs/2205.06230">Simple Open-Vocabulary Object Detection
55
+ with Vision Transformers</a>.
56
+ \n\nYou can use OWL-ViT to query images with text descriptions of any object.
57
+ To use it, simply upload an image and enter comma separated text descriptions of objects you want to query the image for. You
58
+ can also use the score threshold slider to set a threshold to filter out low probability predictions.
59
+ \n\nOWL-ViT is trained on text templates,
60
+ hence you can get better predictions by querying the image with text templates used in training the original model: *"photo of a star-spangled banner"*,
61
+ *"image of a shoe"*. Refer to the <a href="https://arxiv.org/abs/2103.00020">CLIP</a> paper to see the full list of text templates used to augment the training data.
62
+ \n\n<a href="https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/zeroshot_object_detection_with_owlvit.ipynb">Colab demo</a>
63
+ """
64
+ demo = gr.Interface(
65
+ query_image,
66
+ inputs=[gr.Image(), "text", gr.Slider(0, 1, value=0.1)],
67
+ outputs="image",
68
+ title="Zero-Shot Object Detection with OWL-ViT",
69
+ description=description,
70
+ examples=[
71
+ ["./demo_images/cats.png", "cats,ears", 0.11],
72
+ ["./demo_images/demo1.jpg", "bear,soil,sea", 0.1],
73
+ ["./demo_images/demo2.jpg", "dog,ear,leg,eyes,tail", 0.1],
74
+ ["./demo_images/tanager.jpg", "wing,eyes,back,legs,tail", 0.01]
75
+ ],
76
+ )
77
+
78
+ # demo.launch()
79
+ demo.launch(server_name="0.0.0.0", debug=True)