Spaces:
Runtime error
Runtime error
streaming token support
Browse filespass mapping
make sure to call different models
fix fill value on iter
app.py
CHANGED
@@ -7,8 +7,12 @@ import re
|
|
7 |
import traceback
|
8 |
import uuid
|
9 |
import datetime
|
|
|
|
|
|
|
10 |
from collections import defaultdict
|
11 |
from time import sleep
|
|
|
12 |
|
13 |
import boto3
|
14 |
import gradio as gr
|
@@ -56,7 +60,7 @@ class Pipeline:
|
|
56 |
"stop": ["</s>", "USER:", "### Instruction:"] + stop_tokens,
|
57 |
}
|
58 |
|
59 |
-
def __call__(self, prompt):
|
60 |
input = self.generation_config.copy()
|
61 |
input["prompt"] = prompt
|
62 |
|
@@ -71,12 +75,26 @@ class Pipeline:
|
|
71 |
|
72 |
if response.status_code == 200:
|
73 |
data = response.json()
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
|
81 |
def poll_for_status(self, task_id):
|
82 |
url = f"https://api.runpod.ai/v2/{self.endpoint_id}/status/{task_id}"
|
@@ -134,6 +152,19 @@ def user(message, nudge_msg, history1, history2):
|
|
134 |
return "", nudge_msg, history1, history2
|
135 |
|
136 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
137 |
def chat(history1, history2, system_msg):
|
138 |
history1 = history1 or []
|
139 |
history2 = history2 or []
|
@@ -151,34 +182,17 @@ def chat(history1, history2, system_msg):
|
|
151 |
messages1 = messages1.rstrip()
|
152 |
messages2 = messages2.rstrip()
|
153 |
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
# If desired, you can check for exceptions here...
|
163 |
-
if future.exception() is not None:
|
164 |
-
print('Exception: {}'.format(future.exception()))
|
165 |
-
traceback.print_exception(type(future.exception()), future.exception(), future.exception().__traceback__)
|
166 |
-
|
167 |
-
tokens_model1 = re.findall(r'\s*\S+\s*', futures[0].result()[0]['generated_text'])
|
168 |
-
tokens_model2 = re.findall(r'\s*\S+\s*', futures[1].result()[0]['generated_text'])
|
169 |
-
len_tokens_model1 = len(tokens_model1)
|
170 |
-
len_tokens_model2 = len(tokens_model2)
|
171 |
-
max_tokens = max(len_tokens_model1, len_tokens_model2)
|
172 |
-
for i in range(0, max_tokens):
|
173 |
-
if i < len_tokens_model1:
|
174 |
-
answer1 = tokens_model1[i]
|
175 |
-
history1[-1][1] += answer1
|
176 |
-
if i < len_tokens_model2:
|
177 |
-
answer2 = tokens_model2[i]
|
178 |
-
history2[-1][1] += answer2
|
179 |
# stream the response
|
180 |
yield history1, history2, "", gr.update(value=random_battle[0]), gr.update(value=random_battle[1]), {"models": [model1.name, model2.name]}
|
181 |
-
sleep(0.
|
182 |
|
183 |
|
184 |
def chosen_one(label, choice1_history, choice2_history, system_msg, nudge_msg, rlhf_persona, state):
|
|
|
7 |
import traceback
|
8 |
import uuid
|
9 |
import datetime
|
10 |
+
from collections import deque
|
11 |
+
import itertools
|
12 |
+
|
13 |
from collections import defaultdict
|
14 |
from time import sleep
|
15 |
+
from typing import Generator, Tuple
|
16 |
|
17 |
import boto3
|
18 |
import gradio as gr
|
|
|
60 |
"stop": ["</s>", "USER:", "### Instruction:"] + stop_tokens,
|
61 |
}
|
62 |
|
63 |
+
def __call__(self, prompt) -> Generator[str, None, None]:
|
64 |
input = self.generation_config.copy()
|
65 |
input["prompt"] = prompt
|
66 |
|
|
|
75 |
|
76 |
if response.status_code == 200:
|
77 |
data = response.json()
|
78 |
+
task_id = data.get('id')
|
79 |
+
return self.stream_output(task_id)
|
80 |
+
|
81 |
+
def stream_output(self,task_id) -> Generator[str, None, None]:
|
82 |
+
url = f"https://api.runpod.ai/v2/{self.endpoint_id}/stream/{task_id}"
|
83 |
+
headers = {
|
84 |
+
"Authorization": f"Bearer {os.environ['RUNPOD_AI_API_KEY']}"
|
85 |
+
}
|
86 |
+
|
87 |
+
while True:
|
88 |
+
response = requests.get(url, headers=headers)
|
89 |
+
if response.status_code == 200:
|
90 |
+
data = response.json()
|
91 |
+
yield [{"generated_text": "".join([s["output"] for s in data["stream"]])}]
|
92 |
+
if data.get('status') == 'COMPLETED':
|
93 |
+
return
|
94 |
+
elif response.status_code >= 400:
|
95 |
+
logging.error(response.json())
|
96 |
+
# Sleep for 0.5 seconds between each request
|
97 |
+
sleep(0.5)
|
98 |
|
99 |
def poll_for_status(self, task_id):
|
100 |
url = f"https://api.runpod.ai/v2/{self.endpoint_id}/status/{task_id}"
|
|
|
152 |
return "", nudge_msg, history1, history2
|
153 |
|
154 |
|
155 |
+
def token_generator(generator1, generator2, mapping_fn=None, fillvalue=None):
|
156 |
+
if not fillvalue:
|
157 |
+
fillvalue = ''
|
158 |
+
if not mapping_fn:
|
159 |
+
mapping_fn = lambda x: x
|
160 |
+
for output1, output2 in itertools.zip_longest(generator1, generator2, fillvalue=fillvalue):
|
161 |
+
tokens1 = re.findall(r'\s*\S+\s*', mapping_fn(output1))
|
162 |
+
tokens2 = re.findall(r'\s*\S+\s*', mapping_fn(output2))
|
163 |
+
|
164 |
+
for token1, token2 in itertools.zip_longest(tokens1, tokens2, fillvalue=''):
|
165 |
+
yield token1, token2
|
166 |
+
|
167 |
+
|
168 |
def chat(history1, history2, system_msg):
|
169 |
history1 = history1 or []
|
170 |
history2 = history2 or []
|
|
|
182 |
messages1 = messages1.rstrip()
|
183 |
messages2 = messages2.rstrip()
|
184 |
|
185 |
+
model1_res = model1(messages1) # type: Generator[str, None, None]
|
186 |
+
model2_res = model2(messages2) # type: Generator[str, None, None]
|
187 |
+
res = token_generator(model1_res, model2_res, lambda x: x[0]['generated_text'], fillvalue=[{'generated_text': ''}]) # type: Generator[Tuple[str, str], None, None]
|
188 |
+
for t1, t2 in res:
|
189 |
+
if t1 is not None:
|
190 |
+
history1[-1][1] += t1
|
191 |
+
if t2 is not None:
|
192 |
+
history2[-1][1] += t2
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
193 |
# stream the response
|
194 |
yield history1, history2, "", gr.update(value=random_battle[0]), gr.update(value=random_battle[1]), {"models": [model1.name, model2.name]}
|
195 |
+
sleep(0.2)
|
196 |
|
197 |
|
198 |
def chosen_one(label, choice1_history, choice2_history, system_msg, nudge_msg, rlhf_persona, state):
|