Ubuntu commited on
Commit
c5556d8
1 Parent(s): fa25719
Files changed (1) hide show
  1. app.py +8 -1
app.py CHANGED
@@ -17,6 +17,7 @@ def infer(
17
  top_p=1.0,
18
  top_k=40,
19
  num_completions=1,
 
20
  seed=42,
21
  stop="\n"
22
  ):
@@ -28,6 +29,7 @@ def infer(
28
  temperature = float(temperature)
29
  top_p = float(top_p)
30
  top_k = int(top_k)
 
31
  stop = stop.split(";")
32
  seed = seed
33
 
@@ -36,6 +38,7 @@ def infer(
36
  assert 0.0 <= temperature <= 10.0
37
  assert 0.0 <= top_p <= 1.0
38
  assert 1 <= top_k <= 1000
 
39
 
40
  if temperature == 0.0:
41
  temperature = 0.01
@@ -48,6 +51,7 @@ def infer(
48
  "top_k": top_k,
49
  "temperature": temperature,
50
  "max_tokens": max_new_tokens,
 
51
  "stop": stop,
52
  }
53
  print(f"send: {datetime.now()}")
@@ -223,6 +227,7 @@ def main():
223
  if 'preset' not in st.session_state:
224
  st.session_state.preset = "Sentiment Analysis"
225
  st.session_state.top_k = "40"
 
226
  st.session_state.stop = r'\n'
227
  set_preset()
228
 
@@ -252,6 +257,7 @@ def main():
252
  top_p = st.text_input('top_p', st.session_state.top_p)
253
  # num_completions = st.text_input('num_completions (only the best one will be returend)', "1")
254
  num_completions = "1"
 
255
  stop = st.text_input('stop, split by;', st.session_state.stop)
256
  # seed = st.text_input('seed', "42")
257
  seed = "42"
@@ -275,7 +281,8 @@ def main():
275
  generated_area.markdown("<b>" + to_md(prompt) + "</b>", unsafe_allow_html=True)
276
  report_text = infer(
277
  prompt, model_name=model_name, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k,
278
- num_completions=num_completions, seed=seed, stop=literal_eval("'''"+stop+"'''"),
 
279
  )
280
  generated_area.markdown("<b>" + to_md(prompt) + "</b><mark style='background-color: #cbeacd'>" + to_md(report_text)+"</mark>", unsafe_allow_html=True)
281
 
 
17
  top_p=1.0,
18
  top_k=40,
19
  num_completions=1,
20
+ repetition_penalty=1.0,
21
  seed=42,
22
  stop="\n"
23
  ):
 
29
  temperature = float(temperature)
30
  top_p = float(top_p)
31
  top_k = int(top_k)
32
+ repetition_penalty = float(repetition_penalty)
33
  stop = stop.split(";")
34
  seed = seed
35
 
 
38
  assert 0.0 <= temperature <= 10.0
39
  assert 0.0 <= top_p <= 1.0
40
  assert 1 <= top_k <= 1000
41
+ assert 0.9 <= repetition_penalty <= 3.0
42
 
43
  if temperature == 0.0:
44
  temperature = 0.01
 
51
  "top_k": top_k,
52
  "temperature": temperature,
53
  "max_tokens": max_new_tokens,
54
+ "repetition_penalty": repetition_penalty,
55
  "stop": stop,
56
  }
57
  print(f"send: {datetime.now()}")
 
227
  if 'preset' not in st.session_state:
228
  st.session_state.preset = "Sentiment Analysis"
229
  st.session_state.top_k = "40"
230
+ st.session_state.repetition_penalty = "1.0"
231
  st.session_state.stop = r'\n'
232
  set_preset()
233
 
 
257
  top_p = st.text_input('top_p', st.session_state.top_p)
258
  # num_completions = st.text_input('num_completions (only the best one will be returend)', "1")
259
  num_completions = "1"
260
+ repetition_penalty = st.text_input('repetition_penalty', st.session_state.repetition_penalty)
261
  stop = st.text_input('stop, split by;', st.session_state.stop)
262
  # seed = st.text_input('seed', "42")
263
  seed = "42"
 
281
  generated_area.markdown("<b>" + to_md(prompt) + "</b>", unsafe_allow_html=True)
282
  report_text = infer(
283
  prompt, model_name=model_name, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k,
284
+ num_completions=num_completions, repetition_penalty=repetition_penalty,
285
+ seed=seed, stop=literal_eval("'''"+stop+"'''"),
286
  )
287
  generated_area.markdown("<b>" + to_md(prompt) + "</b><mark style='background-color: #cbeacd'>" + to_md(report_text)+"</mark>", unsafe_allow_html=True)
288