Younes Belkada commited on
Commit
4da5c39
1 Parent(s): ab3c911

Add num beams

Browse files
Files changed (1) hide show
  1. app.py +4 -2
app.py CHANGED
@@ -35,11 +35,12 @@ def query(payload):
35
  print(response)
36
  return json.loads(response.content.decode("utf-8"))
37
 
38
- def inference(input_sentence, max_length, no_repeat_ngram_size, temperature,top_k, top_p, greedy_decoding, seed=42):
39
  top_k = None if top_k == 0 else top_k
 
40
  payload = {"inputs": input_sentence,
41
  "parameters": {"max_new_tokens": max_length, "top_k": top_k, "top_p": top_p, "temperature": temperature,
42
- "do_sample": not greedy_decoding, "seed": seed, "early_stopping":no_repeat_ngram_size > 1, "no_repeat_ngram_size":no_repeat_ngram_size}}
43
  data = query(
44
  payload
45
  )
@@ -53,6 +54,7 @@ gr.Interface(
53
  gr.inputs.Textbox(label="Input"),
54
  gr.inputs.Slider(1, 64, default=8, label="Tokens to generate"),
55
  gr.inputs.Slider(1, 10, default=2, step=1, label="No repeat N gram"),
 
56
  gr.inputs.Slider(0.0, 1.0, default=0.1, step=0.05, label="Temperature"),
57
  gr.inputs.Slider(0, 64, default=0, step=1, label="Top K"),
58
  gr.inputs.Slider(0.0, 10, default=0.9, step=0.05, label="Top P"),
 
35
  print(response)
36
  return json.loads(response.content.decode("utf-8"))
37
 
38
+ def inference(input_sentence, max_length, no_repeat_ngram_size, num_beams, temperature,top_k, top_p, greedy_decoding, seed=42):
39
  top_k = None if top_k == 0 else top_k
40
+ greedy = False if num_beams > 0
41
  payload = {"inputs": input_sentence,
42
  "parameters": {"max_new_tokens": max_length, "top_k": top_k, "top_p": top_p, "temperature": temperature,
43
+ "do_sample": not greedy_decoding, "seed": seed, "early_stopping":num_beams > 0, "no_repeat_ngram_size":no_repeat_ngram_size}}
44
  data = query(
45
  payload
46
  )
 
54
  gr.inputs.Textbox(label="Input"),
55
  gr.inputs.Slider(1, 64, default=8, label="Tokens to generate"),
56
  gr.inputs.Slider(1, 10, default=2, step=1, label="No repeat N gram"),
57
+ gr.inputs.Slider(0, 10, default=5, step=1, label="Num beams"),
58
  gr.inputs.Slider(0.0, 1.0, default=0.1, step=0.05, label="Temperature"),
59
  gr.inputs.Slider(0, 64, default=0, step=1, label="Top K"),
60
  gr.inputs.Slider(0.0, 10, default=0.9, step=0.05, label="Top P"),