DongfuJiang commited on
Commit
fcac78f
1 Parent(s): f33be86
Files changed (1) hide show
  1. app.py +53 -35
app.py CHANGED
@@ -81,7 +81,7 @@ def get_preprocess_examples(inst, input):
81
  def update_base_llm_dropdown_along_examples(dummy_text):
82
  candidates = CANDIDATE_EXAMPLES[dummy_text]
83
  ex_llm_outputs = {f"LLM-{i+1}": candidates[i]['text'] for i in range(len(candidates))}
84
- return ex_llm_outputs
85
 
86
  def check_save_ranker_inputs(inst, input, llm_outputs, blender_config):
87
  if not inst and not input:
@@ -97,7 +97,11 @@ def check_save_ranker_inputs(inst, input, llm_outputs, blender_config):
97
  }
98
 
99
  def check_fuser_inputs(blender_state, blender_config, ranks):
100
- pass
 
 
 
 
101
 
102
  def llms_rank(inst, input, llm_outputs, blender_config):
103
  candidates = list(llm_outputs.values())
@@ -115,9 +119,11 @@ def llms_fuse(blender_state, blender_config, ranks):
115
  candidates = blender_state['candidates']
116
  top_k_for_fuser = blender_config['top_k_for_fuser']
117
  fuse_params = blender_config.copy()
118
- del fuse_params["top_k_for_fuser"]
 
 
119
  top_k_candidates = get_topk_candidates_from_ranks([ranks], [candidates], top_k=top_k_for_fuser)[0]
120
- fuser_outputs = blender.fuse(instructions=[inst], inputs=[input], candidates=[top_k_candidates], **fuse_params)[0]
121
  return [fuser_outputs, fuser_outputs]
122
 
123
  def display_fuser_output(fuser_output):
@@ -157,7 +163,7 @@ with gr.Blocks(theme='ParityError/Anime') as demo:
157
  saved_fuse_outputs = gr.State(value=[])
158
  gr.Markdown("## Blender Outputs")
159
  with gr.Group():
160
- rank_outputs = gr.Textbox(lines=1, label="Ranking outputs", placeholder="Ranking outputs", show_label=True)
161
  fuser_outputs = gr.Textbox(lines=4, label="Fusing outputs", placeholder="Fusing outputs", show_label=True)
162
  with gr.Row():
163
  rank_button = gr.Button('Rank LLM Outputs', variant='primary')
@@ -173,6 +179,13 @@ with gr.Blocks(theme='ParityError/Anime') as demo:
173
  })
174
 
175
  with gr.Accordion(label='Advanced options', open=False):
 
 
 
 
 
 
 
176
  source_max_length = gr.Slider(
177
  label='Max length of Instruction + Input',
178
  minimum=1,
@@ -187,13 +200,6 @@ with gr.Blocks(theme='ParityError/Anime') as demo:
187
  step=1,
188
  value=DEFAULT_CANDIDATE_MAX_LENGTH,
189
  )
190
- top_k_for_fuser = gr.Slider(
191
- label='Top-k ranked candidates to fuse',
192
- minimum=1,
193
- maximum=3,
194
- step=1,
195
- value=3,
196
- )
197
  max_new_tokens = gr.Slider(
198
  label='Max new tokens fuser can generate',
199
  minimum=1,
@@ -201,19 +207,26 @@ with gr.Blocks(theme='ParityError/Anime') as demo:
201
  step=1,
202
  value=DEFAULT_FUSER_MAX_NEW_TOKENS,
203
  )
204
- temperature = gr.Slider(
205
- label='Temperature of fuser generation',
206
- minimum=0.1,
207
- maximum=2.0,
208
- step=0.1,
209
- value=0.7,
210
- )
211
- top_p = gr.Slider(
212
- label='Top-p of fuser generation',
213
- minimum=0.05,
214
- maximum=1.0,
215
- step=0.05,
216
- value=1.0,
 
 
 
 
 
 
 
217
  )
218
 
219
  examples_dummy_textbox = gr.Textbox(lines=1, label="", placeholder="", show_label=False, visible=False)
@@ -235,7 +248,7 @@ with gr.Blocks(theme='ParityError/Anime') as demo:
235
  examples_dummy_textbox.change(
236
  fn=update_base_llm_dropdown_along_examples,
237
  inputs=[examples_dummy_textbox],
238
- outputs=saved_llm_outputs,
239
  ).then(
240
  fn=display_llm_output,
241
  inputs=[saved_llm_outputs, selected_base_llm_name_dropdown],
@@ -278,7 +291,7 @@ with gr.Blocks(theme='ParityError/Anime') as demo:
278
  fuse_button.click(
279
  fn=check_fuser_inputs,
280
  inputs=[blender_state, blender_config, saved_rank_outputs],
281
- outputs=[],
282
  ).success(
283
  fn=llms_fuse,
284
  inputs=[blender_state, blender_config, saved_rank_outputs],
@@ -312,14 +325,19 @@ with gr.Blocks(theme='ParityError/Anime') as demo:
312
  inputs=[max_new_tokens, blender_config],
313
  outputs=blender_config,
314
  )
315
- temperature.change(
316
- fn=lambda x, y: y.update({"temperature": x}) or y,
317
- inputs=[temperature, blender_config],
318
- outputs=blender_config,
319
- )
320
- top_p.change(
321
- fn=lambda x, y: y.update({"top_p": x}) or y,
322
- inputs=[top_p, blender_config],
 
 
 
 
 
323
  outputs=blender_config,
324
  )
325
 
 
81
  def update_base_llm_dropdown_along_examples(dummy_text):
82
  candidates = CANDIDATE_EXAMPLES[dummy_text]
83
  ex_llm_outputs = {f"LLM-{i+1}": candidates[i]['text'] for i in range(len(candidates))}
84
+ return ex_llm_outputs, "", ""
85
 
86
  def check_save_ranker_inputs(inst, input, llm_outputs, blender_config):
87
  if not inst and not input:
 
97
  }
98
 
99
  def check_fuser_inputs(blender_state, blender_config, ranks):
100
+ if not (blender_state.get("inst", None) or blender_state.get("input", None)):
101
+ raise gr.Error("Please enter instruction or input context")
102
+ if "candidates" not in blender_state or len(ranks)==0:
103
+ raise gr.Error("Please rank LLM outputs first")
104
+ return
105
 
106
  def llms_rank(inst, input, llm_outputs, blender_config):
107
  candidates = list(llm_outputs.values())
 
119
  candidates = blender_state['candidates']
120
  top_k_for_fuser = blender_config['top_k_for_fuser']
121
  fuse_params = blender_config.copy()
122
+ fuse_params.pop("top_k_for_fuser")
123
+ fuse_params.pop("source_max_length")
124
+ fuse_params['no_repeat_ngram_size'] = 3
125
  top_k_candidates = get_topk_candidates_from_ranks([ranks], [candidates], top_k=top_k_for_fuser)[0]
126
+ fuser_outputs = blender.fuse(instructions=[inst], inputs=[input], candidates=[top_k_candidates], **fuse_params, batch_size=1)[0]
127
  return [fuser_outputs, fuser_outputs]
128
 
129
  def display_fuser_output(fuser_output):
 
163
  saved_fuse_outputs = gr.State(value=[])
164
  gr.Markdown("## Blender Outputs")
165
  with gr.Group():
166
+ rank_outputs = gr.Textbox(lines=1, label="Ranks of each LLM's output", placeholder="Ranking outputs", show_label=True)
167
  fuser_outputs = gr.Textbox(lines=4, label="Fusing outputs", placeholder="Fusing outputs", show_label=True)
168
  with gr.Row():
169
  rank_button = gr.Button('Rank LLM Outputs', variant='primary')
 
179
  })
180
 
181
  with gr.Accordion(label='Advanced options', open=False):
182
+ top_k_for_fuser = gr.Slider(
183
+ label='Top-k ranked candidates to fuse',
184
+ minimum=1,
185
+ maximum=3,
186
+ step=1,
187
+ value=3,
188
+ )
189
  source_max_length = gr.Slider(
190
  label='Max length of Instruction + Input',
191
  minimum=1,
 
200
  step=1,
201
  value=DEFAULT_CANDIDATE_MAX_LENGTH,
202
  )
 
 
 
 
 
 
 
203
  max_new_tokens = gr.Slider(
204
  label='Max new tokens fuser can generate',
205
  minimum=1,
 
207
  step=1,
208
  value=DEFAULT_FUSER_MAX_NEW_TOKENS,
209
  )
210
+ # temperature = gr.Slider(
211
+ # label='Temperature of fuser generation',
212
+ # minimum=0.1,
213
+ # maximum=2.0,
214
+ # step=0.1,
215
+ # value=0.7,
216
+ # )
217
+ # top_p = gr.Slider(
218
+ # label='Top-p of fuser generation',
219
+ # minimum=0.05,
220
+ # maximum=1.0,
221
+ # step=0.05,
222
+ # value=1.0,
223
+ # )
224
+ beam_size = gr.Slider(
225
+ label='Beam size of fuser generation',
226
+ minimum=1,
227
+ maximum=10,
228
+ step=1,
229
+ value=4,
230
  )
231
 
232
  examples_dummy_textbox = gr.Textbox(lines=1, label="", placeholder="", show_label=False, visible=False)
 
248
  examples_dummy_textbox.change(
249
  fn=update_base_llm_dropdown_along_examples,
250
  inputs=[examples_dummy_textbox],
251
+ outputs=[saved_llm_outputs, rank_outputs, fuser_outputs],
252
  ).then(
253
  fn=display_llm_output,
254
  inputs=[saved_llm_outputs, selected_base_llm_name_dropdown],
 
291
  fuse_button.click(
292
  fn=check_fuser_inputs,
293
  inputs=[blender_state, blender_config, saved_rank_outputs],
294
+ outputs=fuser_outputs,
295
  ).success(
296
  fn=llms_fuse,
297
  inputs=[blender_state, blender_config, saved_rank_outputs],
 
325
  inputs=[max_new_tokens, blender_config],
326
  outputs=blender_config,
327
  )
328
+ # temperature.change(
329
+ # fn=lambda x, y: y.update({"temperature": x}) or y,
330
+ # inputs=[temperature, blender_config],
331
+ # outputs=blender_config,
332
+ # )
333
+ # top_p.change(
334
+ # fn=lambda x, y: y.update({"top_p": x}) or y,
335
+ # inputs=[top_p, blender_config],
336
+ # outputs=blender_config,
337
+ # )
338
+ beam_size.change(
339
+ fn=lambda x, y: y.update({"num_beams": x}) or y,
340
+ inputs=[beam_size, blender_config],
341
  outputs=blender_config,
342
  )
343