Ubuntu commited on
Commit
7e6795e
1 Parent(s): 87df952
Files changed (1) hide show
  1. app.py +144 -68
app.py CHANGED
@@ -3,86 +3,162 @@ import requests
3
  import time
4
  from ast import literal_eval
5
 
6
- @st.cache
7
- def infer(prompt,
8
- model_name,
9
- max_new_tokens=10,
10
- temperature=0.0,
11
- top_p=1.0,
12
- num_completions=1,
13
- seed=42,
14
- stop="\n"):
 
 
 
15
 
16
  model_name_map = {
17
  "GPT-JT-6B-v1": "Together-gpt-JT-6B-v1",
18
  }
19
-
20
- my_post_dict = {
21
- "type": "general",
22
- "payload": {
23
- "max_tokens": int(max_new_tokens),
24
- "n": int(num_completions),
25
- "temperature": float(temperature),
26
- "top_p": float(top_p),
27
- "model": model_name_map[model_name],
28
- "prompt": [prompt],
29
- "request_type": "language-model-inference",
30
- "stop": stop.split(";"),
31
- "best_of": 1,
32
- "echo": False,
33
- "seed": int(seed),
34
- "prompt_embedding": False,
35
- },
36
- "returned_payload": {},
37
- "status": "submitted",
38
- "source": "dalle",
39
- }
40
 
41
- job_id = requests.post("https://planetd.shift.ml/jobs", json=my_post_dict).json()['id']
 
 
 
 
 
42
 
43
- for i in range(100):
 
 
 
44
 
45
- time.sleep(1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
- ret = requests.get(f"https://planetd.shift.ml/job/{job_id}", json={'id': job_id}).json()
 
 
 
 
 
 
 
 
 
 
 
48
 
49
- if ret['status'] == 'finished':
50
- break
51
 
52
- return ret['returned_payload']['result']['inference_result'][0]['choices'][0]['text']
 
 
 
 
 
53
 
 
 
54
 
55
- st.title("GPT-JT")
56
-
57
- col1, col2 = st.columns([1, 3])
58
-
59
- with col1:
60
- model_name = st.selectbox("Model", ["GPT-JT-6B-v1"])
61
- max_new_tokens = st.text_input('Max new tokens', "10")
62
- temperature = st.text_input('temperature', "0.0")
63
- top_p = st.text_input('top_p', "1.0")
64
- num_completions = st.text_input('num_completions (only the best one will be returend)', "1")
65
- stop = st.text_input('stop, split by;', r'\n')
66
- seed = st.text_input('seed', "42")
67
-
68
- with col2:
69
- s_example = "Please answer the following question:\n\nQuestion: Where is Zurich?\nAnswer:"
70
- prompt = st.text_area(
71
- "Prompt",
72
- value=s_example,
73
- max_chars=4096,
74
- height=400,
75
- )
76
-
77
- generated_area = st.empty()
78
- generated_area.text("(Generate here)")
79
 
80
- button_submit = st.button("Submit")
 
 
 
 
 
 
 
 
 
81
 
82
- if button_submit:
83
- generated_area.text(prompt)
84
- report_text = infer(
85
- prompt, model_name=model_name, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p,
86
- num_completions=num_completions, seed=seed, stop=literal_eval("'''"+stop+"'''"),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  )
88
- generated_area.text(prompt + report_text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  import time
4
  from ast import literal_eval
5
 
6
+
7
+ def infer(
8
+ prompt,
9
+ model_name,
10
+ max_new_tokens=10,
11
+ temperature=0.0,
12
+ top_p=1.0,
13
+ top_k=40,
14
+ num_completions=1,
15
+ seed=42,
16
+ stop="\n"
17
+ ):
18
 
19
  model_name_map = {
20
  "GPT-JT-6B-v1": "Together-gpt-JT-6B-v1",
21
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
+ max_new_tokens = int(max_new_tokens)
24
+ num_completions = int(num_completions)
25
+ temperature = float(temperature)
26
+ top_p = float(top_p)
27
+ stop = stop.split(";")
28
+ seed = seed
29
 
30
+ assert 0 <= max_new_tokens <= 256
31
+ assert 1 <= num_completions <= 5
32
+ assert 0.0 <= temperature <= 10.0
33
+ assert 0.0 <= top_p <= 1.0
34
 
35
+ if temperature == 0.0:
36
+ temperature = 1.0
37
+ top_k = 1
38
+
39
+ result = await st.session_state.together_web3.language_model_inference(
40
+ from_dict(
41
+ data_class=LanguageModelInferenceRequest,
42
+ data={
43
+ "model": model_name_map[model_name],
44
+ "max_tokens": max_new_tokens,
45
+ "prompt": prompt,
46
+ "n": num_completions,
47
+ "temperature": temperature,
48
+ "top_k": top_k,
49
+ "top_p": top_p,
50
+ "stop": stop,
51
+ "seed": seed,
52
+ "echo": False,
53
+ }
54
+ ),
55
+ )
56
+
57
+ generated_text = result.choices[0].text
58
+
59
+ for stop_word in stop:
60
+ if stop_word in result:
61
+ generated_text = generated_text[:generated_text.find(stop_word)]
62
+
63
+ return generated_text
64
+
65
+ def set_preset():
66
+ if st.session_state.preset == "Classification":
67
 
68
+ st.session_state.prompt = '''Please classify the given sentence.
69
+ Possible labels:
70
+ 1. <label_0>
71
+ 2. <label_1>
72
+
73
+ Input: <sentence_0>
74
+ Label: <label_0>
75
+
76
+ Input: <sentence_1>
77
+ Label:'''
78
+ st.session_state.temperature = "0.0"
79
+ st.session_state.top_p = "1.0"
80
 
81
+ elif st.session_state.preset == "Generation":
 
82
 
83
+ st.session_state.prompt = '''Please write a story given keywords.
84
+
85
+ Input: bear, honey
86
+ Story:'''
87
+ st.session_state.temperature = "1.0"
88
+ st.session_state.top_p = "0.5"
89
 
90
+ else:
91
+ pass
92
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
 
94
+ def main():
95
+
96
+ if 'preset' not in st.session_state:
97
+ st.session_state.preset = "Classification"
98
+
99
+ if 'prompt' not in st.session_state:
100
+ st.session_state.prompt = "Please answer the following question:\n\nQuestion: Where is Zurich?\nAnswer:"
101
+
102
+ if 'temperature' not in st.session_state:
103
+ st.session_state.temperature = "0.0"
104
 
105
+ if 'top_p' not in st.session_state:
106
+ st.session_state.top_p = "1.0"
107
+
108
+ if 'top_k' not in st.session_state:
109
+ st.session_state.top_k = "40"
110
+
111
+ if 'together_web3' not in st.session_state:
112
+ st.session_state.together_web3 = TogetherWeb3()
113
+
114
+
115
+ st.title("GPT-JT")
116
+
117
+ col1, col2 = st.columns([1, 3])
118
+
119
+ with col1:
120
+ model_name = st.selectbox("Model", ["GPT-JT-6B-v1"])
121
+ max_new_tokens = st.text_input('Max new tokens', "10")
122
+ temperature = st.text_input('temperature', st.session_state.temperature)
123
+ top_k = st.text_input('top_k', st.session_state.top_k)
124
+ top_p = st.text_input('top_p', st.session_state.top_p)
125
+ # num_completions = st.text_input('num_completions (only the best one will be returend)', "1")
126
+ num_completions = "1"
127
+ stop = st.text_input('stop, split by;', r'\n')
128
+ # seed = st.text_input('seed', "42")
129
+ seed = "42"
130
+
131
+ with col2:
132
+
133
+ preset = st.radio(
134
+ "Recommended Configurations",
135
+ ('Classification', 'Generation'),
136
+ on_change=set_preset,
137
+ key="preset",
138
+ horizontal=True
139
  )
140
+
141
+ prompt = st.text_area(
142
+ "Prompt",
143
+ value=st.session_state.prompt,
144
+ max_chars=4096,
145
+ height=400,
146
+ )
147
+
148
+ generated_area = st.empty()
149
+ generated_area.text("(Generate here)")
150
+
151
+ button_submit = st.button("Submit")
152
+
153
+ if button_submit:
154
+ generated_area.text(prompt)
155
+ report_text = infer(
156
+ prompt, model_name=model_name, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k,
157
+ num_completions=num_completions, seed=seed, stop=literal_eval("'''"+stop+"'''"),
158
+ )
159
+ generated_area.text(prompt + report_text)
160
+
161
+
162
+
163
+ if __name__ == '__main__':
164
+ main()