merve HF staff commited on
Commit
d2bf7f8
1 Parent(s): 09cf8f6

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +81 -0
app.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
2
+ import torch
3
+ from transformers import SamModel, SamProcessor
4
+
5
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
6
+
7
+ sam_model = SamModel.from_pretrained("facebook/sam-vit-base").to("cuda")
8
+ sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
9
+
10
+ model_id = "IDEA-Research/grounding-dino-base"
11
+
12
+ dino_processor = AutoProcessor.from_pretrained(model_id)
13
+ dino_model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).to(device)
14
+
15
+ def infer_dino(img, text_queries, score_threshold):
16
+ queries=""
17
+ for query in text_queries:
18
+ queries += f"{query}. "
19
+
20
+ width, height = img.shape[:2]
21
+
22
+ target_sizes=[(width, height)]
23
+ inputs = dino_processor(text=queries, images=img, return_tensors="pt").to(device)
24
+
25
+ with torch.no_grad():
26
+ outputs = dino_model(**inputs)
27
+ outputs.logits = outputs.logits.cpu()
28
+ outputs.pred_boxes = outputs.pred_boxes.cpu()
29
+ results = dino_processor.post_process_grounded_object_detection(outputs=outputs, input_ids=inputs.input_ids,
30
+ box_threshold=score_threshold,
31
+ target_sizes=target_sizes)
32
+ return results
33
+
34
+
35
+ import numpy as np
36
+ def query_image(img, text_queries, dino_threshold):
37
+ text_queries = text_queries
38
+ text_queries = text_queries.split(",")
39
+ dino_output = infer_dino(img, text_queries, dino_threshold)
40
+ result_labels=[]
41
+ for pred in dino_output:
42
+ boxes = pred["boxes"].cpu()
43
+ scores = pred["scores"].cpu()
44
+ labels = pred["labels"]
45
+ box = [torch.round(pred["boxes"][0], decimals=2), torch.round(pred["boxes"][1], decimals=2),
46
+ torch.round(pred["boxes"][2], decimals=2), torch.round(pred["boxes"][3], decimals=2)]
47
+ for box, score, label in zip(boxes, scores, labels):
48
+ if label != "":
49
+ inputs = sam_processor(
50
+ img,
51
+ input_boxes=[[[box]]],
52
+ return_tensors="pt"
53
+ ).to("cuda")
54
+
55
+ with torch.no_grad():
56
+ outputs = sam_model(**inputs)
57
+
58
+ mask = sam_processor.image_processor.post_process_masks(
59
+ outputs.pred_masks.cpu(),
60
+ inputs["original_sizes"].cpu(),
61
+ inputs["reshaped_input_sizes"].cpu()
62
+ )[0][0][0].numpy()
63
+ mask = mask[np.newaxis, ...]
64
+ result_labels.append((mask, label))
65
+ return img, result_labels
66
+
67
+ import gradio as gr
68
+
69
+ description = "This Space combines [GroundingDINO](https://huggingface.co/IDEA-Research/grounding-dino-base), a bleeding-edge zero-shot object detection model with [SAM](https://huggingface.co/facebook/sam-vit-base), the state-of-the-art mask generation model. SAM normally doesn't accept text input. Combining SAM with OWLv2 makes SAM text promptable. Try the example or input an image and comma separated candidate labels to segment."
70
+ demo = gr.Interface(
71
+ query_image,
72
+ inputs=[gr.Image(label="Image Input"), gr.Textbox(label = "Candidate Labels"), gr.Slider(0, 1, value=0.05, label="Confidence Threshold for GroundingDINO")],
73
+ outputs="annotatedimage",
74
+ title="GroundingDINO 🤝 SAM for Zero-shot Segmentation",
75
+ description=description,
76
+ examples=[
77
+ ["./cats.png", "cat, fishnet", 0.16],["./bee.jpg", "bee, flower", 0.16]
78
+ ],
79
+ cache_examples=True
80
+ )
81
+ demo.launch(debug=True)