xzyao commited on
Commit
5e48162
1 Parent(s): c694998

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -15
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import streamlit as st
2
  import requests
3
  import time
@@ -18,11 +19,9 @@ def infer(
18
  seed=42,
19
  stop="\n"
20
  ):
21
-
22
  model_name_map = {
23
  "GPT-JT-6B-v1": "Together-gpt-JT-6B-v1",
24
  }
25
-
26
  max_new_tokens = int(max_new_tokens)
27
  num_completions = int(num_completions)
28
  temperature = float(temperature)
@@ -39,21 +38,41 @@ def infer(
39
  temperature = 0.01
40
 
41
  my_post_dict = {
42
- "model": "Together-gpt-JT-6B-v1",
43
- "prompt": prompt,
44
- "top_p": top_p,
45
- "top_k": top_k,
46
- "temperature": temperature,
47
- "max_tokens": max_new_tokens,
48
- "stop": stop,
 
 
 
 
 
 
 
 
 
 
 
49
  }
50
- response = requests.get("https://staging.together.xyz/api/inference", params=my_post_dict).json()
51
- generated_text = response['output']['choices'][0]['text']
 
 
 
 
 
 
 
 
 
52
 
53
  for stop_word in stop:
54
  if stop_word in generated_text:
55
  generated_text = generated_text[:generated_text.find(stop_word)]
56
-
57
  st.session_state.updated = True
58
 
59
  return generated_text
@@ -84,7 +103,7 @@ Label:'''
84
 
85
  Input: bear, honey
86
  Story: Once upon a time,'''
87
- st.session_state.temperature = "0.8"
88
  st.session_state.top_p = "0.9"
89
  st.session_state.max_new_tokens = "100"
90
 
@@ -101,7 +120,7 @@ def main():
101
  st.session_state.prompt = "Please answer the following question:\n\nQuestion: In which country is Zurich located?\nAnswer:"
102
 
103
  if 'temperature' not in st.session_state:
104
- st.session_state.temperature = "0.8"
105
 
106
  if 'top_p' not in st.session_state:
107
  st.session_state.top_p = "1.0"
@@ -164,4 +183,4 @@ def main():
164
  generated_area.markdown("<b>" + to_md(prompt) + "</b>" + to_md(report_text), unsafe_allow_html=True)
165
 
166
  if __name__ == '__main__':
167
- main()
 
1
+
2
  import streamlit as st
3
  import requests
4
  import time
 
19
  seed=42,
20
  stop="\n"
21
  ):
 
22
  model_name_map = {
23
  "GPT-JT-6B-v1": "Together-gpt-JT-6B-v1",
24
  }
 
25
  max_new_tokens = int(max_new_tokens)
26
  num_completions = int(num_completions)
27
  temperature = float(temperature)
 
38
  temperature = 0.01
39
 
40
  my_post_dict = {
41
+ "type": "general",
42
+ "payload": {
43
+ "max_tokens": int(max_new_tokens),
44
+ "n": int(num_completions),
45
+ "temperature": float(temperature),
46
+ "top_p": float(top_p),
47
+ "model": model_name_map[model_name],
48
+ "prompt": [prompt],
49
+ "request_type": "language-model-inference",
50
+ "stop": stop,
51
+ "best_of": 1,
52
+ "echo": False,
53
+ "seed": int(seed),
54
+ "prompt_embedding": False,
55
+ },
56
+ "returned_payload": {},
57
+ "status": "submitted",
58
+ "source": "dalle",
59
  }
60
+
61
+ job_id = requests.post("https://planetd.shift.ml/jobs", json=my_post_dict).json()['id']
62
+
63
+ for i in range(100):
64
+
65
+ time.sleep(0.5)
66
+ ret = requests.get(f"https://planetd.shift.ml/job/{job_id}", json={'id': job_id}).json()
67
+ if ret['status'] == 'finished':
68
+ break
69
+
70
+ generated_text = ret['returned_payload']['result']['inference_result'][0]['choices'][0]['text']
71
 
72
  for stop_word in stop:
73
  if stop_word in generated_text:
74
  generated_text = generated_text[:generated_text.find(stop_word)]
75
+
76
  st.session_state.updated = True
77
 
78
  return generated_text
 
103
 
104
  Input: bear, honey
105
  Story: Once upon a time,'''
106
+ st.session_state.temperature = "0.0"
107
  st.session_state.top_p = "0.9"
108
  st.session_state.max_new_tokens = "100"
109
 
 
120
  st.session_state.prompt = "Please answer the following question:\n\nQuestion: In which country is Zurich located?\nAnswer:"
121
 
122
  if 'temperature' not in st.session_state:
123
+ st.session_state.temperature = "0.0"
124
 
125
  if 'top_p' not in st.session_state:
126
  st.session_state.top_p = "1.0"
 
183
  generated_area.markdown("<b>" + to_md(prompt) + "</b>" + to_md(report_text), unsafe_allow_html=True)
184
 
185
  if __name__ == '__main__':
186
+ main()