SkalskiP commited on
Commit
576e22a
·
1 Parent(s): d1212b2

open vocabulary detection with Florence2 + masks with SAM2

Browse files
Files changed (4) hide show
  1. app.py +107 -57
  2. utils/florence.py +3 -0
  3. utils/modes.py +7 -0
  4. utils/sam.py +22 -0
app.py CHANGED
@@ -1,15 +1,16 @@
1
  from typing import Tuple, Optional
2
 
3
  import gradio as gr
4
- import numpy as np
5
  import supervision as sv
6
  import torch
7
  from PIL import Image
8
 
9
  from utils.florence import load_florence_model, run_florence_inference, \
10
  FLORENCE_DETAILED_CAPTION_TASK, \
11
- FLORENCE_CAPTION_TO_PHRASE_GROUNDING_TASK
12
- from utils.sam import load_sam_model
 
 
13
 
14
  MARKDOWN = """
15
  # Florence2 + SAM2 🔥
@@ -21,94 +22,122 @@ into masks.
21
  """
22
 
23
  EXAMPLES = [
24
- "https://media.roboflow.com/notebooks/examples/dog-2.jpeg",
25
- "https://media.roboflow.com/notebooks/examples/dog-3.jpeg",
26
- "https://media.roboflow.com/notebooks/examples/dog-4.jpeg"
 
 
27
  ]
28
 
29
  DEVICE = torch.device("cuda")
30
-
31
  FLORENCE_MODEL, FLORENCE_PROCESSOR = load_florence_model(device=DEVICE)
32
  SAM_MODEL = load_sam_model(device=DEVICE)
33
  BOX_ANNOTATOR = sv.BoxAnnotator(color_lookup=sv.ColorLookup.INDEX)
34
  LABEL_ANNOTATOR = sv.LabelAnnotator(
35
  color_lookup=sv.ColorLookup.INDEX,
36
  text_position=sv.Position.CENTER_OF_MASS,
37
- text_color=sv.Color.BLACK,
38
  border_radius=5
39
  )
40
  MASK_ANNOTATOR = sv.MaskAnnotator(color_lookup=sv.ColorLookup.INDEX)
41
 
42
 
43
- def process(
44
- image_input,
45
- ) -> Tuple[Optional[Image.Image], Optional[str]]:
46
- if image_input is None:
47
- return None, None
 
48
 
49
- _, result = run_florence_inference(
50
- model=FLORENCE_MODEL,
51
- processor=FLORENCE_PROCESSOR,
52
- device=DEVICE,
53
- image=image_input,
54
- task=FLORENCE_DETAILED_CAPTION_TASK
55
- )
56
- caption = result[FLORENCE_DETAILED_CAPTION_TASK]
57
- _, result = run_florence_inference(
58
- model=FLORENCE_MODEL,
59
- processor=FLORENCE_PROCESSOR,
60
- device=DEVICE,
61
- image=image_input,
62
- task=FLORENCE_CAPTION_TO_PHRASE_GROUNDING_TASK,
63
- text=caption
64
- )
65
- detections = sv.Detections.from_lmm(
66
- lmm=sv.LMM.FLORENCE_2,
67
- result=result,
68
- resolution_wh=image_input.size
69
- )
70
- image = np.array(image_input.convert("RGB"))
71
- SAM_MODEL.set_image(image)
72
- mask, score, _ = SAM_MODEL.predict(box=detections.xyxy, multimask_output=False)
73
 
74
- # dirty fix; remove this later
75
- if len(mask.shape) == 4:
76
- mask = np.squeeze(mask)
 
 
77
 
78
- detections.mask = mask.astype(bool)
79
 
80
- output_image = image_input.copy()
81
- output_image = MASK_ANNOTATOR.annotate(output_image, detections)
82
- output_image = BOX_ANNOTATOR.annotate(output_image, detections)
83
- output_image = LABEL_ANNOTATOR.annotate(output_image, detections)
84
- return output_image, caption
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
 
86
 
87
  with gr.Blocks() as demo:
88
  gr.Markdown(MARKDOWN)
 
 
 
 
 
 
 
89
  with gr.Row():
90
  with gr.Column():
91
  image_input_component = gr.Image(
92
  type='pil', label='Upload image')
 
 
93
  submit_button_component = gr.Button(value='Submit', variant='primary')
94
-
95
  with gr.Column():
96
  image_output_component = gr.Image(type='pil', label='Image output')
97
- text_output_component = gr.Textbox(label='Caption output')
98
 
99
- submit_button_component.click(
100
- fn=process,
101
- inputs=[image_input_component],
102
- outputs=[
103
- image_output_component,
104
- text_output_component
105
- ]
106
- )
107
  with gr.Row():
108
  gr.Examples(
109
  fn=process,
110
  examples=EXAMPLES,
111
- inputs=[image_input_component],
 
 
 
 
112
  outputs=[
113
  image_output_component,
114
  text_output_component
@@ -116,4 +145,25 @@ with gr.Blocks() as demo:
116
  run_on_click=True
117
  )
118
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
  demo.launch(debug=False, show_error=True, max_threads=1)
 
1
  from typing import Tuple, Optional
2
 
3
  import gradio as gr
 
4
  import supervision as sv
5
  import torch
6
  from PIL import Image
7
 
8
  from utils.florence import load_florence_model, run_florence_inference, \
9
  FLORENCE_DETAILED_CAPTION_TASK, \
10
+ FLORENCE_CAPTION_TO_PHRASE_GROUNDING_TASK, FLORENCE_OPEN_VOCABULARY_DETECTION_TASK
11
+ from utils.modes import INFERENCE_MODES, OPEN_VOCABULARY_DETECTION, \
12
+ CAPTION_GROUNDING_MASKS
13
+ from utils.sam import load_sam_model, run_sam_inference
14
 
15
  MARKDOWN = """
16
  # Florence2 + SAM2 🔥
 
22
  """
23
 
24
  EXAMPLES = [
25
+ [OPEN_VOCABULARY_DETECTION, "https://media.roboflow.com/notebooks/examples/dog-2.jpeg", 'straw'],
26
+ [OPEN_VOCABULARY_DETECTION, "https://media.roboflow.com/notebooks/examples/dog-2.jpeg", 'napkin'],
27
+ [OPEN_VOCABULARY_DETECTION, "https://media.roboflow.com/notebooks/examples/dog-3.jpeg", 'tail'],
28
+ [CAPTION_GROUNDING_MASKS, "https://media.roboflow.com/notebooks/examples/dog-2.jpeg", None],
29
+ [CAPTION_GROUNDING_MASKS, "https://media.roboflow.com/notebooks/examples/dog-3.jpeg", None],
30
  ]
31
 
32
  DEVICE = torch.device("cuda")
 
33
  FLORENCE_MODEL, FLORENCE_PROCESSOR = load_florence_model(device=DEVICE)
34
  SAM_MODEL = load_sam_model(device=DEVICE)
35
  BOX_ANNOTATOR = sv.BoxAnnotator(color_lookup=sv.ColorLookup.INDEX)
36
  LABEL_ANNOTATOR = sv.LabelAnnotator(
37
  color_lookup=sv.ColorLookup.INDEX,
38
  text_position=sv.Position.CENTER_OF_MASS,
39
+ text_color=sv.Color.from_hex("#FFFFFF"),
40
  border_radius=5
41
  )
42
  MASK_ANNOTATOR = sv.MaskAnnotator(color_lookup=sv.ColorLookup.INDEX)
43
 
44
 
45
+ def annotate_image(image, detections):
46
+ output_image = image.copy()
47
+ output_image = MASK_ANNOTATOR.annotate(output_image, detections)
48
+ output_image = BOX_ANNOTATOR.annotate(output_image, detections)
49
+ output_image = LABEL_ANNOTATOR.annotate(output_image, detections)
50
+ return output_image
51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
+ def on_mode_dropdown_change(text):
54
+ return [
55
+ gr.Textbox(visible=text == OPEN_VOCABULARY_DETECTION),
56
+ gr.Textbox(visible=text == CAPTION_GROUNDING_MASKS),
57
+ ]
58
 
 
59
 
60
+ def process(
61
+ mode_dropdown, image_input, text_input
62
+ ) -> Tuple[Optional[Image.Image], Optional[str]]:
63
+ if not image_input:
64
+ return None, None
65
+
66
+ if mode_dropdown == OPEN_VOCABULARY_DETECTION:
67
+ if not text_input:
68
+ return None, None
69
+
70
+ _, result = run_florence_inference(
71
+ model=FLORENCE_MODEL,
72
+ processor=FLORENCE_PROCESSOR,
73
+ device=DEVICE,
74
+ image=image_input,
75
+ task=FLORENCE_OPEN_VOCABULARY_DETECTION_TASK,
76
+ text=text_input
77
+ )
78
+ detections = sv.Detections.from_lmm(
79
+ lmm=sv.LMM.FLORENCE_2,
80
+ result=result,
81
+ resolution_wh=image_input.size
82
+ )
83
+ detections = run_sam_inference(SAM_MODEL, image_input, detections)
84
+ return annotate_image(image_input, detections), None
85
+
86
+ if mode_dropdown == CAPTION_GROUNDING_MASKS:
87
+ _, result = run_florence_inference(
88
+ model=FLORENCE_MODEL,
89
+ processor=FLORENCE_PROCESSOR,
90
+ device=DEVICE,
91
+ image=image_input,
92
+ task=FLORENCE_DETAILED_CAPTION_TASK
93
+ )
94
+ caption = result[FLORENCE_DETAILED_CAPTION_TASK]
95
+ _, result = run_florence_inference(
96
+ model=FLORENCE_MODEL,
97
+ processor=FLORENCE_PROCESSOR,
98
+ device=DEVICE,
99
+ image=image_input,
100
+ task=FLORENCE_CAPTION_TO_PHRASE_GROUNDING_TASK,
101
+ text=caption
102
+ )
103
+ detections = sv.Detections.from_lmm(
104
+ lmm=sv.LMM.FLORENCE_2,
105
+ result=result,
106
+ resolution_wh=image_input.size
107
+ )
108
+ detections = run_sam_inference(SAM_MODEL, image_input, detections)
109
+ return annotate_image(image_input, detections), caption
110
 
111
 
112
  with gr.Blocks() as demo:
113
  gr.Markdown(MARKDOWN)
114
+ mode_dropdown_component = gr.Dropdown(
115
+ choices=INFERENCE_MODES,
116
+ value=INFERENCE_MODES[0],
117
+ label="Mode",
118
+ info="Select a mode to use.",
119
+ interactive=True
120
+ )
121
  with gr.Row():
122
  with gr.Column():
123
  image_input_component = gr.Image(
124
  type='pil', label='Upload image')
125
+ text_input_component = gr.Textbox(
126
+ label='Text prompt')
127
  submit_button_component = gr.Button(value='Submit', variant='primary')
 
128
  with gr.Column():
129
  image_output_component = gr.Image(type='pil', label='Image output')
130
+ text_output_component = gr.Textbox(label='Caption output', visible=False)
131
 
 
 
 
 
 
 
 
 
132
  with gr.Row():
133
  gr.Examples(
134
  fn=process,
135
  examples=EXAMPLES,
136
+ inputs=[
137
+ mode_dropdown_component,
138
+ image_input_component,
139
+ text_input_component
140
+ ],
141
  outputs=[
142
  image_output_component,
143
  text_output_component
 
145
  run_on_click=True
146
  )
147
 
148
+ submit_button_component.click(
149
+ fn=process,
150
+ inputs=[
151
+ mode_dropdown_component,
152
+ image_input_component,
153
+ text_input_component
154
+ ],
155
+ outputs=[
156
+ image_output_component,
157
+ text_output_component
158
+ ]
159
+ )
160
+ mode_dropdown_component.change(
161
+ on_mode_dropdown_change,
162
+ inputs=[mode_dropdown_component],
163
+ outputs=[
164
+ text_input_component,
165
+ text_output_component
166
+ ]
167
+ )
168
+
169
  demo.launch(debug=False, show_error=True, max_threads=1)
utils/florence.py CHANGED
@@ -8,8 +8,11 @@ from transformers import AutoModelForCausalLM, AutoProcessor
8
  from transformers.dynamic_module_utils import get_imports
9
 
10
  FLORENCE_CHECKPOINT = "microsoft/Florence-2-large"
 
11
  FLORENCE_DETAILED_CAPTION_TASK = '<MORE_DETAILED_CAPTION>'
12
  FLORENCE_CAPTION_TO_PHRASE_GROUNDING_TASK = '<CAPTION_TO_PHRASE_GROUNDING>'
 
 
13
 
14
 
15
  def fixed_get_imports(filename: Union[str, os.PathLike]) -> list[str]:
 
8
  from transformers.dynamic_module_utils import get_imports
9
 
10
  FLORENCE_CHECKPOINT = "microsoft/Florence-2-large"
11
+ FLORENCE_OBJECT_DETECTION_TASK = '<OD>'
12
  FLORENCE_DETAILED_CAPTION_TASK = '<MORE_DETAILED_CAPTION>'
13
  FLORENCE_CAPTION_TO_PHRASE_GROUNDING_TASK = '<CAPTION_TO_PHRASE_GROUNDING>'
14
+ FLORENCE_OPEN_VOCABULARY_DETECTION_TASK = '<OPEN_VOCABULARY_DETECTION>'
15
+ FLORENCE_DENSE_REGION_CAPTION_TASK = '<DENSE_REGION_CAPTION>'
16
 
17
 
18
  def fixed_get_imports(filename: Union[str, os.PathLike]) -> list[str]:
utils/modes.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ OPEN_VOCABULARY_DETECTION = "open vocabulary detection + masks"
2
+ CAPTION_GROUNDING_MASKS = "caption + grounding + masks"
3
+
4
+ INFERENCE_MODES = [
5
+ OPEN_VOCABULARY_DETECTION,
6
+ CAPTION_GROUNDING_MASKS
7
+ ]
utils/sam.py CHANGED
@@ -1,4 +1,9 @@
 
 
 
 
1
  import torch
 
2
  from sam2.build_sam import build_sam2
3
  from sam2.sam2_image_predictor import SAM2ImagePredictor
4
 
@@ -13,3 +18,20 @@ def load_sam_model(
13
  ) -> SAM2ImagePredictor:
14
  model = build_sam2(config, checkpoint, device=device)
15
  return SAM2ImagePredictor(sam_model=model)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any
2
+
3
+ import numpy as np
4
+ import supervision as sv
5
  import torch
6
+ from PIL import Image
7
  from sam2.build_sam import build_sam2
8
  from sam2.sam2_image_predictor import SAM2ImagePredictor
9
 
 
18
  ) -> SAM2ImagePredictor:
19
  model = build_sam2(config, checkpoint, device=device)
20
  return SAM2ImagePredictor(sam_model=model)
21
+
22
+
23
+ def run_sam_inference(
24
+ model: Any,
25
+ image: Image,
26
+ detections: sv.Detections
27
+ ) -> sv.Detections:
28
+ image = np.array(image.convert("RGB"))
29
+ model.set_image(image)
30
+ mask, score, _ = model.predict(box=detections.xyxy, multimask_output=False)
31
+
32
+ # dirty fix; remove this later
33
+ if len(mask.shape) == 4:
34
+ mask = np.squeeze(mask)
35
+
36
+ detections.mask = mask.astype(bool)
37
+ return detections