juewang commited on
Commit
a531b86
1 Parent(s): 9e2c039

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -19
app.py CHANGED
@@ -44,26 +44,30 @@ def infer(prompt, model_name, max_new_tokens=10, temperature=0.0, top_p=1.0):
44
 
45
 
46
  st.title("TOMA Application")
47
-
48
- s_example = "Please answer the following question:\n\nQuestion: Where is Zurich?\nAnswer:"
49
- prompt = st.text_area(
50
- "Prompt",
51
- value=s_example,
52
- max_chars=4096,
53
- height=400,
54
- )
55
-
56
- generated_area = st.empty()
57
- generated_area.markdown("(Generate here)")
58
 
59
- button_submit = st.button("Submit")
60
-
61
- model_name = st.selectbox("Model", ["GPT-JT-6B-v1"])
62
- max_new_tokens = st.text_input('Max new tokens', "10")
63
- temperature = st.text_input('temperature', "0.0")
64
- top_p = st.text_input('top_p', "1.0")
65
 
66
- if button_submit:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  with st.spinner(text="In progress.."):
68
  report_text = infer(prompt, model_name=model_name, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p)
69
- generated_area.markdown(report_text)
 
 
 
 
 
 
 
44
 
45
 
46
  st.title("TOMA Application")
 
 
 
 
 
 
 
 
 
 
 
47
 
48
+ col1, col2 = st.columns([1, 3])
 
 
 
 
 
49
 
50
+ with col2:
51
+ s_example = "Please answer the following question:\n\nQuestion: Where is Zurich?\nAnswer:"
52
+ prompt = st.text_area(
53
+ "Prompt",
54
+ value=s_example,
55
+ max_chars=4096,
56
+ height=400,
57
+ )
58
+
59
+ generated_area = st.empty()
60
+ generated_area.markdown("(Generate here)")
61
+
62
+ button_submit = st.button("Submit")
63
+
64
+ if button_submit:
65
  with st.spinner(text="In progress.."):
66
  report_text = infer(prompt, model_name=model_name, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p)
67
+ generated_area.markdown(report_text)
68
+
69
+ with col1:
70
+ model_name = st.selectbox("Model", ["GPT-JT-6B-v1"])
71
+ max_new_tokens = st.text_input('Max new tokens', "10")
72
+ temperature = st.text_input('temperature', "0.0")
73
+ top_p = st.text_input('top_p', "1.0")