Ubuntu commited on
Commit
5fcdcad
2 Parent(s): 7e6795e f79758c
Files changed (2) hide show
  1. app.py +49 -48
  2. requirements.txt +1 -0
app.py CHANGED
@@ -3,7 +3,8 @@ import requests
3
  import time
4
  from ast import literal_eval
5
 
6
-
 
7
  def infer(
8
  prompt,
9
  model_name,
@@ -15,7 +16,7 @@ def infer(
15
  seed=42,
16
  stop="\n"
17
  ):
18
-
19
  model_name_map = {
20
  "GPT-JT-6B-v1": "Together-gpt-JT-6B-v1",
21
  }
@@ -33,39 +34,34 @@ def infer(
33
  assert 0.0 <= top_p <= 1.0
34
 
35
  if temperature == 0.0:
36
- temperature = 1.0
37
- top_k = 1
38
-
39
- result = await st.session_state.together_web3.language_model_inference(
40
- from_dict(
41
- data_class=LanguageModelInferenceRequest,
42
- data={
43
- "model": model_name_map[model_name],
44
- "max_tokens": max_new_tokens,
45
- "prompt": prompt,
46
- "n": num_completions,
47
- "temperature": temperature,
48
- "top_k": top_k,
49
- "top_p": top_p,
50
- "stop": stop,
51
- "seed": seed,
52
- "echo": False,
53
- }
54
- ),
55
- )
56
-
57
- generated_text = result.choices[0].text
58
 
59
  for stop_word in stop:
60
- if stop_word in result:
61
  generated_text = generated_text[:generated_text.find(stop_word)]
 
 
62
 
63
  return generated_text
64
 
 
65
  def set_preset():
66
  if st.session_state.preset == "Classification":
67
 
68
- st.session_state.prompt = '''Please classify the given sentence.
 
69
  Possible labels:
70
  1. <label_0>
71
  2. <label_1>
@@ -77,15 +73,18 @@ Input: <sentence_1>
77
  Label:'''
78
  st.session_state.temperature = "0.0"
79
  st.session_state.top_p = "1.0"
 
80
 
81
  elif st.session_state.preset == "Generation":
82
 
83
- st.session_state.prompt = '''Please write a story given keywords.
 
84
 
85
  Input: bear, honey
86
- Story:'''
87
- st.session_state.temperature = "1.0"
88
- st.session_state.top_p = "0.5"
 
89
 
90
  else:
91
  pass
@@ -97,10 +96,10 @@ def main():
97
  st.session_state.preset = "Classification"
98
 
99
  if 'prompt' not in st.session_state:
100
- st.session_state.prompt = "Please answer the following question:\n\nQuestion: Where is Zurich?\nAnswer:"
101
 
102
  if 'temperature' not in st.session_state:
103
- st.session_state.temperature = "0.0"
104
 
105
  if 'top_p' not in st.session_state:
106
  st.session_state.top_p = "1.0"
@@ -108,8 +107,11 @@ def main():
108
  if 'top_k' not in st.session_state:
109
  st.session_state.top_k = "40"
110
 
111
- if 'together_web3' not in st.session_state:
112
- st.session_state.together_web3 = TogetherWeb3()
 
 
 
113
 
114
 
115
  st.title("GPT-JT")
@@ -118,7 +120,7 @@ def main():
118
 
119
  with col1:
120
  model_name = st.selectbox("Model", ["GPT-JT-6B-v1"])
121
- max_new_tokens = st.text_input('Max new tokens', "10")
122
  temperature = st.text_input('temperature', st.session_state.temperature)
123
  top_k = st.text_input('top_k', st.session_state.top_k)
124
  top_p = st.text_input('top_p', st.session_state.top_p)
@@ -130,19 +132,20 @@ def main():
130
 
131
  with col2:
132
 
133
- preset = st.radio(
134
- "Recommended Configurations",
135
- ('Classification', 'Generation'),
136
- on_change=set_preset,
137
- key="preset",
138
- horizontal=True
139
- )
140
 
141
- prompt = st.text_area(
 
142
  "Prompt",
143
  value=st.session_state.prompt,
144
  max_chars=4096,
145
- height=400,
146
  )
147
 
148
  generated_area = st.empty()
@@ -153,12 +156,10 @@ def main():
153
  if button_submit:
154
  generated_area.text(prompt)
155
  report_text = infer(
156
- prompt, model_name=model_name, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k,
157
  num_completions=num_completions, seed=seed, stop=literal_eval("'''"+stop+"'''"),
158
  )
159
  generated_area.text(prompt + report_text)
160
-
161
-
162
 
163
  if __name__ == '__main__':
164
- main()
 
3
  import time
4
  from ast import literal_eval
5
 
6
+
7
+ @st.cache
8
  def infer(
9
  prompt,
10
  model_name,
 
16
  seed=42,
17
  stop="\n"
18
  ):
19
+
20
  model_name_map = {
21
  "GPT-JT-6B-v1": "Together-gpt-JT-6B-v1",
22
  }
 
34
  assert 0.0 <= top_p <= 1.0
35
 
36
  if temperature == 0.0:
37
+ temperature = 0.01
38
+
39
+ my_post_dict = {
40
+ "model": "Together-gpt-JT-6B-v1",
41
+ "prompt": prompt,
42
+ "top_p": top_p,
43
+ "top_k": top_k,
44
+ "temperature": temperature,
45
+ "max_tokens": max_new_tokens,
46
+ "stop": stop,
47
+ }
48
+ response = requests.get("https://staging.together.xyz/api/inference", params=my_post_dict).json()
49
+ generated_text = response['output']['choices'][0]['text']
 
 
 
 
 
 
 
 
 
50
 
51
  for stop_word in stop:
52
+ if stop_word in generated_text:
53
  generated_text = generated_text[:generated_text.find(stop_word)]
54
+
55
+ st.session_state.updated = True
56
 
57
  return generated_text
58
 
59
+
60
  def set_preset():
61
  if st.session_state.preset == "Classification":
62
 
63
+ if not st.session_state.updated:
64
+ st.session_state.prompt = '''Please classify the given sentence.
65
  Possible labels:
66
  1. <label_0>
67
  2. <label_1>
 
73
  Label:'''
74
  st.session_state.temperature = "0.0"
75
  st.session_state.top_p = "1.0"
76
+ st.session_state.max_new_tokens = "10"
77
 
78
  elif st.session_state.preset == "Generation":
79
 
80
+ if not st.session_state.updated:
81
+ st.session_state.prompt = '''Please write a story given keywords.
82
 
83
  Input: bear, honey
84
+ Story: Once upon a time,'''
85
+ st.session_state.temperature = "0.8"
86
+ st.session_state.top_p = "0.9"
87
+ st.session_state.max_new_tokens = "100"
88
 
89
  else:
90
  pass
 
96
  st.session_state.preset = "Classification"
97
 
98
  if 'prompt' not in st.session_state:
99
+ st.session_state.prompt = "Please answer the following question:\n\nQuestion: In which country is Zurich located?\nAnswer:"
100
 
101
  if 'temperature' not in st.session_state:
102
+ st.session_state.temperature = "0.8"
103
 
104
  if 'top_p' not in st.session_state:
105
  st.session_state.top_p = "1.0"
 
107
  if 'top_k' not in st.session_state:
108
  st.session_state.top_k = "40"
109
 
110
+ if 'max_new_tokens' not in st.session_state:
111
+ st.session_state.max_new_tokens = "10"
112
+
113
+ if 'updated' not in st.session_state:
114
+ st.session_state.updated = False
115
 
116
 
117
  st.title("GPT-JT")
 
120
 
121
  with col1:
122
  model_name = st.selectbox("Model", ["GPT-JT-6B-v1"])
123
+ max_new_tokens = st.text_input('Max new tokens', st.session_state.max_new_tokens)
124
  temperature = st.text_input('temperature', st.session_state.temperature)
125
  top_k = st.text_input('top_k', st.session_state.top_k)
126
  top_p = st.text_input('top_p', st.session_state.top_p)
 
132
 
133
  with col2:
134
 
135
+ # preset = st.radio(
136
+ # "Recommended Templates",
137
+ # ('Classification', 'Generation'),
138
+ # on_change=set_preset,
139
+ # key="preset",
140
+ # horizontal=True
141
+ # )
142
 
143
+ prompt_area = st.empty()
144
+ prompt = prompt_area.text_area(
145
  "Prompt",
146
  value=st.session_state.prompt,
147
  max_chars=4096,
148
+ height=300,
149
  )
150
 
151
  generated_area = st.empty()
 
156
  if button_submit:
157
  generated_area.text(prompt)
158
  report_text = infer(
159
+ st.session_state.prompt, model_name=model_name, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k,
160
  num_completions=num_completions, seed=seed, stop=literal_eval("'''"+stop+"'''"),
161
  )
162
  generated_area.text(prompt + report_text)
 
 
163
 
164
  if __name__ == '__main__':
165
+ main()
requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ together_web3