ttengwang commited on
Commit
eabdb1c
1 Parent(s): 12dc496

improve chat box; add a enable_wiki button

Browse files
Files changed (3) hide show
  1. app.py +64 -38
  2. caption_anything.py +2 -2
  3. text_refiner/text_refiner.py +8 -6
app.py CHANGED
@@ -120,28 +120,44 @@ def update_click_state(click_state, caption, click_mode):
120
  raise NotImplementedError
121
 
122
 
123
- def chat_with_points(chat_input, click_state, state, text_refiner):
124
  if text_refiner is None:
125
  response = "Text refiner is not initilzed, please input openai api key."
126
  state = state + [(chat_input, response)]
127
- return state, state
128
 
129
  points, labels, captions = click_state
130
- # point_chat_prompt = "I want you act as a chat bot in terms of image. I will give you some points (w, h) in the image and tell you what happed on the point in natural language. Note that (0, 0) refers to the top-left corner of the image, w refers to the width and h refers the height. You should chat with me based on the fact in the image instead of imagination. Now I tell you the points with their visual description:\n{points_with_caps}\nNow begin chatting! Human: {chat_input}\nAI: "
 
 
131
  # # "The image is of width {width} and height {height}."
132
- point_chat_prompt = "a) Revised prompt: I am an AI trained to chat with you about an image based on specific points (w, h) you provide, along with their visual descriptions. Please note that (0, 0) refers to the top-left corner of the image, w refers to the width, and h refers to the height. Here are the points and their descriptions you've given me: {points_with_caps}. Now, let's chat! Human: {chat_input} AI:"
133
  prev_visual_context = ""
134
- pos_points = [f"{points[i][0]}, {points[i][1]}" for i in range(len(points)) if labels[i] == 1]
135
- if len(captions):
136
- prev_visual_context = ', '.join(pos_points) + captions[-1] + '\n'
137
- else:
138
- prev_visual_context = 'no point exists.'
139
- chat_prompt = point_chat_prompt.format(**{"points_with_caps": prev_visual_context, "chat_input": chat_input})
 
 
 
 
 
 
 
 
 
 
 
 
 
140
  response = text_refiner.llm(chat_prompt)
141
  state = state + [(chat_input, response)]
142
- return state, state
 
143
 
144
- def inference_seg_cap(image_input, point_prompt, click_mode, language, sentiment, factuality,
145
  length, image_embedding, state, click_state, original_size, input_size, text_refiner, evt:gr.SelectData):
146
 
147
  model = build_caption_anything_with_models(
@@ -173,11 +189,12 @@ def inference_seg_cap(image_input, point_prompt, click_mode, language, sentiment
173
  prompt = get_prompt(coordinate, click_state, click_mode)
174
  print('prompt: ', prompt, 'controls: ', controls)
175
 
176
- out = model.inference(image_input, prompt, controls, disable_gpt=True)
177
- state = state + [(None, "Image point: {}, Input label: {}".format(prompt["input_point"], prompt["input_label"]))]
 
178
  # for k, v in out['generated_captions'].items():
179
  # state = state + [(f'{k}: {v}', None)]
180
- state = state + [("caption: {}".format(out['generated_captions']['raw_caption']), None)]
181
  wiki = out['generated_captions'].get('wiki', "")
182
 
183
  update_click_state(click_state, out['generated_captions']['raw_caption'], click_mode)
@@ -191,15 +208,18 @@ def inference_seg_cap(image_input, point_prompt, click_mode, language, sentiment
191
 
192
  yield state, state, click_state, chat_input, image_input, wiki
193
  if not args.disable_gpt and model.text_refiner:
194
- refined_caption = model.text_refiner.inference(query=text, controls=controls, context=out['context_captions'])
195
  # new_cap = 'Original: ' + text + '. Refined: ' + refined_caption['caption']
196
  new_cap = refined_caption['caption']
 
 
197
  refined_image_input = create_bubble_frame(origin_image_input, new_cap, (evt.index[0], evt.index[1]))
198
  yield state, state, click_state, chat_input, refined_image_input, wiki
199
 
200
 
201
  def upload_callback(image_input, state):
202
- state = [] + [('Image size: ' + str(image_input.size), None)]
 
203
  click_state = [[], [], []]
204
  res = 1024
205
  width, height = image_input.size
@@ -219,7 +239,7 @@ def upload_callback(image_input, state):
219
  image_embedding = model.segmenter.image_embedding
220
  original_size = model.segmenter.predictor.original_size
221
  input_size = model.segmenter.predictor.input_size
222
- return state, state, image_input, click_state, image_input, image_embedding, original_size, input_size
223
 
224
  with gr.Blocks(
225
  css='''
@@ -229,6 +249,7 @@ with gr.Blocks(
229
  ) as iface:
230
  state = gr.State([])
231
  click_state = gr.State([[],[],[]])
 
232
  origin_image = gr.State(None)
233
  image_embedding = gr.State(None)
234
  text_refiner = gr.State(None)
@@ -260,14 +281,13 @@ with gr.Blocks(
260
  clear_button_image = gr.Button(value="Clear Image", interactive=True)
261
  with gr.Column(visible=False) as modules_need_gpt:
262
  with gr.Row(scale=1.0):
263
- language = gr.Dropdown(['English', 'Chinese', 'French', "Spanish", "Arabic", "Portuguese", "Cantonese"], value="English", label="Language", interactive=True)
264
-
265
- sentiment = gr.Radio(
266
- choices=["Positive", "Natural", "Negative"],
267
- value="Natural",
268
- label="Sentiment",
269
- interactive=True,
270
- )
271
  with gr.Row(scale=1.0):
272
  factuality = gr.Radio(
273
  choices=["Factual", "Imagination"],
@@ -281,8 +301,13 @@ with gr.Blocks(
281
  value=10,
282
  step=1,
283
  interactive=True,
284
- label="Length",
285
- )
 
 
 
 
 
286
  with gr.Column(visible=True) as modules_not_need_gpt3:
287
  gr.Examples(
288
  examples=examples,
@@ -303,7 +328,7 @@ with gr.Blocks(
303
  with gr.Column(visible=False) as modules_not_need_gpt2:
304
  chatbot = gr.Chatbot(label="Chat about Selected Object",).style(height=550,scale=0.5)
305
  with gr.Column(visible=False) as modules_need_gpt3:
306
- chat_input = gr.Textbox(lines=1, label="Chat Input")
307
  with gr.Row():
308
  clear_button_text = gr.Button(value="Clear Text", interactive=True)
309
  submit_button_text = gr.Button(value="Submit", interactive=True, variant="primary")
@@ -320,30 +345,30 @@ with gr.Blocks(
320
  show_progress=False
321
  )
322
  clear_button_image.click(
323
- lambda: (None, [], [], [[], [], []], "", ""),
324
  [],
325
- [image_input, chatbot, state, click_state, wiki_output, origin_image],
326
  queue=False,
327
  show_progress=False
328
  )
329
  clear_button_text.click(
330
- lambda: ([], [], [[], [], []]),
331
  [],
332
- [chatbot, state, click_state],
333
  queue=False,
334
  show_progress=False
335
  )
336
  image_input.clear(
337
- lambda: (None, [], [], [[], [], []], "", ""),
338
  [],
339
- [image_input, chatbot, state, click_state, wiki_output, origin_image],
340
  queue=False,
341
  show_progress=False
342
  )
343
 
344
- image_input.upload(upload_callback,[image_input, state], [chatbot, state, origin_image, click_state, image_input, image_embedding, original_size, input_size])
345
- chat_input.submit(chat_with_points, [chat_input, click_state, state, text_refiner], [chatbot, state])
346
- example_image.change(upload_callback,[example_image, state], [state, state, origin_image, click_state, image_input, image_embedding, original_size, input_size])
347
 
348
  # select coordinate
349
  image_input.select(inference_seg_cap,
@@ -351,6 +376,7 @@ with gr.Blocks(
351
  origin_image,
352
  point_prompt,
353
  click_mode,
 
354
  language,
355
  sentiment,
356
  factuality,
 
120
  raise NotImplementedError
121
 
122
 
123
+ def chat_with_points(chat_input, click_state, chat_state, state, text_refiner):
124
  if text_refiner is None:
125
  response = "Text refiner is not initilzed, please input openai api key."
126
  state = state + [(chat_input, response)]
127
+ return state, state, chat_state
128
 
129
  points, labels, captions = click_state
130
+ # point_chat_prompt = "I want you act as a chat bot in terms of image. I will give you some points (w, h) in the image and tell you what happed on the point in natural language. Note that (0, 0) refers to the top-left corner of the image, w refers to the width and h refers the height. You should chat with me based on the fact in the image instead of imagination. Now I tell you the points with their visual description:\n{points_with_caps}\nNow begin chatting!"
131
+ suffix = '\nHuman: {chat_input}\nAI: '
132
+ qa_template = '\nHuman: {q}\nAI: {a}'
133
  # # "The image is of width {width} and height {height}."
134
+ point_chat_prompt = "I am an AI trained to chat with you about an image based on specific points (w, h) you provide, along with their visual descriptions. Please note that (0, 0) refers to the top-left corner of the image, w refers to the width, and h refers to the height. Here are the points and their descriptions you've given me: {points_with_caps} \n Now, let's chat!"
135
  prev_visual_context = ""
136
+ pos_points = []
137
+ pos_captions = []
138
+ for i in range(len(points)):
139
+ if labels[i] == 1:
140
+ pos_points.append(f"({points[i][0]}, {points[i][0]})")
141
+ pos_captions.append(captions[i])
142
+ prev_visual_context = prev_visual_context + '\n' + 'Points: ' +', '.join(pos_points) + '. Description: ' + pos_captions[-1]
143
+
144
+ context_length_thres = 500
145
+ prev_history = ""
146
+ for i in range(len(chat_state)):
147
+ q, a = chat_state[i]
148
+ if len(prev_history) < context_length_thres:
149
+ prev_history = prev_history + qa_template.format(**{"q": q, "a": a})
150
+ else:
151
+ break
152
+
153
+ chat_prompt = point_chat_prompt.format(**{"points_with_caps": prev_visual_context}) + prev_history + suffix.format(**{"chat_input": chat_input})
154
+ print('\nchat_prompt: ', chat_prompt)
155
  response = text_refiner.llm(chat_prompt)
156
  state = state + [(chat_input, response)]
157
+ chat_state = chat_state + [(chat_input, response)]
158
+ return state, state, chat_state
159
 
160
+ def inference_seg_cap(image_input, point_prompt, click_mode, enable_wiki, language, sentiment, factuality,
161
  length, image_embedding, state, click_state, original_size, input_size, text_refiner, evt:gr.SelectData):
162
 
163
  model = build_caption_anything_with_models(
 
189
  prompt = get_prompt(coordinate, click_state, click_mode)
190
  print('prompt: ', prompt, 'controls: ', controls)
191
 
192
+ enable_wiki = True if enable_wiki in ['True', 'TRUE', 'true', True, 'Yes', 'YES', 'yes'] else False
193
+ out = model.inference(image_input, prompt, controls, disable_gpt=True, enable_wiki=enable_wiki)
194
+ state = state + [("Image point: {}, Input label: {}".format(prompt["input_point"], prompt["input_label"]), None)]
195
  # for k, v in out['generated_captions'].items():
196
  # state = state + [(f'{k}: {v}', None)]
197
+ state = state + [(None, "raw_caption: {}".format(out['generated_captions']['raw_caption']))]
198
  wiki = out['generated_captions'].get('wiki', "")
199
 
200
  update_click_state(click_state, out['generated_captions']['raw_caption'], click_mode)
 
208
 
209
  yield state, state, click_state, chat_input, image_input, wiki
210
  if not args.disable_gpt and model.text_refiner:
211
+ refined_caption = model.text_refiner.inference(query=text, controls=controls, context=out['context_captions'], enable_wiki=enable_wiki)
212
  # new_cap = 'Original: ' + text + '. Refined: ' + refined_caption['caption']
213
  new_cap = refined_caption['caption']
214
+ wiki = refined_caption['wiki']
215
+ state = state + [(None, f"caption: {new_cap}")]
216
  refined_image_input = create_bubble_frame(origin_image_input, new_cap, (evt.index[0], evt.index[1]))
217
  yield state, state, click_state, chat_input, refined_image_input, wiki
218
 
219
 
220
  def upload_callback(image_input, state):
221
+ state = [] + [(None, 'Image size: ' + str(image_input.size))]
222
+ chat_state = []
223
  click_state = [[], [], []]
224
  res = 1024
225
  width, height = image_input.size
 
239
  image_embedding = model.segmenter.image_embedding
240
  original_size = model.segmenter.predictor.original_size
241
  input_size = model.segmenter.predictor.input_size
242
+ return state, state, chat_state, image_input, click_state, image_input, image_embedding, original_size, input_size
243
 
244
  with gr.Blocks(
245
  css='''
 
249
  ) as iface:
250
  state = gr.State([])
251
  click_state = gr.State([[],[],[]])
252
+ chat_state = gr.State([])
253
  origin_image = gr.State(None)
254
  image_embedding = gr.State(None)
255
  text_refiner = gr.State(None)
 
281
  clear_button_image = gr.Button(value="Clear Image", interactive=True)
282
  with gr.Column(visible=False) as modules_need_gpt:
283
  with gr.Row(scale=1.0):
284
+ language = gr.Dropdown(['English', 'Chinese', 'French', "Spanish", "Arabic", "Portuguese", "Cantonese"], value="English", label="Language", interactive=True)
285
+ sentiment = gr.Radio(
286
+ choices=["Positive", "Natural", "Negative"],
287
+ value="Natural",
288
+ label="Sentiment",
289
+ interactive=True,
290
+ )
 
291
  with gr.Row(scale=1.0):
292
  factuality = gr.Radio(
293
  choices=["Factual", "Imagination"],
 
301
  value=10,
302
  step=1,
303
  interactive=True,
304
+ label="Generated Caption Length",
305
+ )
306
+ enable_wiki = gr.Radio(
307
+ choices=["Yes", "No"],
308
+ value="No",
309
+ label="Enable Wiki",
310
+ interactive=True)
311
  with gr.Column(visible=True) as modules_not_need_gpt3:
312
  gr.Examples(
313
  examples=examples,
 
328
  with gr.Column(visible=False) as modules_not_need_gpt2:
329
  chatbot = gr.Chatbot(label="Chat about Selected Object",).style(height=550,scale=0.5)
330
  with gr.Column(visible=False) as modules_need_gpt3:
331
+ chat_input = gr.Textbox(show_label=False, placeholder="Enter text and press Enter").style(container=False)
332
  with gr.Row():
333
  clear_button_text = gr.Button(value="Clear Text", interactive=True)
334
  submit_button_text = gr.Button(value="Submit", interactive=True, variant="primary")
 
345
  show_progress=False
346
  )
347
  clear_button_image.click(
348
+ lambda: (None, [], [], [], [[], [], []], "", ""),
349
  [],
350
+ [image_input, chatbot, state, chat_state, click_state, wiki_output, origin_image],
351
  queue=False,
352
  show_progress=False
353
  )
354
  clear_button_text.click(
355
+ lambda: ([], [], [[], [], [], []], []),
356
  [],
357
+ [chatbot, state, click_state, chat_state],
358
  queue=False,
359
  show_progress=False
360
  )
361
  image_input.clear(
362
+ lambda: (None, [], [], [], [[], [], []], "", ""),
363
  [],
364
+ [image_input, chatbot, state, chat_state, click_state, wiki_output, origin_image],
365
  queue=False,
366
  show_progress=False
367
  )
368
 
369
+ image_input.upload(upload_callback,[image_input, state], [chatbot, state, chat_state, origin_image, click_state, image_input, image_embedding, original_size, input_size])
370
+ chat_input.submit(chat_with_points, [chat_input, click_state, chat_state, state, text_refiner], [chatbot, state, chat_state])
371
+ example_image.change(upload_callback,[example_image, state], [chatbot, state, chat_state, origin_image, click_state, image_input, image_embedding, original_size, input_size])
372
 
373
  # select coordinate
374
  image_input.select(inference_seg_cap,
 
376
  origin_image,
377
  point_prompt,
378
  click_mode,
379
+ enable_wiki,
380
  language,
381
  sentiment,
382
  factuality,
caption_anything.py CHANGED
@@ -30,7 +30,7 @@ class CaptionAnything():
30
  self.text_refiner = None
31
  print('OpenAI GPT is not available')
32
 
33
- def inference(self, image, prompt, controls, disable_gpt=False):
34
  # segment with prompt
35
  print("CA prompt: ", prompt, "CA controls",controls)
36
  seg_mask = self.segmenter.inference(image, prompt)[0, ...]
@@ -59,7 +59,7 @@ class CaptionAnything():
59
  if self.args.context_captions:
60
  context_captions.append(self.captioner.inference(image))
61
  if not disable_gpt and self.text_refiner is not None:
62
- refined_caption = self.text_refiner.inference(query=caption, controls=controls, context=context_captions)
63
  else:
64
  refined_caption = {'raw_caption': caption}
65
  out = {'generated_captions': refined_caption,
 
30
  self.text_refiner = None
31
  print('OpenAI GPT is not available')
32
 
33
+ def inference(self, image, prompt, controls, disable_gpt=False, enable_wiki=False):
34
  # segment with prompt
35
  print("CA prompt: ", prompt, "CA controls",controls)
36
  seg_mask = self.segmenter.inference(image, prompt)[0, ...]
 
59
  if self.args.context_captions:
60
  context_captions.append(self.captioner.inference(image))
61
  if not disable_gpt and self.text_refiner is not None:
62
+ refined_caption = self.text_refiner.inference(query=caption, controls=controls, context=context_captions, enable_wiki=enable_wiki)
63
  else:
64
  refined_caption = {'raw_caption': caption}
65
  out = {'generated_captions': refined_caption,
text_refiner/text_refiner.py CHANGED
@@ -39,7 +39,7 @@ class TextRefiner:
39
  print('prompt: ', input)
40
  return input
41
 
42
- def inference(self, query: str, controls: dict, context: list=[]):
43
  """
44
  query: the caption of the region of interest, generated by captioner
45
  controls: a dict of control singals, e.g., {"length": 5, "sentiment": "positive"}
@@ -58,15 +58,17 @@ class TextRefiner:
58
  response = self.llm(input)
59
  response = self.parse(response)
60
 
61
- tmp_configs = {"query": query}
62
- prompt_wiki = self.wiki_prompts.format(**tmp_configs)
63
- response_wiki = self.llm(prompt_wiki)
64
- response_wiki = self.parse2(response_wiki)
 
 
65
  out = {
66
  'raw_caption': query,
67
  'caption': response,
68
  'wiki': response_wiki
69
- }
70
  print(out)
71
  return out
72
 
 
39
  print('prompt: ', input)
40
  return input
41
 
42
+ def inference(self, query: str, controls: dict, context: list=[], enable_wiki=False):
43
  """
44
  query: the caption of the region of interest, generated by captioner
45
  controls: a dict of control singals, e.g., {"length": 5, "sentiment": "positive"}
 
58
  response = self.llm(input)
59
  response = self.parse(response)
60
 
61
+ response_wiki = ""
62
+ if enable_wiki:
63
+ tmp_configs = {"query": query}
64
+ prompt_wiki = self.wiki_prompts.format(**tmp_configs)
65
+ response_wiki = self.llm(prompt_wiki)
66
+ response_wiki = self.parse2(response_wiki)
67
  out = {
68
  'raw_caption': query,
69
  'caption': response,
70
  'wiki': response_wiki
71
+ }
72
  print(out)
73
  return out
74