jackyccl commited on
Commit
3fc0306
·
1 Parent(s): 890b333

Add app.py - with only grounding dino bounding box function

Browse files
Files changed (1) hide show
  1. app.py +119 -0
app.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from functools import partial
3
+ import cv2
4
+ import os
5
+ from io import BytesIO
6
+ from PIL import Image
7
+ import numpy as np
8
+ from pathlib import Path
9
+ import gradio as gr
10
+
11
+ import warnings
12
+
13
+ import torch
14
+ warnings.filterwarnings("ignore")
15
+
16
+ # grounding DINO
17
+ from groundingdino.models import build_model
18
+ from groundingdino.util.slconfig import SLConfig
19
+ from groundingdino.util.utils import clean_state_dict
20
+ from groundingdino.util.inference import annotate, load_image, predict
21
+ import groundingdino.datasets.transforms as T
22
+
23
+ from huggingface_hub import hf_hub_download
24
+
25
+ os.environ["CUDA_VISIBLE_DEVICES"] = "0"
26
+
27
+
28
+ # Use this command for evaluate the GLIP-T model
29
+ config_file = "groundingdino/config/GroundingDINO_SwinT_OGC.py"
30
+ ckpt_repo_id = "ShilongLiu/GroundingDINO"
31
+ ckpt_filename = "groundingdino_swint_ogc.pth"
32
+ groundingdino_device = 'cpu'
33
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
34
+
35
+
36
+ def load_model_hf(model_config_path, repo_id, filename, device='cpu'):
37
+ args = SLConfig.fromfile(model_config_path)
38
+ model = build_model(args)
39
+ args.device = device
40
+
41
+ cache_file = hf_hub_download(repo_id=repo_id, filename=filename)
42
+ checkpoint = torch.load(cache_file, map_location='cpu')
43
+ log = model.load_state_dict(clean_state_dict(checkpoint['model']), strict=False)
44
+ print("Model loaded from {} \n => {}".format(cache_file, log))
45
+ _ = model.eval()
46
+ return model
47
+
48
+ def image_transform_grounding(init_image):
49
+ transform = T.Compose([
50
+ T.RandomResize([800], max_size=1333),
51
+ T.ToTensor(),
52
+ T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
53
+ ])
54
+ image, _ = transform(init_image, None) # 3, h, w
55
+ return init_image, image
56
+
57
+ def image_transform_grounding_for_vis(init_image):
58
+ transform = T.Compose([
59
+ T.RandomResize([800], max_size=1333),
60
+ ])
61
+ image, _ = transform(init_image, None) # 3, h, w
62
+ return image
63
+
64
+ model = load_model_hf(config_file, ckpt_repo_id, ckpt_filename, groundingdino_device)
65
+
66
+ def get_grounding_box(input_image, grounding_caption, box_threshold, text_threshold):
67
+ init_image = input_image.convert("RGB")
68
+ original_size = init_image.size
69
+
70
+ _, image_tensor = image_transform_grounding(init_image)
71
+ image_pil: Image = image_transform_grounding_for_vis(init_image)
72
+
73
+ # run grounding
74
+ boxes, logits, phrases = predict(model, image_tensor, grounding_caption, box_threshold, text_threshold, device=groundingdino_device)
75
+ annotated_frame = annotate(image_source=np.asarray(image_pil), boxes=boxes, logits=logits, phrases=phrases)
76
+ image_with_box = Image.fromarray(cv2.cvtColor(annotated_frame, cv2.COLOR_BGR2RGB))
77
+
78
+
79
+ return image_with_box
80
+
81
+ if __name__ == "__main__":
82
+
83
+ parser = argparse.ArgumentParser("Grounding SAM demo", add_help=True)
84
+ parser.add_argument("--debug", action="store_true", help="using debug mode")
85
+ parser.add_argument("--share", action="store_true", help="share the app")
86
+ args = parser.parse_args()
87
+
88
+ print(f'args = {args}')
89
+
90
+ block = gr.Blocks().queue()
91
+ with block:
92
+ gr.Markdown("# [Grounding SAM Playground]")
93
+ with gr.Row():
94
+ with gr.Column():
95
+ input_image = gr.Image(source='upload', type="pil")
96
+ grounding_caption = gr.Textbox(label="Detection Prompt")
97
+ run_button = gr.Button(label="Run")
98
+ with gr.Accordion("Advanced options", open=False):
99
+ box_threshold = gr.Slider(
100
+ label="Box Threshold", minimum=0.0, maximum=1.0, value=0.25, step=0.001
101
+ )
102
+ text_threshold = gr.Slider(
103
+ label="Text Threshold", minimum=0.0, maximum=1.0, value=0.25, step=0.001
104
+ )
105
+
106
+ with gr.Column():
107
+ gallery = gr.outputs.Image(
108
+ type="pil",
109
+ # label="grounding results"
110
+ ).style(full_width=True, full_height=True)
111
+ # gallery = gr.Gallery(label="Generated images", show_label=False).style(
112
+ # grid=[1], height="auto", container=True, full_width=True, full_height=True)
113
+
114
+ DESCRIPTION = '### This demo from [Grounded-Segment-Anything](https://github.com/IDEA-Research/Grounded-Segment-Anything) and kudos to thier excellent works. Welcome everyone to try this out and learn together!'
115
+ gr.Markdown(DESCRIPTION)
116
+ run_button.click(fn=get_grounding_box, inputs=[
117
+ input_image, grounding_caption, box_threshold, text_threshold], outputs=[gallery])
118
+
119
+ block.launch(share=False, show_api=False, show_error=True)