winglian commited on
Commit
b360328
1 Parent(s): 96772fe

streaming token support

Browse files

pass mapping
make sure to call different models
fix fill value on iter

Files changed (1) hide show
  1. app.py +47 -33
app.py CHANGED
@@ -7,8 +7,12 @@ import re
7
  import traceback
8
  import uuid
9
  import datetime
 
 
 
10
  from collections import defaultdict
11
  from time import sleep
 
12
 
13
  import boto3
14
  import gradio as gr
@@ -56,7 +60,7 @@ class Pipeline:
56
  "stop": ["</s>", "USER:", "### Instruction:"] + stop_tokens,
57
  }
58
 
59
- def __call__(self, prompt):
60
  input = self.generation_config.copy()
61
  input["prompt"] = prompt
62
 
@@ -71,12 +75,26 @@ class Pipeline:
71
 
72
  if response.status_code == 200:
73
  data = response.json()
74
- status = data.get('status')
75
- if status == 'COMPLETED':
76
- return [{"generated_text": data["output"]}]
77
- else:
78
- task_id = data.get('id')
79
- return self.poll_for_status(task_id)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
  def poll_for_status(self, task_id):
82
  url = f"https://api.runpod.ai/v2/{self.endpoint_id}/status/{task_id}"
@@ -134,6 +152,19 @@ def user(message, nudge_msg, history1, history2):
134
  return "", nudge_msg, history1, history2
135
 
136
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
  def chat(history1, history2, system_msg):
138
  history1 = history1 or []
139
  history2 = history2 or []
@@ -151,34 +182,17 @@ def chat(history1, history2, system_msg):
151
  messages1 = messages1.rstrip()
152
  messages2 = messages2.rstrip()
153
 
154
-
155
- with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
156
- futures = []
157
- futures.append(executor.submit(model1, messages1))
158
- futures.append(executor.submit(model2, messages2))
159
-
160
- # Wait for all threads to finish...
161
- for future in concurrent.futures.as_completed(futures):
162
- # If desired, you can check for exceptions here...
163
- if future.exception() is not None:
164
- print('Exception: {}'.format(future.exception()))
165
- traceback.print_exception(type(future.exception()), future.exception(), future.exception().__traceback__)
166
-
167
- tokens_model1 = re.findall(r'\s*\S+\s*', futures[0].result()[0]['generated_text'])
168
- tokens_model2 = re.findall(r'\s*\S+\s*', futures[1].result()[0]['generated_text'])
169
- len_tokens_model1 = len(tokens_model1)
170
- len_tokens_model2 = len(tokens_model2)
171
- max_tokens = max(len_tokens_model1, len_tokens_model2)
172
- for i in range(0, max_tokens):
173
- if i < len_tokens_model1:
174
- answer1 = tokens_model1[i]
175
- history1[-1][1] += answer1
176
- if i < len_tokens_model2:
177
- answer2 = tokens_model2[i]
178
- history2[-1][1] += answer2
179
  # stream the response
180
  yield history1, history2, "", gr.update(value=random_battle[0]), gr.update(value=random_battle[1]), {"models": [model1.name, model2.name]}
181
- sleep(0.15)
182
 
183
 
184
  def chosen_one(label, choice1_history, choice2_history, system_msg, nudge_msg, rlhf_persona, state):
 
7
  import traceback
8
  import uuid
9
  import datetime
10
+ from collections import deque
11
+ import itertools
12
+
13
  from collections import defaultdict
14
  from time import sleep
15
+ from typing import Generator, Tuple
16
 
17
  import boto3
18
  import gradio as gr
 
60
  "stop": ["</s>", "USER:", "### Instruction:"] + stop_tokens,
61
  }
62
 
63
+ def __call__(self, prompt) -> Generator[str, None, None]:
64
  input = self.generation_config.copy()
65
  input["prompt"] = prompt
66
 
 
75
 
76
  if response.status_code == 200:
77
  data = response.json()
78
+ task_id = data.get('id')
79
+ return self.stream_output(task_id)
80
+
81
+ def stream_output(self,task_id) -> Generator[str, None, None]:
82
+ url = f"https://api.runpod.ai/v2/{self.endpoint_id}/stream/{task_id}"
83
+ headers = {
84
+ "Authorization": f"Bearer {os.environ['RUNPOD_AI_API_KEY']}"
85
+ }
86
+
87
+ while True:
88
+ response = requests.get(url, headers=headers)
89
+ if response.status_code == 200:
90
+ data = response.json()
91
+ yield [{"generated_text": "".join([s["output"] for s in data["stream"]])}]
92
+ if data.get('status') == 'COMPLETED':
93
+ return
94
+ elif response.status_code >= 400:
95
+ logging.error(response.json())
96
+ # Sleep for 0.5 seconds between each request
97
+ sleep(0.5)
98
 
99
  def poll_for_status(self, task_id):
100
  url = f"https://api.runpod.ai/v2/{self.endpoint_id}/status/{task_id}"
 
152
  return "", nudge_msg, history1, history2
153
 
154
 
155
+ def token_generator(generator1, generator2, mapping_fn=None, fillvalue=None):
156
+ if not fillvalue:
157
+ fillvalue = ''
158
+ if not mapping_fn:
159
+ mapping_fn = lambda x: x
160
+ for output1, output2 in itertools.zip_longest(generator1, generator2, fillvalue=fillvalue):
161
+ tokens1 = re.findall(r'\s*\S+\s*', mapping_fn(output1))
162
+ tokens2 = re.findall(r'\s*\S+\s*', mapping_fn(output2))
163
+
164
+ for token1, token2 in itertools.zip_longest(tokens1, tokens2, fillvalue=''):
165
+ yield token1, token2
166
+
167
+
168
  def chat(history1, history2, system_msg):
169
  history1 = history1 or []
170
  history2 = history2 or []
 
182
  messages1 = messages1.rstrip()
183
  messages2 = messages2.rstrip()
184
 
185
+ model1_res = model1(messages1) # type: Generator[str, None, None]
186
+ model2_res = model2(messages2) # type: Generator[str, None, None]
187
+ res = token_generator(model1_res, model2_res, lambda x: x[0]['generated_text'], fillvalue=[{'generated_text': ''}]) # type: Generator[Tuple[str, str], None, None]
188
+ for t1, t2 in res:
189
+ if t1 is not None:
190
+ history1[-1][1] += t1
191
+ if t2 is not None:
192
+ history2[-1][1] += t2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
  # stream the response
194
  yield history1, history2, "", gr.update(value=random_battle[0]), gr.update(value=random_battle[1]), {"models": [model1.name, model2.name]}
195
+ sleep(0.2)
196
 
197
 
198
  def chosen_one(label, choice1_history, choice2_history, system_msg, nudge_msg, rlhf_persona, state):