tsi-org commited on
Commit
544e7f1
1 Parent(s): 6421de5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -31
app.py CHANGED
@@ -128,47 +128,48 @@ def flag_last_response(state, model_selector, request: gr.Request):
128
  return ("",) + (disable_btn,) * 3
129
 
130
 
131
- def regenerate(state, image_process_mode1, image_process_mode2, request: gr.Request):
132
  logger.info(f"regenerate. ip: {request.client.host}")
133
  state.messages[-1][-1] = None
134
  prev_human_msg = state.messages[-2]
135
  if type(prev_human_msg[1]) in (tuple, list):
136
- prev_human_msg[1] = (*prev_human_msg[1][:2], image_process_mode1, image_process_mode2)
137
  state.skip_next = False
138
- return (state, state.to_gradio_chatbot(), "", None, None) + (disable_btn,) * 5
139
 
140
 
141
  def clear_history(request: gr.Request):
142
  logger.info(f"clear_history. ip: {request.client.host}")
143
  state = default_conversation.copy()
144
- return (state, state.to_gradio_chatbot(), "", None, None) + (disable_btn,) * 5
145
 
146
 
147
- def add_text(state, text, image1, image2, image_process_mode1, image_process_mode2, request: gr.Request):
148
  logger.info(f"add_text. ip: {request.client.host}. len: {len(text)}")
149
- if len(text) <= 0 and image1 is None and image2 is None:
150
  state.skip_next = True
151
- return (state, state.to_gradio_chatbot(), "", None, None) + (no_change_btn,) * 5
152
  if args.moderate:
153
  flagged = violates_moderation(text)
154
  if flagged:
155
  state.skip_next = True
156
- return (state, state.to_gradio_chatbot(), moderation_msg, None, None) + (
157
  no_change_btn,
158
  ) * 5
159
 
160
  text = text[:1536] # Hard cut-off
161
- if image1 is not None or image2 is not None:
162
  text = text[:1200] # Hard cut-off for images
163
  if "<image>" not in text:
 
164
  text = text + "\n<image>"
165
- text = (text, image1, image2, image_process_mode1, image_process_mode2)
166
  if len(state.get_images(return_pil=True)) > 0:
167
  state = default_conversation.copy()
168
  state.append_message(state.roles[0], text)
169
  state.append_message(state.roles[1], None)
170
  state.skip_next = False
171
- return (state, state.to_gradio_chatbot(), "", None, None) + (disable_btn,) * 5
172
 
173
 
174
  def http_bot(
@@ -179,10 +180,12 @@ def http_bot(
179
  model_name = model_selector
180
 
181
  if state.skip_next:
 
182
  yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
183
  return
184
 
185
  if len(state.messages) == state.offset + 2:
 
186
  if "llava" in model_name.lower():
187
  if "llama-2" in model_name.lower():
188
  template_name = "llava_llama_2"
@@ -219,6 +222,7 @@ def http_bot(
219
  new_state.append_message(new_state.roles[1], None)
220
  state = new_state
221
 
 
222
  controller_url = args.controller_url
223
  ret = requests.post(
224
  controller_url + "/get_worker_address", json={"model": model_name}
@@ -226,6 +230,7 @@ def http_bot(
226
  worker_addr = ret.json()["address"]
227
  logger.info(f"model_name: {model_name}, worker_addr: {worker_addr}")
228
 
 
229
  if worker_addr == "":
230
  state.messages[-1][-1] = server_error_msg
231
  yield (
@@ -239,6 +244,7 @@ def http_bot(
239
  )
240
  return
241
 
 
242
  prompt = state.get_prompt()
243
 
244
  all_images = state.get_images(return_pil=True)
@@ -252,6 +258,7 @@ def http_bot(
252
  os.makedirs(os.path.dirname(filename), exist_ok=True)
253
  image.save(filename)
254
 
 
255
  pload = {
256
  "model": model_name,
257
  "prompt": prompt,
@@ -271,6 +278,7 @@ def http_bot(
271
  yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
272
 
273
  try:
 
274
  response = requests.post(
275
  worker_addr + "/worker_generate_stream",
276
  headers=headers,
@@ -331,13 +339,17 @@ def http_bot(
331
  title_markdown = """
332
  # 🌋 AI Tutor Vision: Large Language and Vision Assistant
333
  [[website]](https://myapps.ai) [[Paper]](https://arxiv.org/abs/2304.08485) [[Model]](https://github.com/haotian-liu/LLaVA/blob/main/docs/MODEL_ZOO.md)
 
334
  ONLY WORKS WITH GPU!
 
335
  You can load the model with 4-bit or 8-bit quantization to make it fit in smaller hardwares. Setting the environment variable `bits` to control the quantization.
336
  *Note: 8-bit seems to be slower than both 4-bit/16-bit. Although it has enough VRAM to support 8-bit, until we figure out the inference speed issue, we recommend 4-bit for A10G for the best efficiency.*
 
337
  Recommended configurations:
338
  | Hardware | T4-Small (16G) | A10G-Small (24G) | A100-Large (40G) |
339
  |-------------------|-----------------|------------------|------------------|
340
  | **Bits** | 4 (default) | 4 | 16 |
 
341
  """
342
 
343
  tos_markdown = """
@@ -355,9 +367,11 @@ The service is a research preview intended for non-commercial use only, subject
355
  """
356
 
357
  block_css = """
 
358
  #buttons button {
359
  min-width: min(120px,100%);
360
  }
 
361
  """
362
 
363
 
@@ -384,15 +398,8 @@ def build_demo(embed_mode):
384
  container=False,
385
  )
386
 
387
- imagebox1 = gr.Image(type="pil")
388
- imagebox2 = gr.Image(type="pil")
389
- image_process_mode1 = gr.Radio(
390
- ["Crop", "Resize", "Pad", "Default"],
391
- value="Default",
392
- label="Preprocess for non-square image",
393
- visible=False,
394
- )
395
- image_process_mode2 = gr.Radio(
396
  ["Crop", "Resize", "Pad", "Default"],
397
  value="Default",
398
  label="Preprocess for non-square image",
@@ -411,7 +418,7 @@ def build_demo(embed_mode):
411
  "What are the things I should be cautious about when I visit here?",
412
  ],
413
  ],
414
- inputs=[imagebox1, textbox, imagebox2],
415
  )
416
 
417
  with gr.Accordion("Parameters", open=False) as parameter_row:
@@ -442,18 +449,20 @@ def build_demo(embed_mode):
442
 
443
  with gr.Column(scale=8):
444
  chatbot = gr.Chatbot(
445
- elem_id="chatbot", label="AI Tutor Vision Chatbot", height=550
446
  )
447
  with gr.Row():
448
  with gr.Column(scale=8):
449
  textbox.render()
450
  with gr.Column(scale=1, min_width=50):
451
  submit_btn = gr.Button(
452
- value="Send", variant="primary", interactive=False)
 
453
  with gr.Row(elem_id="buttons") as button_row:
454
  upvote_btn = gr.Button(value="👍 Upvote", interactive=False)
455
  downvote_btn = gr.Button(value="👎 Downvote", interactive=False)
456
  flag_btn = gr.Button(value="⚠️ Flag", interactive=False)
 
457
  regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
458
  clear_btn = gr.Button(value="🗑️ Clear history", interactive=False)
459
 
@@ -462,6 +471,7 @@ def build_demo(embed_mode):
462
  gr.Markdown(learn_more_markdown)
463
  url_params = gr.JSON(visible=False)
464
 
 
465
  btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn]
466
  upvote_btn.click(
467
  upvote_last_response,
@@ -480,21 +490,21 @@ def build_demo(embed_mode):
480
  )
481
  regenerate_btn.click(
482
  regenerate,
483
- [state, image_process_mode1, image_process_mode2],
484
- [state, chatbot, textbox, imagebox1, imagebox2] + btn_list,
485
  ).then(
486
  http_bot,
487
  [state, model_selector, temperature, top_p, max_output_tokens],
488
  [state, chatbot] + btn_list,
489
  )
490
  clear_btn.click(
491
- clear_history, None, [state, chatbot, textbox, imagebox1, imagebox2] + btn_list
492
  )
493
 
494
  textbox.submit(
495
  add_text,
496
- [state, textbox, imagebox1, imagebox2, image_process_mode1, image_process_mode2],
497
- [state, chatbot, textbox, imagebox1, imagebox2] + btn_list,
498
  ).then(
499
  http_bot,
500
  [state, model_selector, temperature, top_p, max_output_tokens],
@@ -502,8 +512,8 @@ def build_demo(embed_mode):
502
  )
503
  submit_btn.click(
504
  add_text,
505
- [state, textbox, imagebox1, imagebox2, image_process_mode1, image_process_mode2],
506
- [state, chatbot, textbox, imagebox1, imagebox2] + btn_list,
507
  ).then(
508
  http_bot,
509
  [state, model_selector, temperature, top_p, max_output_tokens],
@@ -600,6 +610,7 @@ if __name__ == "__main__":
600
  controller_proc = start_controller()
601
  worker_proc = start_worker(model_path, bits=bits)
602
 
 
603
  time.sleep(10)
604
 
605
  exit_status = 0
@@ -612,4 +623,4 @@ if __name__ == "__main__":
612
  worker_proc.kill()
613
  controller_proc.kill()
614
 
615
- sys.exit(exit_status)
 
128
  return ("",) + (disable_btn,) * 3
129
 
130
 
131
+ def regenerate(state, image_process_mode, request: gr.Request):
132
  logger.info(f"regenerate. ip: {request.client.host}")
133
  state.messages[-1][-1] = None
134
  prev_human_msg = state.messages[-2]
135
  if type(prev_human_msg[1]) in (tuple, list):
136
+ prev_human_msg[1] = (*prev_human_msg[1][:2], image_process_mode)
137
  state.skip_next = False
138
+ return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
139
 
140
 
141
  def clear_history(request: gr.Request):
142
  logger.info(f"clear_history. ip: {request.client.host}")
143
  state = default_conversation.copy()
144
+ return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
145
 
146
 
147
+ def add_text(state, text, image, image_process_mode, request: gr.Request):
148
  logger.info(f"add_text. ip: {request.client.host}. len: {len(text)}")
149
+ if len(text) <= 0 and image is None:
150
  state.skip_next = True
151
+ return (state, state.to_gradio_chatbot(), "", None) + (no_change_btn,) * 5
152
  if args.moderate:
153
  flagged = violates_moderation(text)
154
  if flagged:
155
  state.skip_next = True
156
+ return (state, state.to_gradio_chatbot(), moderation_msg, None) + (
157
  no_change_btn,
158
  ) * 5
159
 
160
  text = text[:1536] # Hard cut-off
161
+ if image is not None:
162
  text = text[:1200] # Hard cut-off for images
163
  if "<image>" not in text:
164
+ # text = '<Image><image></Image>' + text
165
  text = text + "\n<image>"
166
+ text = (text, image, image_process_mode)
167
  if len(state.get_images(return_pil=True)) > 0:
168
  state = default_conversation.copy()
169
  state.append_message(state.roles[0], text)
170
  state.append_message(state.roles[1], None)
171
  state.skip_next = False
172
+ return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
173
 
174
 
175
  def http_bot(
 
180
  model_name = model_selector
181
 
182
  if state.skip_next:
183
+ # This generate call is skipped due to invalid inputs
184
  yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
185
  return
186
 
187
  if len(state.messages) == state.offset + 2:
188
+ # First round of conversation
189
  if "llava" in model_name.lower():
190
  if "llama-2" in model_name.lower():
191
  template_name = "llava_llama_2"
 
222
  new_state.append_message(new_state.roles[1], None)
223
  state = new_state
224
 
225
+ # Query worker address
226
  controller_url = args.controller_url
227
  ret = requests.post(
228
  controller_url + "/get_worker_address", json={"model": model_name}
 
230
  worker_addr = ret.json()["address"]
231
  logger.info(f"model_name: {model_name}, worker_addr: {worker_addr}")
232
 
233
+ # No available worker
234
  if worker_addr == "":
235
  state.messages[-1][-1] = server_error_msg
236
  yield (
 
244
  )
245
  return
246
 
247
+ # Construct prompt
248
  prompt = state.get_prompt()
249
 
250
  all_images = state.get_images(return_pil=True)
 
258
  os.makedirs(os.path.dirname(filename), exist_ok=True)
259
  image.save(filename)
260
 
261
+ # Make requests
262
  pload = {
263
  "model": model_name,
264
  "prompt": prompt,
 
278
  yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
279
 
280
  try:
281
+ # Stream output
282
  response = requests.post(
283
  worker_addr + "/worker_generate_stream",
284
  headers=headers,
 
339
  title_markdown = """
340
  # 🌋 AI Tutor Vision: Large Language and Vision Assistant
341
  [[website]](https://myapps.ai) [[Paper]](https://arxiv.org/abs/2304.08485) [[Model]](https://github.com/haotian-liu/LLaVA/blob/main/docs/MODEL_ZOO.md)
342
+
343
  ONLY WORKS WITH GPU!
344
+
345
  You can load the model with 4-bit or 8-bit quantization to make it fit in smaller hardwares. Setting the environment variable `bits` to control the quantization.
346
  *Note: 8-bit seems to be slower than both 4-bit/16-bit. Although it has enough VRAM to support 8-bit, until we figure out the inference speed issue, we recommend 4-bit for A10G for the best efficiency.*
347
+
348
  Recommended configurations:
349
  | Hardware | T4-Small (16G) | A10G-Small (24G) | A100-Large (40G) |
350
  |-------------------|-----------------|------------------|------------------|
351
  | **Bits** | 4 (default) | 4 | 16 |
352
+
353
  """
354
 
355
  tos_markdown = """
 
367
  """
368
 
369
  block_css = """
370
+
371
  #buttons button {
372
  min-width: min(120px,100%);
373
  }
374
+
375
  """
376
 
377
 
 
398
  container=False,
399
  )
400
 
401
+ imagebox = gr.Image(type="pil")
402
+ image_process_mode = gr.Radio(
 
 
 
 
 
 
 
403
  ["Crop", "Resize", "Pad", "Default"],
404
  value="Default",
405
  label="Preprocess for non-square image",
 
418
  "What are the things I should be cautious about when I visit here?",
419
  ],
420
  ],
421
+ inputs=[imagebox, textbox],
422
  )
423
 
424
  with gr.Accordion("Parameters", open=False) as parameter_row:
 
449
 
450
  with gr.Column(scale=8):
451
  chatbot = gr.Chatbot(
452
+ elem_id="chatbot", label="AI Tutor Vision", height=550
453
  )
454
  with gr.Row():
455
  with gr.Column(scale=8):
456
  textbox.render()
457
  with gr.Column(scale=1, min_width=50):
458
  submit_btn = gr.Button(
459
+ value="Send", variant="primary", interactive=False
460
+ )
461
  with gr.Row(elem_id="buttons") as button_row:
462
  upvote_btn = gr.Button(value="👍 Upvote", interactive=False)
463
  downvote_btn = gr.Button(value="👎 Downvote", interactive=False)
464
  flag_btn = gr.Button(value="⚠️ Flag", interactive=False)
465
+ # stop_btn = gr.Button(value="⏹️ Stop Generation", interactive=False)
466
  regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
467
  clear_btn = gr.Button(value="🗑️ Clear history", interactive=False)
468
 
 
471
  gr.Markdown(learn_more_markdown)
472
  url_params = gr.JSON(visible=False)
473
 
474
+ # Register listeners
475
  btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn]
476
  upvote_btn.click(
477
  upvote_last_response,
 
490
  )
491
  regenerate_btn.click(
492
  regenerate,
493
+ [state, image_process_mode],
494
+ [state, chatbot, textbox, imagebox] + btn_list,
495
  ).then(
496
  http_bot,
497
  [state, model_selector, temperature, top_p, max_output_tokens],
498
  [state, chatbot] + btn_list,
499
  )
500
  clear_btn.click(
501
+ clear_history, None, [state, chatbot, textbox, imagebox] + btn_list
502
  )
503
 
504
  textbox.submit(
505
  add_text,
506
+ [state, textbox, imagebox, image_process_mode],
507
+ [state, chatbot, textbox, imagebox] + btn_list,
508
  ).then(
509
  http_bot,
510
  [state, model_selector, temperature, top_p, max_output_tokens],
 
512
  )
513
  submit_btn.click(
514
  add_text,
515
+ [state, textbox, imagebox, image_process_mode],
516
+ [state, chatbot, textbox, imagebox] + btn_list,
517
  ).then(
518
  http_bot,
519
  [state, model_selector, temperature, top_p, max_output_tokens],
 
610
  controller_proc = start_controller()
611
  worker_proc = start_worker(model_path, bits=bits)
612
 
613
+ # Wait for worker and controller to start
614
  time.sleep(10)
615
 
616
  exit_status = 0
 
623
  worker_proc.kill()
624
  controller_proc.kill()
625
 
626
+ sys.exit(exit_status)