jykoh commited on
Commit
12a8812
1 Parent(s): 2bf4a87

Add concurrency

Browse files
Files changed (2) hide show
  1. app.py +8 -6
  2. fromage/models.py +0 -2
app.py CHANGED
@@ -73,14 +73,14 @@ def generate_for_prompt(input_text, state, ret_scale_factor, max_nm_rets, num_wo
73
  elif type(output) == list:
74
  for image in output:
75
  filename = save_image_to_local(image)
76
- response += f'<img src="/file={filename}">'
77
  elif type(output) == Image.Image:
78
- filename = save_image_to_local(output)
79
- response += f'<img src="/file={filename}">'
80
 
81
  # TODO(jykoh): Persist image inputs.
82
  chat_history = model_inputs + [' '.join([s for s in model_outputs if type(s) == str]) + '\n']
83
- conversation.append((input_text, response))
84
 
85
  # Set input image to None.
86
  print('state', state, flush=True)
@@ -98,7 +98,7 @@ with gr.Blocks() as demo:
98
 
99
  with gr.Row():
100
  with gr.Column(scale=0.3, min_width=0):
101
- ret_scale_factor = gr.Slider(minimum=0.0, maximum=3.0, value=1.0, step=0.1, interactive=True, label="Multiplier for returning images (higher means more frequent)")
102
  max_ret_images = gr.Number(minimum=0, maximum=3, value=1, precision=1, interactive=True, label="Max images to return")
103
  gr_max_len = gr.Number(value=32, precision=1, label="Max # of words returned", interactive=True)
104
  gr_temperature = gr.Number(value=0.0, label="Temperature", interactive=True)
@@ -113,4 +113,6 @@ with gr.Blocks() as demo:
113
  image_btn.upload(upload_image, [gr_state, image_btn], [gr_state, chatbot])
114
  clear_btn.click(reset, [], [gr_state, chatbot])
115
 
116
- demo.launch(share=False, debug=True, server_name="0.0.0.0")
 
 
 
73
  elif type(output) == list:
74
  for image in output:
75
  filename = save_image_to_local(image)
76
+ response += f'<br/><img src="/file={filename}">'
77
  elif type(output) == Image.Image:
78
+ filename = save_image_to_local(output)
79
+ response += f'<br/><img src="/file={filename}">'
80
 
81
  # TODO(jykoh): Persist image inputs.
82
  chat_history = model_inputs + [' '.join([s for s in model_outputs if type(s) == str]) + '\n']
83
+ conversation.append((input_text, response.replace('[RET]', ''))) # Remove [RET] from outputs.
84
 
85
  # Set input image to None.
86
  print('state', state, flush=True)
 
98
 
99
  with gr.Row():
100
  with gr.Column(scale=0.3, min_width=0):
101
+ ret_scale_factor = gr.Slider(minimum=0.0, maximum=3.0, value=1.3, step=0.1, interactive=True, label="Multiplier for returning images (higher means more frequent)")
102
  max_ret_images = gr.Number(minimum=0, maximum=3, value=1, precision=1, interactive=True, label="Max images to return")
103
  gr_max_len = gr.Number(value=32, precision=1, label="Max # of words returned", interactive=True)
104
  gr_temperature = gr.Number(value=0.0, label="Temperature", interactive=True)
 
113
  image_btn.upload(upload_image, [gr_state, image_btn], [gr_state, chatbot])
114
  clear_btn.click(reset, [], [gr_state, chatbot])
115
 
116
+ # demo.launch(share=False, debug=True, server_name="0.0.0.0")
117
+ demo.queue(concurrency_count=5)
118
+ demo.launch(debug=True)
fromage/models.py CHANGED
@@ -634,8 +634,6 @@ def load_fromage(embeddings_dir: str, model_args_path: str, model_ckpt_path: str
634
  ret_token_idx = tokenizer('[RET]', add_special_tokens=False).input_ids
635
  assert len(ret_token_idx) == 1, ret_token_idx
636
  model_kwargs['retrieval_token_idx'] = ret_token_idx[0]
637
- # model_kwargs['opt_version'] = 'facebook/opt-125m'
638
- # model_kwargs['visual_encoder'] = 'openai/clip-vit-base-patch32'
639
  args = namedtuple('args', model_kwargs)(**model_kwargs)
640
 
641
  # Initialize model for inference.
 
634
  ret_token_idx = tokenizer('[RET]', add_special_tokens=False).input_ids
635
  assert len(ret_token_idx) == 1, ret_token_idx
636
  model_kwargs['retrieval_token_idx'] = ret_token_idx[0]
 
 
637
  args = namedtuple('args', model_kwargs)(**model_kwargs)
638
 
639
  # Initialize model for inference.