juewang commited on
Commit
b2703de
1 Parent(s): 7f5cbab

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -10
app.py CHANGED
@@ -9,7 +9,8 @@ def infer(prompt,
9
  temperature=0.0,
10
  top_p=1.0,
11
  num_completions=1,
12
- seed=42,):
 
13
 
14
  model_name_map = {
15
  "GPT-JT-6B-v1": "Together-gpt-JT-6B-v1",
@@ -25,7 +26,7 @@ def infer(prompt,
25
  "model": model_name_map[model_name],
26
  "prompt": [prompt],
27
  "request_type": "language-model-inference",
28
- "stop": None,
29
  "best_of": 1,
30
  "echo": False,
31
  "seed": int(seed),
@@ -59,7 +60,8 @@ with col1:
59
  max_new_tokens = st.text_input('Max new tokens', "10")
60
  temperature = st.text_input('temperature', "0.0")
61
  top_p = st.text_input('top_p', "1.0")
62
- num_completions = st.text_input('num_completions', "1")
 
63
  seed = st.text_input('seed', "42")
64
 
65
  with col2:
@@ -77,10 +79,9 @@ with col2:
77
  button_submit = st.button("Submit")
78
 
79
  if button_submit:
80
- with st.spinner(text="In progress.."):
81
- generated_area.markdown("...")
82
- report_text = infer(
83
- prompt, model_name=model_name, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p,
84
- num_completions=num_completions, seed=seed,
85
- )
86
- generated_area.markdown(report_text)
 
9
  temperature=0.0,
10
  top_p=1.0,
11
  num_completions=1,
12
+ seed=42,
13
+ stop="\n"):
14
 
15
  model_name_map = {
16
  "GPT-JT-6B-v1": "Together-gpt-JT-6B-v1",
 
26
  "model": model_name_map[model_name],
27
  "prompt": [prompt],
28
  "request_type": "language-model-inference",
29
+ "stop": stop.split(";"),
30
  "best_of": 1,
31
  "echo": False,
32
  "seed": int(seed),
 
60
  max_new_tokens = st.text_input('Max new tokens', "10")
61
  temperature = st.text_input('temperature', "0.0")
62
  top_p = st.text_input('top_p', "1.0")
63
+ num_completions = st.text_input('num_completions (only the best one will be returend)', "1")
64
+ stop = st.text_input('stop, split by;', "\n")
65
  seed = st.text_input('seed', "42")
66
 
67
  with col2:
 
79
  button_submit = st.button("Submit")
80
 
81
  if button_submit:
82
+ generated_area.markdown(prompt)
83
+ report_text = infer(
84
+ prompt, model_name=model_name, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p,
85
+ num_completions=num_completions, seed=seed, stop=stop,
86
+ )
87
+ generated_area.markdown(prompt + "**" + report_text + "**")