Spaces:
Runtime error
Runtime error
Reformatted with black
Browse files
app.py
CHANGED
@@ -1,18 +1,23 @@
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
|
3 |
-
from prompt import iface_prompt
|
4 |
from chat import iface_chat
|
|
|
5 |
|
6 |
with gr.Blocks() as iface:
|
7 |
-
gr.Markdown(
|
8 |
-
|
|
|
9 |
|
10 |
-
|
11 |
|
12 |
-
|
13 |
-
|
14 |
|
15 |
-
|
|
|
16 |
|
17 |
gr.TabbedInterface([iface_prompt, iface_chat], ["Prompt mode", "Chat mode"])
|
18 |
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# or gradio app.py
|
3 |
+
|
4 |
import gradio as gr
|
5 |
|
|
|
6 |
from chat import iface_chat
|
7 |
+
from prompt import iface_prompt
|
8 |
|
9 |
with gr.Blocks() as iface:
|
10 |
+
gr.Markdown(
|
11 |
+
"""# Petals playground
|
12 |
+
**Let's play with prompts and inference settings for BLOOM and BLOOMZ 176B models!**
|
13 |
|
14 |
+
This space uses websocket API of [chat.petals.ml](http://chat.petals.ml). Health status of Petals network [lives here](http://health.petals.ml).
|
15 |
|
16 |
+
Do NOT talk to BLOOM as an entity, it's not a chatbot but a webpage/blog/article completion model.
|
17 |
+
For the best results: MIMIC a few sentences of a webpage similar to the content you want to generate.
|
18 |
|
19 |
+
BLOOMZ performs better in chat mode and understands the instructions better."""
|
20 |
+
)
|
21 |
|
22 |
gr.TabbedInterface([iface_prompt, iface_chat], ["Prompt mode", "Chat mode"])
|
23 |
|
chat.py
CHANGED
@@ -1,26 +1,26 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
import traceback
|
|
|
5 |
import gradio as gr
|
|
|
6 |
import chat_client
|
7 |
-
import time
|
8 |
-
import json
|
9 |
-
import re
|
10 |
|
11 |
-
CHAT_URL=
|
12 |
-
#CHAT_URL='ws://localhost:8000/api/v2/generate'
|
13 |
|
14 |
EMPTY_STATE = {
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
}
|
20 |
|
|
|
21 |
def generate(state, prompt, model, context, output, *args):
|
22 |
# Save that we're in generating loop
|
23 |
-
state[
|
24 |
|
25 |
try:
|
26 |
yield from _generate(state, prompt, model, context, output, *args)
|
@@ -29,16 +29,28 @@ def generate(state, prompt, model, context, output, *args):
|
|
29 |
# TODO This is a bit fragile because of recursive call...
|
30 |
print("Retrying session...")
|
31 |
context = output
|
32 |
-
output =
|
33 |
yield from generate(state, prompt, model, context, output, *args)
|
34 |
finally:
|
35 |
-
state[
|
36 |
-
|
37 |
-
|
38 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
|
40 |
start = time.time()
|
41 |
-
cnt = 0
|
42 |
|
43 |
def stats():
|
44 |
# Produces inline stats for generation speed
|
@@ -50,34 +62,37 @@ def _generate(state, prompt, model, context, output, endseq, max_length,
|
|
50 |
sec_per_item = (time.time() - start) / cnt
|
51 |
return f" | {sec_per_item:.1f} sec/t"
|
52 |
|
53 |
-
print('prompt', prompt)
|
54 |
eos = "</s>\n" if "bloomz" in model else "\n\n"
|
55 |
|
56 |
-
if state[
|
57 |
# If the connection is resumed, output is truncated in generate().
|
58 |
-
# So this
|
59 |
context = output
|
60 |
-
output =
|
61 |
|
62 |
-
|
63 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
|
65 |
try:
|
66 |
-
state[
|
67 |
-
state[
|
68 |
-
state[
|
69 |
except Exception:
|
70 |
print(traceback.format_exc())
|
71 |
raise gr.Error(traceback.format_exc(limit=3))
|
72 |
|
73 |
else:
|
74 |
-
context =
|
75 |
-
|
76 |
-
client = state['client']
|
77 |
|
|
|
78 |
context += eos
|
79 |
-
#for question, answer in state['history']:
|
80 |
-
# context += f"Human: {question}{eos}AI: {answer}{eos}"
|
81 |
|
82 |
# Fix eventual eos token mismatch and add eos token to context and prompt
|
83 |
if "bloomz" in model:
|
@@ -87,7 +102,7 @@ def _generate(state, prompt, model, context, output, endseq, max_length,
|
|
87 |
context = context.replace("</s>", eos)
|
88 |
context = re.sub(r"\n\n+", "\n\n", context)
|
89 |
prompt2 = prompt.replace("</s>", eos) + "\n\n"
|
90 |
-
|
91 |
prompt2 = f"{context}Human: {prompt2}AI:"
|
92 |
|
93 |
# Translate checkbox items to actual sequences
|
@@ -119,24 +134,22 @@ def _generate(state, prompt, model, context, output, endseq, max_length,
|
|
119 |
|
120 |
output += prompt2
|
121 |
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
orig_history = state['history']
|
126 |
-
new_line = ''
|
127 |
try:
|
128 |
-
for out in client.generate(
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
|
|
138 |
client.close_session()
|
139 |
-
yield state, [],
|
140 |
# Stopping generation
|
141 |
return
|
142 |
|
@@ -149,64 +162,84 @@ def _generate(state, prompt, model, context, output, endseq, max_length,
|
|
149 |
spl = new_line.split(s)
|
150 |
new_line = spl[0]
|
151 |
if len(spl) > 1:
|
152 |
-
state[
|
153 |
output += new_line
|
154 |
-
yield state, state[
|
155 |
# Stopping generation
|
156 |
return
|
157 |
-
|
158 |
# Keep original history untouched as we're adding just
|
159 |
# a chunks at one moment.
|
160 |
-
state[
|
161 |
-
yield state, state[
|
162 |
|
163 |
# Final line w/o statistics
|
164 |
-
yield state, state[
|
165 |
|
166 |
except (json.decoder.JSONDecodeError, BrokenPipeError):
|
167 |
# Session was interrupted
|
168 |
# Handled in upstream func
|
169 |
client.close_session()
|
170 |
-
state[
|
171 |
-
state[
|
172 |
|
173 |
print("Broken session!")
|
174 |
raise
|
175 |
except Exception:
|
176 |
client.close_session()
|
177 |
-
state[
|
178 |
-
state[
|
179 |
|
180 |
print(traceback.format_exc())
|
181 |
raise gr.Error(traceback.format_exc(limit=3))
|
182 |
|
|
|
183 |
def reset(state):
|
184 |
"""Resets the session and clears the chat window."""
|
185 |
state.update(EMPTY_STATE)
|
186 |
-
return state, [],
|
187 |
|
|
|
|
|
|
|
188 |
with gr.Blocks() as iface_chat:
|
189 |
gr.Markdown("""**Let's talk to Bloom in a chat!**""")
|
190 |
|
191 |
with gr.Row():
|
192 |
-
model = gr.Radio(
|
|
|
|
|
193 |
|
194 |
# Additional ending sequence, at which generation shoud stop
|
195 |
-
endseq = gr.CheckboxGroup(
|
196 |
-
|
|
|
|
|
|
|
197 |
|
198 |
# Maximum length of inference session
|
199 |
-
max_length = gr.Radio(
|
|
|
|
|
|
|
|
|
|
|
200 |
|
201 |
with gr.Row():
|
202 |
with gr.Column():
|
203 |
# Switch between sampling and greedy generation
|
204 |
do_sample = gr.Checkbox(value=True, interactive=True, label="do_sample")
|
205 |
-
context = gr.Textbox(
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
210 |
|
211 |
# Only one of top_k and top_p can be set. Requires "do_sample=True" to work.
|
212 |
top_k = gr.Number(value=0, precision=0, interactive=True, label="top_k")
|
@@ -214,12 +247,14 @@ with gr.Blocks() as iface_chat:
|
|
214 |
# TODO num_beams
|
215 |
|
216 |
# Generation temperature
|
217 |
-
temperature = gr.Number(
|
218 |
-
|
|
|
219 |
|
220 |
-
chat = gr.Chatbot(label=
|
221 |
-
prompt = gr.Textbox(
|
222 |
-
|
|
|
223 |
|
224 |
with gr.Row():
|
225 |
button_generate = gr.Button("Generate")
|
@@ -231,20 +266,40 @@ with gr.Blocks() as iface_chat:
|
|
231 |
# Chat history
|
232 |
state = gr.State(EMPTY_STATE)
|
233 |
|
234 |
-
|
235 |
-
|
236 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
237 |
|
238 |
prompt.submit(generate, inputs=inputs, outputs=outputs)
|
239 |
button_generate.click(generate, inputs=inputs, outputs=outputs)
|
240 |
button_reset.click(reset, inputs=[state], outputs=[state, chat, output])
|
241 |
|
242 |
-
examples = gr.Examples(
|
|
|
243 |
examples=[
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import re
|
3 |
+
import time
|
4 |
import traceback
|
5 |
+
|
6 |
import gradio as gr
|
7 |
+
|
8 |
import chat_client
|
|
|
|
|
|
|
9 |
|
10 |
+
CHAT_URL = "ws://chat.petals.ml/api/v2/generate"
|
11 |
+
# CHAT_URL='ws://localhost:8000/api/v2/generate'
|
12 |
|
13 |
EMPTY_STATE = {
|
14 |
+
"generate": False,
|
15 |
+
"model": None,
|
16 |
+
"client": None,
|
17 |
+
"history": [],
|
18 |
}
|
19 |
|
20 |
+
|
21 |
def generate(state, prompt, model, context, output, *args):
|
22 |
# Save that we're in generating loop
|
23 |
+
state["generate"] = True
|
24 |
|
25 |
try:
|
26 |
yield from _generate(state, prompt, model, context, output, *args)
|
|
|
29 |
# TODO This is a bit fragile because of recursive call...
|
30 |
print("Retrying session...")
|
31 |
context = output
|
32 |
+
output = ""
|
33 |
yield from generate(state, prompt, model, context, output, *args)
|
34 |
finally:
|
35 |
+
state["generate"] = False
|
36 |
+
|
37 |
+
|
38 |
+
def _generate(
|
39 |
+
state,
|
40 |
+
prompt,
|
41 |
+
model,
|
42 |
+
context,
|
43 |
+
output,
|
44 |
+
endseq,
|
45 |
+
max_length,
|
46 |
+
do_sample,
|
47 |
+
top_k,
|
48 |
+
top_p,
|
49 |
+
temperature,
|
50 |
+
):
|
51 |
|
52 |
start = time.time()
|
53 |
+
cnt = 0 # Tokens generated
|
54 |
|
55 |
def stats():
|
56 |
# Produces inline stats for generation speed
|
|
|
62 |
sec_per_item = (time.time() - start) / cnt
|
63 |
return f" | {sec_per_item:.1f} sec/t"
|
64 |
|
|
|
65 |
eos = "</s>\n" if "bloomz" in model else "\n\n"
|
66 |
|
67 |
+
if state["model"] != model and output:
|
68 |
# If the connection is resumed, output is truncated in generate().
|
69 |
+
# So this executes when user change model.
|
70 |
context = output
|
71 |
+
output = ""
|
72 |
|
73 |
+
# Update widgets even before we get the first response
|
74 |
+
print("prompt", prompt)
|
75 |
+
yield state, state["history"] + [[prompt, stats()]], "", output
|
76 |
+
|
77 |
+
if (
|
78 |
+
state["model"] != model
|
79 |
+
or state["client"] == None
|
80 |
+
or state["client"].is_session() == False
|
81 |
+
):
|
82 |
|
83 |
try:
|
84 |
+
state["client"] = chat_client.ModelClient(CHAT_URL)
|
85 |
+
state["client"].open_session(f"bigscience/{model}-petals", max_length)
|
86 |
+
state["model"] = model
|
87 |
except Exception:
|
88 |
print(traceback.format_exc())
|
89 |
raise gr.Error(traceback.format_exc(limit=3))
|
90 |
|
91 |
else:
|
92 |
+
context = ""
|
|
|
|
|
93 |
|
94 |
+
client = state["client"]
|
95 |
context += eos
|
|
|
|
|
96 |
|
97 |
# Fix eventual eos token mismatch and add eos token to context and prompt
|
98 |
if "bloomz" in model:
|
|
|
102 |
context = context.replace("</s>", eos)
|
103 |
context = re.sub(r"\n\n+", "\n\n", context)
|
104 |
prompt2 = prompt.replace("</s>", eos) + "\n\n"
|
105 |
+
|
106 |
prompt2 = f"{context}Human: {prompt2}AI:"
|
107 |
|
108 |
# Translate checkbox items to actual sequences
|
|
|
134 |
|
135 |
output += prompt2
|
136 |
|
137 |
+
orig_history = state["history"]
|
138 |
+
new_line = ""
|
|
|
|
|
|
|
139 |
try:
|
140 |
+
for out in client.generate(
|
141 |
+
prompt2,
|
142 |
+
max_new_tokens=1,
|
143 |
+
do_sample=do_sample,
|
144 |
+
temperature=temperature,
|
145 |
+
top_k=top_k,
|
146 |
+
top_p=top_p,
|
147 |
+
extra_stop_sequences=seq,
|
148 |
+
):
|
149 |
+
|
150 |
+
if not state["generate"]:
|
151 |
client.close_session()
|
152 |
+
yield state, [], "", ""
|
153 |
# Stopping generation
|
154 |
return
|
155 |
|
|
|
162 |
spl = new_line.split(s)
|
163 |
new_line = spl[0]
|
164 |
if len(spl) > 1:
|
165 |
+
state["history"] = orig_history + [[prompt, new_line]]
|
166 |
output += new_line
|
167 |
+
yield state, state["history"], "", output
|
168 |
# Stopping generation
|
169 |
return
|
170 |
+
|
171 |
# Keep original history untouched as we're adding just
|
172 |
# a chunks at one moment.
|
173 |
+
state["history"] = orig_history + [[prompt, new_line + stats()]]
|
174 |
+
yield state, state["history"], "", output
|
175 |
|
176 |
# Final line w/o statistics
|
177 |
+
yield state, state["history"], "", output
|
178 |
|
179 |
except (json.decoder.JSONDecodeError, BrokenPipeError):
|
180 |
# Session was interrupted
|
181 |
# Handled in upstream func
|
182 |
client.close_session()
|
183 |
+
state["client"] = None
|
184 |
+
state["model"] = None
|
185 |
|
186 |
print("Broken session!")
|
187 |
raise
|
188 |
except Exception:
|
189 |
client.close_session()
|
190 |
+
state["client"] = None
|
191 |
+
state["model"] = None
|
192 |
|
193 |
print(traceback.format_exc())
|
194 |
raise gr.Error(traceback.format_exc(limit=3))
|
195 |
|
196 |
+
|
197 |
def reset(state):
|
198 |
"""Resets the session and clears the chat window."""
|
199 |
state.update(EMPTY_STATE)
|
200 |
+
return state, [], ""
|
201 |
|
202 |
+
|
203 |
+
# ---------------------------------------------------------
|
204 |
+
# Defining Gradio layout
|
205 |
with gr.Blocks() as iface_chat:
|
206 |
gr.Markdown("""**Let's talk to Bloom in a chat!**""")
|
207 |
|
208 |
with gr.Row():
|
209 |
+
model = gr.Radio(
|
210 |
+
["bloom", "bloomz", "bloom-7b1"], value="bloomz", label="Use model"
|
211 |
+
)
|
212 |
|
213 |
# Additional ending sequence, at which generation shoud stop
|
214 |
+
endseq = gr.CheckboxGroup(
|
215 |
+
["Human:", "AI:", "\\n", "</s>", "? (question mark)", ". (dot)"],
|
216 |
+
value=["Human:", "AI:", "\\n", "</s>"],
|
217 |
+
label="Extra end sequences",
|
218 |
+
)
|
219 |
|
220 |
# Maximum length of inference session
|
221 |
+
max_length = gr.Radio(
|
222 |
+
[64, 128, 256, 512, 1024, 2048],
|
223 |
+
value=1024,
|
224 |
+
interactive=True,
|
225 |
+
label="Max length",
|
226 |
+
)
|
227 |
|
228 |
with gr.Row():
|
229 |
with gr.Column():
|
230 |
# Switch between sampling and greedy generation
|
231 |
do_sample = gr.Checkbox(value=True, interactive=True, label="do_sample")
|
232 |
+
context = gr.Textbox(
|
233 |
+
lines=3,
|
234 |
+
label="Initial context:",
|
235 |
+
interactive=True,
|
236 |
+
value="A Human talks to a powerful AI that follows "
|
237 |
+
"the Human's instructions.\n"
|
238 |
+
"AI is talkative, friendly, positive and provides "
|
239 |
+
"detailed answers to any question.</s>\n"
|
240 |
+
"Human: Hi!</s>\n"
|
241 |
+
"AI: How can I help you?",
|
242 |
+
)
|
243 |
|
244 |
# Only one of top_k and top_p can be set. Requires "do_sample=True" to work.
|
245 |
top_k = gr.Number(value=0, precision=0, interactive=True, label="top_k")
|
|
|
247 |
# TODO num_beams
|
248 |
|
249 |
# Generation temperature
|
250 |
+
temperature = gr.Number(
|
251 |
+
value=0.75, precision=2, interactive=True, label="Temperature"
|
252 |
+
)
|
253 |
|
254 |
+
chat = gr.Chatbot(label="Chat window")
|
255 |
+
prompt = gr.Textbox(
|
256 |
+
show_label=False, label="Prompt", placeholder="Prompt Here and press Enter..."
|
257 |
+
).style(container=False)
|
258 |
|
259 |
with gr.Row():
|
260 |
button_generate = gr.Button("Generate")
|
|
|
266 |
# Chat history
|
267 |
state = gr.State(EMPTY_STATE)
|
268 |
|
269 |
+
# Define button actions
|
270 |
+
inputs = [
|
271 |
+
state,
|
272 |
+
prompt,
|
273 |
+
model,
|
274 |
+
context,
|
275 |
+
output,
|
276 |
+
endseq,
|
277 |
+
max_length,
|
278 |
+
do_sample,
|
279 |
+
top_k,
|
280 |
+
top_p,
|
281 |
+
temperature,
|
282 |
+
]
|
283 |
+
outputs = [state, chat, prompt, output]
|
284 |
|
285 |
prompt.submit(generate, inputs=inputs, outputs=outputs)
|
286 |
button_generate.click(generate, inputs=inputs, outputs=outputs)
|
287 |
button_reset.click(reset, inputs=[state], outputs=[state, chat, output])
|
288 |
|
289 |
+
examples = gr.Examples(
|
290 |
+
inputs=[context, prompt, model, do_sample, top_k, top_p, temperature],
|
291 |
examples=[
|
292 |
+
[
|
293 |
+
"A Human talks to a powerful AI that follows the Human's instructions. "
|
294 |
+
"AI is talkative, friendly, positive and provides detailed answers to any question.</s>\n"
|
295 |
+
"Human: Hi!</s>\n"
|
296 |
+
"AI: Hi! How can I help you?",
|
297 |
+
"Could you remind me please who was Neil Armstrong?",
|
298 |
+
"bloomz",
|
299 |
+
True,
|
300 |
+
0,
|
301 |
+
0.9,
|
302 |
+
0.75,
|
303 |
+
],
|
304 |
+
],
|
305 |
+
)
|
prompt.py
CHANGED
@@ -1,33 +1,44 @@
|
|
1 |
-
|
2 |
-
# or gradio app.py
|
3 |
-
|
4 |
import traceback
|
|
|
5 |
import gradio as gr
|
|
|
6 |
import chat_client
|
7 |
-
import time
|
8 |
|
9 |
-
CHAT_URL=
|
10 |
-
#CHAT_URL='ws://localhost:8000/api/v2/generate'
|
|
|
11 |
|
12 |
def generate(state, *args):
|
13 |
# Save that we're in generating loop
|
14 |
-
state[
|
15 |
|
16 |
try:
|
17 |
-
|
18 |
-
yield x
|
19 |
finally:
|
20 |
-
state[
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
|
26 |
start = time.time()
|
27 |
cnt = 0
|
28 |
|
29 |
def stats():
|
30 |
# Produces inline stats for generation speed
|
|
|
31 |
if cnt == 0:
|
32 |
return "\u2026 | ? sec/t"
|
33 |
if cnt > time.time() - start:
|
@@ -70,23 +81,24 @@ def _generate(state, prompt, model, endseq, max_length,
|
|
70 |
temperature = 1.0
|
71 |
|
72 |
prompt2 = prompt
|
73 |
-
output =
|
74 |
|
75 |
# This render prompt dialog immediately and
|
76 |
# don't wait to generator to return first result
|
77 |
yield [state, prompt2, stats()]
|
78 |
|
79 |
try:
|
80 |
-
for out in client.generate(
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
|
|
90 |
client.close_session()
|
91 |
return
|
92 |
|
@@ -104,31 +116,53 @@ def _generate(state, prompt, model, endseq, max_length,
|
|
104 |
print(traceback.format_exc())
|
105 |
raise gr.Error(traceback.format_exc(limit=3))
|
106 |
|
|
|
107 |
def stop(state):
|
108 |
"""Stops generating."""
|
109 |
state.update({"generate": False})
|
110 |
return state
|
111 |
|
|
|
|
|
|
|
112 |
with gr.Blocks() as iface_prompt:
|
113 |
-
gr.Markdown(
|
|
|
|
|
|
|
114 |
|
115 |
with gr.Row():
|
116 |
-
model = gr.Radio(
|
|
|
|
|
117 |
|
118 |
# Additional ending sequence, at which generation shoud stop
|
119 |
-
endseq = gr.CheckboxGroup(
|
120 |
-
|
|
|
|
|
|
|
121 |
|
122 |
# Maximum length of inference session
|
123 |
-
max_length = gr.Radio(
|
|
|
|
|
|
|
|
|
|
|
124 |
|
125 |
with gr.Row():
|
126 |
with gr.Column():
|
127 |
# Switch between sampling and greedy generation
|
128 |
do_sample = gr.Checkbox(value=True, interactive=True, label="do_sample")
|
129 |
|
130 |
-
# Should the app append stop sequence at the end of prompt
|
131 |
-
|
|
|
|
|
|
|
|
|
|
|
132 |
|
133 |
# Only one of top_k and top_p can be set. Requires "do_sample=True" to work.
|
134 |
top_k = gr.Number(value=0, precision=0, interactive=True, label="top_k")
|
@@ -136,10 +170,12 @@ with gr.Blocks() as iface_prompt:
|
|
136 |
# TODO num_beams
|
137 |
|
138 |
# Generation temperature
|
139 |
-
temperature = gr.Number(
|
|
|
|
|
140 |
|
141 |
-
prompt = gr.Textbox(lines=3, label=
|
142 |
-
state = gr.State({
|
143 |
|
144 |
with gr.Row():
|
145 |
button_generate = gr.Button("Generate")
|
@@ -148,22 +184,62 @@ with gr.Blocks() as iface_prompt:
|
|
148 |
# Automatically copy the output at the end of prompt
|
149 |
copy_output = gr.Checkbox(label="Output -> Prompt")
|
150 |
|
151 |
-
output = gr.Textbox(lines=3, label=
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
157 |
button_stop.click(stop, inputs=[state], outputs=[state])
|
158 |
|
159 |
-
examples = gr.Examples(
|
|
|
160 |
examples=[
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
|
|
|
|
2 |
import traceback
|
3 |
+
|
4 |
import gradio as gr
|
5 |
+
|
6 |
import chat_client
|
|
|
7 |
|
8 |
+
CHAT_URL = "ws://chat.petals.ml/api/v2/generate"
|
9 |
+
# CHAT_URL='ws://localhost:8000/api/v2/generate'
|
10 |
+
|
11 |
|
12 |
def generate(state, *args):
|
13 |
# Save that we're in generating loop
|
14 |
+
state["generate"] = True
|
15 |
|
16 |
try:
|
17 |
+
yield from _generate(state, *args)
|
|
|
18 |
finally:
|
19 |
+
state["generate"] = False
|
20 |
+
|
21 |
+
|
22 |
+
def _generate(
|
23 |
+
state,
|
24 |
+
prompt,
|
25 |
+
model,
|
26 |
+
endseq,
|
27 |
+
max_length,
|
28 |
+
do_sample,
|
29 |
+
top_k,
|
30 |
+
top_p,
|
31 |
+
temperature,
|
32 |
+
add_stoptoken,
|
33 |
+
copy_output,
|
34 |
+
):
|
35 |
|
36 |
start = time.time()
|
37 |
cnt = 0
|
38 |
|
39 |
def stats():
|
40 |
# Produces inline stats for generation speed
|
41 |
+
# sec/t or t/sec depending on the speed
|
42 |
if cnt == 0:
|
43 |
return "\u2026 | ? sec/t"
|
44 |
if cnt > time.time() - start:
|
|
|
81 |
temperature = 1.0
|
82 |
|
83 |
prompt2 = prompt
|
84 |
+
output = ""
|
85 |
|
86 |
# This render prompt dialog immediately and
|
87 |
# don't wait to generator to return first result
|
88 |
yield [state, prompt2, stats()]
|
89 |
|
90 |
try:
|
91 |
+
for out in client.generate(
|
92 |
+
prompt,
|
93 |
+
max_new_tokens=1,
|
94 |
+
do_sample=do_sample,
|
95 |
+
temperature=temperature,
|
96 |
+
top_k=top_k,
|
97 |
+
top_p=top_p,
|
98 |
+
extra_stop_sequences=seq,
|
99 |
+
):
|
100 |
+
|
101 |
+
if not state["generate"]:
|
102 |
client.close_session()
|
103 |
return
|
104 |
|
|
|
116 |
print(traceback.format_exc())
|
117 |
raise gr.Error(traceback.format_exc(limit=3))
|
118 |
|
119 |
+
|
120 |
def stop(state):
|
121 |
"""Stops generating."""
|
122 |
state.update({"generate": False})
|
123 |
return state
|
124 |
|
125 |
+
|
126 |
+
# ---------------------------------------------------------
|
127 |
+
# Defining Gradio layout
|
128 |
with gr.Blocks() as iface_prompt:
|
129 |
+
gr.Markdown(
|
130 |
+
"""**Useful for testing raw prompts with zero,
|
131 |
+
one or few-shot prompting.**"""
|
132 |
+
)
|
133 |
|
134 |
with gr.Row():
|
135 |
+
model = gr.Radio(
|
136 |
+
["bloom", "bloomz", "bloom-7b1"], value="bloom", label="Use model"
|
137 |
+
)
|
138 |
|
139 |
# Additional ending sequence, at which generation shoud stop
|
140 |
+
endseq = gr.CheckboxGroup(
|
141 |
+
["\\n", "</s>", "? (question mark)", ". (dot)"],
|
142 |
+
value=["\\n", "</s>"],
|
143 |
+
label="Extra end sequences",
|
144 |
+
)
|
145 |
|
146 |
# Maximum length of inference session
|
147 |
+
max_length = gr.Radio(
|
148 |
+
[64, 128, 256, 512, 1024, 2048],
|
149 |
+
value=512,
|
150 |
+
interactive=True,
|
151 |
+
label="Max length",
|
152 |
+
)
|
153 |
|
154 |
with gr.Row():
|
155 |
with gr.Column():
|
156 |
# Switch between sampling and greedy generation
|
157 |
do_sample = gr.Checkbox(value=True, interactive=True, label="do_sample")
|
158 |
|
159 |
+
# Should the app append stop sequence at the end of prompt
|
160 |
+
# or should it leave the prompt open?
|
161 |
+
add_stoptoken = gr.Checkbox(
|
162 |
+
value=True,
|
163 |
+
interactive=True,
|
164 |
+
label="Automatically add eos token to the prompt.",
|
165 |
+
)
|
166 |
|
167 |
# Only one of top_k and top_p can be set. Requires "do_sample=True" to work.
|
168 |
top_k = gr.Number(value=0, precision=0, interactive=True, label="top_k")
|
|
|
170 |
# TODO num_beams
|
171 |
|
172 |
# Generation temperature
|
173 |
+
temperature = gr.Number(
|
174 |
+
value=0.75, precision=2, interactive=True, label="Temperature"
|
175 |
+
)
|
176 |
|
177 |
+
prompt = gr.Textbox(lines=3, label="Prompt", placeholder="Prompt Here...")
|
178 |
+
state = gr.State({"generate": False})
|
179 |
|
180 |
with gr.Row():
|
181 |
button_generate = gr.Button("Generate")
|
|
|
184 |
# Automatically copy the output at the end of prompt
|
185 |
copy_output = gr.Checkbox(label="Output -> Prompt")
|
186 |
|
187 |
+
output = gr.Textbox(lines=3, label="Output")
|
188 |
+
|
189 |
+
# Define button actions
|
190 |
+
button_generate.click(
|
191 |
+
generate,
|
192 |
+
inputs=[
|
193 |
+
state,
|
194 |
+
prompt,
|
195 |
+
model,
|
196 |
+
endseq,
|
197 |
+
max_length,
|
198 |
+
do_sample,
|
199 |
+
top_k,
|
200 |
+
top_p,
|
201 |
+
temperature,
|
202 |
+
add_stoptoken,
|
203 |
+
copy_output,
|
204 |
+
],
|
205 |
+
outputs=[state, prompt, output],
|
206 |
+
)
|
207 |
button_stop.click(stop, inputs=[state], outputs=[state])
|
208 |
|
209 |
+
examples = gr.Examples(
|
210 |
+
inputs=[prompt, model, do_sample, top_k, top_p, temperature, add_stoptoken],
|
211 |
examples=[
|
212 |
+
[
|
213 |
+
"The SQL command to extract all the users whose name starts with A is: ",
|
214 |
+
"bloom-7b1",
|
215 |
+
False,
|
216 |
+
0,
|
217 |
+
0,
|
218 |
+
1,
|
219 |
+
False,
|
220 |
+
],
|
221 |
+
[
|
222 |
+
"The Spanish translation of thank you for your help is: ",
|
223 |
+
"bloom-7b1",
|
224 |
+
False,
|
225 |
+
0,
|
226 |
+
0,
|
227 |
+
1,
|
228 |
+
False,
|
229 |
+
],
|
230 |
+
[
|
231 |
+
"A human talks to a powerful AI that follows the Human's instructions.\n"
|
232 |
+
"AI is talkative, friendly, positive and provides detailed answers to any question.</s>\n"
|
233 |
+
"Human: Hi!</s>\n"
|
234 |
+
"AI: Hi! How can I help you?</s>\n"
|
235 |
+
"Human: What's the capital of Portugal?</s>\n"
|
236 |
+
"AI: ",
|
237 |
+
"bloomz",
|
238 |
+
True,
|
239 |
+
0,
|
240 |
+
0.9,
|
241 |
+
0.75,
|
242 |
+
False,
|
243 |
+
],
|
244 |
+
],
|
245 |
+
)
|