ttengwang commited on
Commit
13c1c2e
1 Parent(s): eeb5fe8

assign api key and img embed from different users to different sessions

Browse files
Files changed (4) hide show
  1. app.py +141 -60
  2. caption_anything.py +21 -10
  3. segmenter/__init__.py +4 -2
  4. segmenter/base_segmenter.py +7 -4
app.py CHANGED
@@ -15,6 +15,10 @@ import copy
15
  from tools import mask_painter
16
  from PIL import Image
17
  import os
 
 
 
 
18
 
19
  def download_checkpoint(url, folder, filename):
20
  os.makedirs(folder, exist_ok=True)
@@ -50,37 +54,74 @@ examples = [
50
  ]
51
 
52
  args = parse_augment()
53
- args.disable_reuse_features = True
54
  # args.device = 'cuda:5'
55
- # args.disable_gpt = False
56
- # args.enable_reduce_tokens = True
57
  # args.port=20322
58
- model = CaptionAnything(args)
 
 
 
59
 
60
- def init_openai_api_key(api_key):
61
- # os.environ['OPENAI_API_KEY'] = api_key
62
- model.init_refiner(api_key)
63
- openai_available = model.text_refiner is not None
64
- return gr.update(visible = openai_available), gr.update(visible = openai_available), gr.update(visible = openai_available), gr.update(visible = True), gr.update(visible = True)
65
 
66
- def get_prompt(chat_input, click_state):
67
- points = click_state[0]
68
- labels = click_state[1]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  inputs = json.loads(chat_input)
70
- for input in inputs:
71
- points.append(input[:2])
72
- labels.append(input[2])
 
 
 
 
 
 
 
 
 
 
 
 
 
73
 
74
  prompt = {
75
  "prompt_type":["click"],
76
- "input_point":points,
77
- "input_label":labels,
78
  "multimask_output":"True",
79
  }
80
  return prompt
81
 
82
- def chat_with_points(chat_input, click_state, state):
83
- if model.text_refiner is None:
 
 
 
 
 
 
 
 
 
84
  response = "Text refiner is not initilzed, please input openai api key."
85
  state = state + [(chat_input, response)]
86
  return state, state
@@ -96,11 +137,26 @@ def chat_with_points(chat_input, click_state, state):
96
  else:
97
  prev_visual_context = 'no point exists.'
98
  chat_prompt = point_chat_prompt.format(**{"points_with_caps": prev_visual_context, "chat_input": chat_input})
99
- response = model.text_refiner.llm(chat_prompt)
100
  state = state + [(chat_input, response)]
101
  return state, state
102
 
103
- def inference_seg_cap(image_input, point_prompt, language, sentiment, factuality, length, state, click_state, evt:gr.SelectData):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
 
105
  if point_prompt == 'Positive':
106
  coordinate = "[[{}, {}, 1]]".format(str(evt.index[0]), str(evt.index[1]))
@@ -114,7 +170,7 @@ def inference_seg_cap(image_input, point_prompt, language, sentiment, factuality
114
 
115
  # click_coordinate = "[[{}, {}, 1]]".format(str(evt.index[0]), str(evt.index[1]))
116
  # chat_input = click_coordinate
117
- prompt = get_prompt(coordinate, click_state)
118
  print('prompt: ', prompt, 'controls: ', controls)
119
 
120
  out = model.inference(image_input, prompt, controls)
@@ -123,12 +179,12 @@ def inference_seg_cap(image_input, point_prompt, language, sentiment, factuality
123
  # state = state + [(f'{k}: {v}', None)]
124
  state = state + [("caption: {}".format(out['generated_captions']['raw_caption']), None)]
125
  wiki = out['generated_captions'].get('wiki', "")
126
- click_state[2].append(out['generated_captions']['raw_caption'])
127
-
128
  text = out['generated_captions']['raw_caption']
129
  # draw = ImageDraw.Draw(image_input)
130
  # draw.text((evt.index[0], evt.index[1]), text, textcolor=(0,0,255), text_size=120)
131
- input_mask = np.array(Image.open(out['mask_save_path']).convert('P'))
132
  image_input = mask_painter(np.array(image_input), input_mask)
133
  origin_image_input = image_input
134
  image_input = create_bubble_frame(image_input, text, (evt.index[0], evt.index[1]))
@@ -151,10 +207,19 @@ def upload_callback(image_input, state):
151
  if ratio < 1.0:
152
  image_input = image_input.resize((int(width * ratio), int(height * ratio)))
153
  print('Scaling input image to {}'.format(image_input.size))
154
- model.segmenter.image = None
155
- model.segmenter.image_embedding = None
 
 
 
 
 
 
156
  model.segmenter.set_image(image_input)
157
- return state, image_input, click_state, image_input
 
 
 
158
 
159
  with gr.Blocks(
160
  css='''
@@ -165,6 +230,10 @@ with gr.Blocks(
165
  state = gr.State([])
166
  click_state = gr.State([[],[],[]])
167
  origin_image = gr.State(None)
 
 
 
 
168
 
169
  gr.Markdown(title)
170
  gr.Markdown(description)
@@ -175,17 +244,24 @@ with gr.Blocks(
175
  image_input = gr.Image(type="pil", interactive=True, elem_id="image_upload")
176
  example_image = gr.Image(type="pil", interactive=False, visible=False)
177
  with gr.Row(scale=1.0):
178
- point_prompt = gr.Radio(
179
- choices=["Positive", "Negative"],
180
- value="Positive",
181
- label="Point Prompt",
182
- interactive=True)
183
- clear_button_clike = gr.Button(value="Clear Clicks", interactive=True)
184
- clear_button_image = gr.Button(value="Clear Image", interactive=True)
 
 
 
 
 
 
 
185
  with gr.Column(visible=False) as modules_need_gpt:
186
  with gr.Row(scale=1.0):
187
  language = gr.Dropdown(['English', 'Chinese', 'French', "Spanish", "Arabic", "Portuguese", "Cantonese"], value="English", label="Language", interactive=True)
188
-
189
  sentiment = gr.Radio(
190
  choices=["Positive", "Natural", "Negative"],
191
  value="Natural",
@@ -206,27 +282,36 @@ with gr.Blocks(
206
  step=1,
207
  interactive=True,
208
  label="Length",
209
- )
210
-
 
 
 
 
211
  with gr.Column(scale=0.5):
212
  openai_api_key = gr.Textbox(
213
- placeholder="Input openAI API key and press Enter (Input blank will disable GPT)",
214
  show_label=False,
215
  label = "OpenAI API Key",
216
  lines=1,
217
- type="password"
218
- )
 
 
219
  with gr.Column(visible=False) as modules_need_gpt2:
220
- wiki_output = gr.Textbox(lines=6, label="Wiki")
221
  with gr.Column(visible=False) as modules_not_need_gpt2:
222
- chatbot = gr.Chatbot(label="Chat about Selected Object",).style(height=450,scale=0.5)
223
  with gr.Column(visible=False) as modules_need_gpt3:
224
  chat_input = gr.Textbox(lines=1, label="Chat Input")
225
  with gr.Row():
226
  clear_button_text = gr.Button(value="Clear Text", interactive=True)
227
  submit_button_text = gr.Button(value="Submit", interactive=True, variant="primary")
228
-
229
- openai_api_key.submit(init_openai_api_key, inputs=[openai_api_key], outputs=[modules_need_gpt,modules_need_gpt2, modules_need_gpt3, modules_not_need_gpt, modules_not_need_gpt2])
 
 
 
230
  clear_button_clike.click(
231
  lambda x: ([[], [], []], x, ""),
232
  [origin_image],
@@ -256,33 +341,29 @@ with gr.Blocks(
256
  show_progress=False
257
  )
258
 
259
- def example_callback(x):
260
- model.image_embedding = None
261
- return x
262
-
263
- gr.Examples(
264
- examples=examples,
265
- inputs=[example_image],
266
- )
267
-
268
- image_input.upload(upload_callback,[image_input, state], [state, origin_image, click_state, image_input])
269
- chat_input.submit(chat_with_points, [chat_input, click_state, state], [chatbot, state])
270
- example_image.change(upload_callback,[example_image, state], [state, origin_image, click_state, image_input])
271
 
272
  # select coordinate
273
  image_input.select(inference_seg_cap,
274
  inputs=[
275
  origin_image,
276
  point_prompt,
 
277
  language,
278
  sentiment,
279
  factuality,
280
  length,
 
281
  state,
282
- click_state
 
 
 
283
  ],
284
  outputs=[chatbot, state, click_state, chat_input, image_input, wiki_output],
285
  show_progress=False, queue=True)
286
 
287
- iface.queue(concurrency_count=1, api_open=False)
288
- iface.launch(server_name="0.0.0.0", enable_queue=True)
 
15
  from tools import mask_painter
16
  from PIL import Image
17
  import os
18
+ from captioner import build_captioner
19
+ from segment_anything import sam_model_registry
20
+ from text_refiner import build_text_refiner
21
+ from segmenter import build_segmenter
22
 
23
  def download_checkpoint(url, folder, filename):
24
  os.makedirs(folder, exist_ok=True)
 
54
  ]
55
 
56
  args = parse_augment()
 
57
  # args.device = 'cuda:5'
58
+ # args.disable_gpt = True
59
+ # args.enable_reduce_tokens = False
60
  # args.port=20322
61
+ # args.captioner = 'blip'
62
+ # args.regular_box = True
63
+ shared_captioner = build_captioner(args.captioner, args.device, args)
64
+ shared_sam_model = sam_model_registry['vit_h'](checkpoint=args.segmenter_checkpoint).to(args.device)
65
 
 
 
 
 
 
66
 
67
+ def build_caption_anything_with_models(args, api_key="", captioner=None, sam_model=None, text_refiner=None, session_id=None):
68
+ segmenter = build_segmenter(args.segmenter, args.device, args, model=sam_model)
69
+ captioner = captioner
70
+ if session_id is not None:
71
+ print('Init caption anything for session {}'.format(session_id))
72
+ return CaptionAnything(args, api_key, captioner=captioner, segmenter=segmenter, text_refiner=text_refiner)
73
+
74
+
75
+ def init_openai_api_key(api_key=""):
76
+ text_refiner = None
77
+ if api_key and len(api_key) > 30:
78
+ try:
79
+ text_refiner = build_text_refiner(args.text_refiner, args.device, args, api_key)
80
+ text_refiner.llm('hi') # test
81
+ except:
82
+ text_refiner = None
83
+ openai_available = text_refiner is not None
84
+ return gr.update(visible = openai_available), gr.update(visible = openai_available), gr.update(visible = openai_available), gr.update(visible = True), gr.update(visible = True), gr.update(visible = True), text_refiner
85
+
86
+
87
+ def get_prompt(chat_input, click_state, click_mode):
88
  inputs = json.loads(chat_input)
89
+ if click_mode == 'Continuous':
90
+ points = click_state[0]
91
+ labels = click_state[1]
92
+ for input in inputs:
93
+ points.append(input[:2])
94
+ labels.append(input[2])
95
+ elif click_mode == 'Single':
96
+ points = []
97
+ labels = []
98
+ for input in inputs:
99
+ points.append(input[:2])
100
+ labels.append(input[2])
101
+ click_state[0] = points
102
+ click_state[1] = labels
103
+ else:
104
+ raise NotImplementedError
105
 
106
  prompt = {
107
  "prompt_type":["click"],
108
+ "input_point":click_state[0],
109
+ "input_label":click_state[1],
110
  "multimask_output":"True",
111
  }
112
  return prompt
113
 
114
+ def update_click_state(click_state, caption, click_mode):
115
+ if click_mode == 'Continuous':
116
+ click_state[2].append(caption)
117
+ elif click_mode == 'Single':
118
+ click_state[2] = [caption]
119
+ else:
120
+ raise NotImplementedError
121
+
122
+
123
+ def chat_with_points(chat_input, click_state, state, text_refiner):
124
+ if text_refiner is None:
125
  response = "Text refiner is not initilzed, please input openai api key."
126
  state = state + [(chat_input, response)]
127
  return state, state
 
137
  else:
138
  prev_visual_context = 'no point exists.'
139
  chat_prompt = point_chat_prompt.format(**{"points_with_caps": prev_visual_context, "chat_input": chat_input})
140
+ response = text_refiner.llm(chat_prompt)
141
  state = state + [(chat_input, response)]
142
  return state, state
143
 
144
+ def inference_seg_cap(image_input, point_prompt, click_mode, language, sentiment, factuality,
145
+ length, image_embedding, state, click_state, original_size, input_size, text_refiner, evt:gr.SelectData):
146
+
147
+ model = build_caption_anything_with_models(
148
+ args,
149
+ api_key="",
150
+ captioner=shared_captioner,
151
+ sam_model=shared_sam_model,
152
+ text_refiner=text_refiner,
153
+ session_id=iface.app_id
154
+ )
155
+
156
+ model.segmenter.image_embedding = image_embedding
157
+ model.segmenter.predictor.original_size = original_size
158
+ model.segmenter.predictor.input_size = input_size
159
+ model.segmenter.predictor.is_image_set = True
160
 
161
  if point_prompt == 'Positive':
162
  coordinate = "[[{}, {}, 1]]".format(str(evt.index[0]), str(evt.index[1]))
 
170
 
171
  # click_coordinate = "[[{}, {}, 1]]".format(str(evt.index[0]), str(evt.index[1]))
172
  # chat_input = click_coordinate
173
+ prompt = get_prompt(coordinate, click_state, click_mode)
174
  print('prompt: ', prompt, 'controls: ', controls)
175
 
176
  out = model.inference(image_input, prompt, controls)
 
179
  # state = state + [(f'{k}: {v}', None)]
180
  state = state + [("caption: {}".format(out['generated_captions']['raw_caption']), None)]
181
  wiki = out['generated_captions'].get('wiki', "")
182
+
183
+ update_click_state(click_state, out['generated_captions']['raw_caption'], click_mode)
184
  text = out['generated_captions']['raw_caption']
185
  # draw = ImageDraw.Draw(image_input)
186
  # draw.text((evt.index[0], evt.index[1]), text, textcolor=(0,0,255), text_size=120)
187
+ input_mask = np.array(out['mask'].convert('P'))
188
  image_input = mask_painter(np.array(image_input), input_mask)
189
  origin_image_input = image_input
190
  image_input = create_bubble_frame(image_input, text, (evt.index[0], evt.index[1]))
 
207
  if ratio < 1.0:
208
  image_input = image_input.resize((int(width * ratio), int(height * ratio)))
209
  print('Scaling input image to {}'.format(image_input.size))
210
+
211
+ model = build_caption_anything_with_models(
212
+ args,
213
+ api_key="",
214
+ captioner=shared_captioner,
215
+ sam_model=shared_sam_model,
216
+ session_id=iface.app_id
217
+ )
218
  model.segmenter.set_image(image_input)
219
+ image_embedding = model.segmenter.image_embedding
220
+ original_size = model.segmenter.predictor.original_size
221
+ input_size = model.segmenter.predictor.input_size
222
+ return state, state, image_input, click_state, image_input, image_embedding, original_size, input_size
223
 
224
  with gr.Blocks(
225
  css='''
 
230
  state = gr.State([])
231
  click_state = gr.State([[],[],[]])
232
  origin_image = gr.State(None)
233
+ image_embedding = gr.State(None)
234
+ text_refiner = gr.State(None)
235
+ original_size = gr.State(None)
236
+ input_size = gr.State(None)
237
 
238
  gr.Markdown(title)
239
  gr.Markdown(description)
 
244
  image_input = gr.Image(type="pil", interactive=True, elem_id="image_upload")
245
  example_image = gr.Image(type="pil", interactive=False, visible=False)
246
  with gr.Row(scale=1.0):
247
+ with gr.Row(scale=0.4):
248
+ point_prompt = gr.Radio(
249
+ choices=["Positive", "Negative"],
250
+ value="Positive",
251
+ label="Point Prompt",
252
+ interactive=True)
253
+ click_mode = gr.Radio(
254
+ choices=["Continuous", "Single"],
255
+ value="Continuous",
256
+ label="Clicking Mode",
257
+ interactive=True)
258
+ with gr.Row(scale=0.4):
259
+ clear_button_clike = gr.Button(value="Clear Clicks", interactive=True)
260
+ clear_button_image = gr.Button(value="Clear Image", interactive=True)
261
  with gr.Column(visible=False) as modules_need_gpt:
262
  with gr.Row(scale=1.0):
263
  language = gr.Dropdown(['English', 'Chinese', 'French', "Spanish", "Arabic", "Portuguese", "Cantonese"], value="English", label="Language", interactive=True)
264
+
265
  sentiment = gr.Radio(
266
  choices=["Positive", "Natural", "Negative"],
267
  value="Natural",
 
282
  step=1,
283
  interactive=True,
284
  label="Length",
285
+ )
286
+ with gr.Column(visible=True) as modules_not_need_gpt3:
287
+ gr.Examples(
288
+ examples=examples,
289
+ inputs=[example_image],
290
+ )
291
  with gr.Column(scale=0.5):
292
  openai_api_key = gr.Textbox(
293
+ placeholder="Input openAI API key",
294
  show_label=False,
295
  label = "OpenAI API Key",
296
  lines=1,
297
+ type="password")
298
+ with gr.Row(scale=0.5):
299
+ enable_chatGPT_button = gr.Button(value="Run with ChatGPT", interactive=True, variant='primary')
300
+ disable_chatGPT_button = gr.Button(value="Run without ChatGPT (Faster)", interactive=True, variant='primary')
301
  with gr.Column(visible=False) as modules_need_gpt2:
302
+ wiki_output = gr.Textbox(lines=5, label="Wiki", max_lines=5)
303
  with gr.Column(visible=False) as modules_not_need_gpt2:
304
+ chatbot = gr.Chatbot(label="Chat about Selected Object",).style(height=550,scale=0.5)
305
  with gr.Column(visible=False) as modules_need_gpt3:
306
  chat_input = gr.Textbox(lines=1, label="Chat Input")
307
  with gr.Row():
308
  clear_button_text = gr.Button(value="Clear Text", interactive=True)
309
  submit_button_text = gr.Button(value="Submit", interactive=True, variant="primary")
310
+
311
+ openai_api_key.submit(init_openai_api_key, inputs=[openai_api_key], outputs=[modules_need_gpt,modules_need_gpt2, modules_need_gpt3, modules_not_need_gpt, modules_not_need_gpt2, modules_not_need_gpt3, text_refiner])
312
+ enable_chatGPT_button.click(init_openai_api_key, inputs=[openai_api_key], outputs=[modules_need_gpt,modules_need_gpt2, modules_need_gpt3, modules_not_need_gpt, modules_not_need_gpt2, modules_not_need_gpt3, text_refiner])
313
+ disable_chatGPT_button.click(init_openai_api_key, outputs=[modules_need_gpt,modules_need_gpt2, modules_need_gpt3, modules_not_need_gpt, modules_not_need_gpt2, modules_not_need_gpt3, text_refiner])
314
+
315
  clear_button_clike.click(
316
  lambda x: ([[], [], []], x, ""),
317
  [origin_image],
 
341
  show_progress=False
342
  )
343
 
344
+ image_input.upload(upload_callback,[image_input, state], [chatbot, state, origin_image, click_state, image_input, image_embedding, original_size, input_size])
345
+ chat_input.submit(chat_with_points, [chat_input, click_state, state, text_refiner], [chatbot, state])
346
+ example_image.change(upload_callback,[example_image, state], [state, state, origin_image, click_state, image_input, image_embedding, original_size, input_size])
 
 
 
 
 
 
 
 
 
347
 
348
  # select coordinate
349
  image_input.select(inference_seg_cap,
350
  inputs=[
351
  origin_image,
352
  point_prompt,
353
+ click_mode,
354
  language,
355
  sentiment,
356
  factuality,
357
  length,
358
+ image_embedding,
359
  state,
360
+ click_state,
361
+ original_size,
362
+ input_size,
363
+ text_refiner
364
  ],
365
  outputs=[chatbot, state, click_state, chat_input, image_input, wiki_output],
366
  show_progress=False, queue=True)
367
 
368
+ iface.queue(concurrency_count=1, api_open=False, max_size=10)
369
+ iface.launch(server_name="0.0.0.0", enable_queue=True, server_port=args.port, share=args.gradio_share)
caption_anything.py CHANGED
@@ -6,14 +6,17 @@ import argparse
6
  import pdb
7
  import time
8
  from PIL import Image
 
 
9
 
10
  class CaptionAnything():
11
- def __init__(self, args, api_key=""):
12
  self.args = args
13
- self.captioner = build_captioner(args.captioner, args.device, args)
14
- self.segmenter = build_segmenter(args.segmenter, args.device, args)
 
15
  self.text_refiner = None
16
- if not args.disable_gpt:
17
  self.init_refiner(api_key)
18
 
19
  def init_refiner(self, api_key):
@@ -22,19 +25,25 @@ class CaptionAnything():
22
  self.text_refiner.llm('hi') # test
23
  except:
24
  self.text_refiner = None
25
- print('Openai api key is NOT given')
26
 
27
  def inference(self, image, prompt, controls, disable_gpt=False):
28
  # segment with prompt
29
  print("CA prompt: ", prompt, "CA controls",controls)
30
  seg_mask = self.segmenter.inference(image, prompt)[0, ...]
 
 
 
 
 
 
31
  mask_save_path = f'result/mask_{time.time()}.png'
32
  if not os.path.exists(os.path.dirname(mask_save_path)):
33
  os.makedirs(os.path.dirname(mask_save_path))
34
- new_p = Image.fromarray(seg_mask.astype('int') * 255.)
35
- if new_p.mode != 'RGB':
36
- new_p = new_p.convert('RGB')
37
- new_p.save(mask_save_path)
38
  print('seg_mask path: ', mask_save_path)
39
  print("seg_mask.shape: ", seg_mask.shape)
40
  # captioning with mask
@@ -53,6 +62,7 @@ class CaptionAnything():
53
  out = {'generated_captions': refined_caption,
54
  'crop_save_path': crop_save_path,
55
  'mask_save_path': mask_save_path,
 
56
  'context_captions': context_captions}
57
  return out
58
 
@@ -73,6 +83,7 @@ def parse_augment():
73
  parser.add_argument('--disable_gpt', action="store_true")
74
  parser.add_argument('--enable_reduce_tokens', action="store_true", default=False)
75
  parser.add_argument('--disable_reuse_features', action="store_true", default=False)
 
76
  args = parser.parse_args()
77
 
78
  if args.debug:
@@ -115,4 +126,4 @@ if __name__ == "__main__":
115
  print('Language controls:\n', controls)
116
  out = model.inference(image_path, prompt, controls)
117
 
118
-
 
6
  import pdb
7
  import time
8
  from PIL import Image
9
+ import cv2
10
+ import numpy as np
11
 
12
  class CaptionAnything():
13
+ def __init__(self, args, api_key="", captioner=None, segmenter=None, text_refiner=None):
14
  self.args = args
15
+ self.captioner = build_captioner(args.captioner, args.device, args) if captioner is None else captioner
16
+ self.segmenter = build_segmenter(args.segmenter, args.device, args) if segmenter is None else segmenter
17
+
18
  self.text_refiner = None
19
+ if not args.disable_gpt and text_refiner is not None:
20
  self.init_refiner(api_key)
21
 
22
  def init_refiner(self, api_key):
 
25
  self.text_refiner.llm('hi') # test
26
  except:
27
  self.text_refiner = None
28
+ print('OpenAI GPT is not available')
29
 
30
  def inference(self, image, prompt, controls, disable_gpt=False):
31
  # segment with prompt
32
  print("CA prompt: ", prompt, "CA controls",controls)
33
  seg_mask = self.segmenter.inference(image, prompt)[0, ...]
34
+ if self.args.enable_morphologyex:
35
+ seg_mask = 255 * seg_mask.astype(np.uint8)
36
+ seg_mask = np.stack([seg_mask, seg_mask, seg_mask], axis = -1)
37
+ seg_mask = cv2.morphologyEx(seg_mask, cv2.MORPH_OPEN, kernel = np.ones((6, 6), np.uint8))
38
+ seg_mask = cv2.morphologyEx(seg_mask, cv2.MORPH_CLOSE, kernel = np.ones((6, 6), np.uint8))
39
+ seg_mask = seg_mask[:,:,0] > 0
40
  mask_save_path = f'result/mask_{time.time()}.png'
41
  if not os.path.exists(os.path.dirname(mask_save_path)):
42
  os.makedirs(os.path.dirname(mask_save_path))
43
+ seg_mask_img = Image.fromarray(seg_mask.astype('int') * 255.)
44
+ if seg_mask_img.mode != 'RGB':
45
+ seg_mask_img = seg_mask_img.convert('RGB')
46
+ seg_mask_img.save(mask_save_path)
47
  print('seg_mask path: ', mask_save_path)
48
  print("seg_mask.shape: ", seg_mask.shape)
49
  # captioning with mask
 
62
  out = {'generated_captions': refined_caption,
63
  'crop_save_path': crop_save_path,
64
  'mask_save_path': mask_save_path,
65
+ 'mask': seg_mask_img,
66
  'context_captions': context_captions}
67
  return out
68
 
 
83
  parser.add_argument('--disable_gpt', action="store_true")
84
  parser.add_argument('--enable_reduce_tokens', action="store_true", default=False)
85
  parser.add_argument('--disable_reuse_features', action="store_true", default=False)
86
+ parser.add_argument('--enable_morphologyex', action="store_true", default=False)
87
  args = parser.parse_args()
88
 
89
  if args.debug:
 
126
  print('Language controls:\n', controls)
127
  out = model.inference(image_path, prompt, controls)
128
 
129
+
segmenter/__init__.py CHANGED
@@ -1,6 +1,8 @@
1
  from segmenter.base_segmenter import BaseSegmenter
2
 
3
 
4
- def build_segmenter(type, device, args=None):
5
  if type == 'base':
6
- return BaseSegmenter(device, args.segmenter_checkpoint, reuse_feature=not args.disable_reuse_features)
 
 
 
1
  from segmenter.base_segmenter import BaseSegmenter
2
 
3
 
4
+ def build_segmenter(type, device, args=None, model=None):
5
  if type == 'base':
6
+ return BaseSegmenter(device, args.segmenter_checkpoint, reuse_feature=not args.disable_reuse_features, model=model)
7
+ else:
8
+ raise NotImplementedError()
segmenter/base_segmenter.py CHANGED
@@ -9,15 +9,18 @@ import matplotlib.pyplot as plt
9
  import PIL
10
 
11
  class BaseSegmenter:
12
- def __init__(self, device, checkpoint, model_type='vit_h', reuse_feature = True):
13
  print(f"Initializing BaseSegmenter to {device}")
14
  self.device = device
15
  self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
16
  self.processor = None
17
  self.model_type = model_type
18
- self.checkpoint = checkpoint
19
- self.model = sam_model_registry[self.model_type](checkpoint=self.checkpoint)
20
- self.model.to(device=self.device)
 
 
 
21
  self.reuse_feature = reuse_feature
22
  self.predictor = SamPredictor(self.model)
23
  self.mask_generator = SamAutomaticMaskGenerator(self.model)
 
9
  import PIL
10
 
11
  class BaseSegmenter:
12
+ def __init__(self, device, checkpoint, model_type='vit_h', reuse_feature = True, model=None):
13
  print(f"Initializing BaseSegmenter to {device}")
14
  self.device = device
15
  self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
16
  self.processor = None
17
  self.model_type = model_type
18
+ if model is None:
19
+ self.checkpoint = checkpoint
20
+ self.model = sam_model_registry[self.model_type](checkpoint=self.checkpoint)
21
+ self.model.to(device=self.device)
22
+ else:
23
+ self.model = model
24
  self.reuse_feature = reuse_feature
25
  self.predictor = SamPredictor(self.model)
26
  self.mask_generator = SamAutomaticMaskGenerator(self.model)