LukasHug commited on
Commit
e0ca52a
1 Parent(s): 7771cfc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -13
app.py CHANGED
@@ -41,7 +41,7 @@ priority = {
41
 
42
 
43
  @spaces.GPU
44
- def run_llava(prompt, pil_image):
45
  image_size = pil_image.size
46
  image_tensor = image_processor.preprocess(pil_image, return_tensors='pt')['pixel_values'].half().cuda()
47
  # images_tensor = load_images(images, image_processor)
@@ -54,11 +54,11 @@ def run_llava(prompt, pil_image):
54
  images=image_tensor,
55
  image_sizes=[image_size],
56
  do_sample=True,
57
- temperature=0.2,
58
- top_p=0.95,
59
  top_k=50,
60
  num_beams=2,
61
- max_new_tokens=1024,
62
  use_cache=True,
63
  stopping_criteria=[KeywordsStoppingCriteria(['}'], tokenizer, input_ids)]
64
  )
@@ -84,11 +84,12 @@ def get_conv_log_filename():
84
  name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json")
85
  return name
86
 
 
87
  def get_model_list():
88
  models = [
89
- 'LukasHug/LlavaGuard-7B-hf',
90
- 'LukasHug/LlavaGuard-13B-hf',
91
- 'LukasHug/LlavaGuard-34B-hf', ]
92
  return models
93
 
94
 
@@ -249,7 +250,6 @@ def llava_bot(state, model_selector, temperature, top_p, max_new_tokens, request
249
  new_state.append_message(new_state.roles[1], None)
250
  state = new_state
251
 
252
-
253
  # Construct prompt
254
  prompt = state.get_prompt()
255
 
@@ -262,13 +262,12 @@ def llava_bot(state, model_selector, temperature, top_p, max_new_tokens, request
262
  os.makedirs(os.path.dirname(filename), exist_ok=True)
263
  image.save(filename)
264
 
265
- output = run_llava(prompt, all_images[0])
266
 
267
  state.messages[-1][-1] = output
268
 
269
  yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
270
 
271
-
272
  finish_tstamp = time.time()
273
  logger.info(f"{output}")
274
 
@@ -406,7 +405,10 @@ def build_demo(embed_mode, cur_dir=None, concurrency_count=10):
406
  [textbox, upvote_btn, downvote_btn, flag_btn]
407
  )
408
 
409
- model_selector.change(load_selected_model)
 
 
 
410
 
411
  regenerate_btn.click(
412
  regenerate,
@@ -517,7 +519,6 @@ Set the environment variable `model` to change the model:
517
  tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, None, model_name, token=api_key)
518
  model.config.tokenizer_model_max_length = 2048 * 2
519
 
520
-
521
  exit_status = 0
522
  try:
523
  demo = build_demo(embed_mode=False, cur_dir='./', concurrency_count=concurrency_count)
@@ -534,4 +535,4 @@ Set the environment variable `model` to change the model:
534
  print(e)
535
  exit_status = 1
536
  finally:
537
- sys.exit(exit_status)
 
41
 
42
 
43
  @spaces.GPU
44
+ def run_llava(prompt, pil_image, temperature, top_p, max_new_tokens):
45
  image_size = pil_image.size
46
  image_tensor = image_processor.preprocess(pil_image, return_tensors='pt')['pixel_values'].half().cuda()
47
  # images_tensor = load_images(images, image_processor)
 
54
  images=image_tensor,
55
  image_sizes=[image_size],
56
  do_sample=True,
57
+ temperature=temperature,
58
+ top_p=top_p,
59
  top_k=50,
60
  num_beams=2,
61
+ max_new_tokens=max_new_tokens,
62
  use_cache=True,
63
  stopping_criteria=[KeywordsStoppingCriteria(['}'], tokenizer, input_ids)]
64
  )
 
84
  name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json")
85
  return name
86
 
87
+
88
  def get_model_list():
89
  models = [
90
+ 'LukasHug/LlavaGuard-7B-hf',
91
+ 'LukasHug/LlavaGuard-13B-hf',
92
+ 'LukasHug/LlavaGuard-34B-hf', ][:2]
93
  return models
94
 
95
 
 
250
  new_state.append_message(new_state.roles[1], None)
251
  state = new_state
252
 
 
253
  # Construct prompt
254
  prompt = state.get_prompt()
255
 
 
262
  os.makedirs(os.path.dirname(filename), exist_ok=True)
263
  image.save(filename)
264
 
265
+ output = run_llava(prompt, all_images[0], temperature, top_p, max_new_tokens)
266
 
267
  state.messages[-1][-1] = output
268
 
269
  yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
270
 
 
271
  finish_tstamp = time.time()
272
  logger.info(f"{output}")
273
 
 
405
  [textbox, upvote_btn, downvote_btn, flag_btn]
406
  )
407
 
408
+ model_selector.change(
409
+ load_selected_model,
410
+ [model_selector],
411
+ )
412
 
413
  regenerate_btn.click(
414
  regenerate,
 
519
  tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, None, model_name, token=api_key)
520
  model.config.tokenizer_model_max_length = 2048 * 2
521
 
 
522
  exit_status = 0
523
  try:
524
  demo = build_demo(embed_mode=False, cur_dir='./', concurrency_count=concurrency_count)
 
535
  print(e)
536
  exit_status = 1
537
  finally:
538
+ sys.exit(exit_status)