slush0 commited on
Commit
b70859f
·
1 Parent(s): 2caf2e7

Reformatted with black

Browse files
Files changed (3) hide show
  1. app.py +12 -7
  2. chat.py +141 -86
  3. prompt.py +127 -51
app.py CHANGED
@@ -1,18 +1,23 @@
 
 
 
1
  import gradio as gr
2
 
3
- from prompt import iface_prompt
4
  from chat import iface_chat
 
5
 
6
  with gr.Blocks() as iface:
7
- gr.Markdown("""# Petals playground
8
- **Let's play with prompts and inference settings for BLOOM and BLOOMZ 176B models!**
 
9
 
10
- This space uses websocket API of [chat.petals.ml](http://chat.petals.ml). Health status of Petals network [lives here](http://health.petals.ml).
11
 
12
- Do NOT talk to BLOOM as an entity, it's not a chatbot but a webpage/blog/article completion model.
13
- For the best results: MIMIC a few sentences of a webpage similar to the content you want to generate.
14
 
15
- BLOOMZ performs better in chat mode and understands the instructions better.""")
 
16
 
17
  gr.TabbedInterface([iface_prompt, iface_chat], ["Prompt mode", "Chat mode"])
18
 
 
1
+ #!/usr/bin/env python
2
+ # or gradio app.py
3
+
4
  import gradio as gr
5
 
 
6
  from chat import iface_chat
7
+ from prompt import iface_prompt
8
 
9
  with gr.Blocks() as iface:
10
+ gr.Markdown(
11
+ """# Petals playground
12
+ **Let's play with prompts and inference settings for BLOOM and BLOOMZ 176B models!**
13
 
14
+ This space uses websocket API of [chat.petals.ml](http://chat.petals.ml). Health status of Petals network [lives here](http://health.petals.ml).
15
 
16
+ Do NOT talk to BLOOM as an entity, it's not a chatbot but a webpage/blog/article completion model.
17
+ For the best results: MIMIC a few sentences of a webpage similar to the content you want to generate.
18
 
19
+ BLOOMZ performs better in chat mode and understands the instructions better."""
20
+ )
21
 
22
  gr.TabbedInterface([iface_prompt, iface_chat], ["Prompt mode", "Chat mode"])
23
 
chat.py CHANGED
@@ -1,26 +1,26 @@
1
- #!/usr/bin/env python
2
- # or gradio app.py
3
-
4
  import traceback
 
5
  import gradio as gr
 
6
  import chat_client
7
- import time
8
- import json
9
- import re
10
 
11
- CHAT_URL='ws://chat.petals.ml/api/v2/generate'
12
- #CHAT_URL='ws://localhost:8000/api/v2/generate'
13
 
14
  EMPTY_STATE = {
15
- 'generate': False,
16
- 'model': None,
17
- 'client': None,
18
- 'history': [],
19
  }
20
 
 
21
  def generate(state, prompt, model, context, output, *args):
22
  # Save that we're in generating loop
23
- state['generate'] = True
24
 
25
  try:
26
  yield from _generate(state, prompt, model, context, output, *args)
@@ -29,16 +29,28 @@ def generate(state, prompt, model, context, output, *args):
29
  # TODO This is a bit fragile because of recursive call...
30
  print("Retrying session...")
31
  context = output
32
- output = ''
33
  yield from generate(state, prompt, model, context, output, *args)
34
  finally:
35
- state['generate'] = False
36
-
37
- def _generate(state, prompt, model, context, output, endseq, max_length,
38
- do_sample, top_k, top_p, temperature):
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
  start = time.time()
41
- cnt = 0 # Tokens generated
42
 
43
  def stats():
44
  # Produces inline stats for generation speed
@@ -50,34 +62,37 @@ def _generate(state, prompt, model, context, output, endseq, max_length,
50
  sec_per_item = (time.time() - start) / cnt
51
  return f" | {sec_per_item:.1f} sec/t"
52
 
53
- print('prompt', prompt)
54
  eos = "</s>\n" if "bloomz" in model else "\n\n"
55
 
56
- if state['model'] != model and output:
57
  # If the connection is resumed, output is truncated in generate().
58
- # So this happen when user change model.
59
  context = output
60
- output = ''
61
 
62
- if state['model'] != model or \
63
- state['client'] == None or state['client'].is_session() == False:
 
 
 
 
 
 
 
64
 
65
  try:
66
- state['client'] = chat_client.ModelClient(CHAT_URL)
67
- state['client'].open_session(f"bigscience/{model}-petals", max_length)
68
- state['model'] = model
69
  except Exception:
70
  print(traceback.format_exc())
71
  raise gr.Error(traceback.format_exc(limit=3))
72
 
73
  else:
74
- context = ''
75
-
76
- client = state['client']
77
 
 
78
  context += eos
79
- #for question, answer in state['history']:
80
- # context += f"Human: {question}{eos}AI: {answer}{eos}"
81
 
82
  # Fix eventual eos token mismatch and add eos token to context and prompt
83
  if "bloomz" in model:
@@ -87,7 +102,7 @@ def _generate(state, prompt, model, context, output, endseq, max_length,
87
  context = context.replace("</s>", eos)
88
  context = re.sub(r"\n\n+", "\n\n", context)
89
  prompt2 = prompt.replace("</s>", eos) + "\n\n"
90
-
91
  prompt2 = f"{context}Human: {prompt2}AI:"
92
 
93
  # Translate checkbox items to actual sequences
@@ -119,24 +134,22 @@ def _generate(state, prompt, model, context, output, endseq, max_length,
119
 
120
  output += prompt2
121
 
122
- # Update widgets even before we get the first response
123
- yield state, state['history'] + [[prompt, stats()]], None, output
124
-
125
- orig_history = state['history']
126
- new_line = ''
127
  try:
128
- for out in client.generate(prompt2,
129
- max_new_tokens=1,
130
- do_sample=do_sample,
131
- temperature=temperature,
132
- top_k=top_k,
133
- top_p=top_p,
134
- extra_stop_sequences=seq
135
- ):
136
-
137
- if not state['generate']:
 
138
  client.close_session()
139
- yield state, [], None, ''
140
  # Stopping generation
141
  return
142
 
@@ -149,64 +162,84 @@ def _generate(state, prompt, model, context, output, endseq, max_length,
149
  spl = new_line.split(s)
150
  new_line = spl[0]
151
  if len(spl) > 1:
152
- state['history'] = orig_history + [[prompt, new_line]]
153
  output += new_line
154
- yield state, state['history'], None, output
155
  # Stopping generation
156
  return
157
-
158
  # Keep original history untouched as we're adding just
159
  # a chunks at one moment.
160
- state['history'] = orig_history + [[prompt, new_line + stats()]]
161
- yield state, state['history'], None, output
162
 
163
  # Final line w/o statistics
164
- yield state, state['history'], None, output
165
 
166
  except (json.decoder.JSONDecodeError, BrokenPipeError):
167
  # Session was interrupted
168
  # Handled in upstream func
169
  client.close_session()
170
- state['client'] = None
171
- state['model'] = None
172
 
173
  print("Broken session!")
174
  raise
175
  except Exception:
176
  client.close_session()
177
- state['client'] = None
178
- state['model'] = None
179
 
180
  print(traceback.format_exc())
181
  raise gr.Error(traceback.format_exc(limit=3))
182
 
 
183
  def reset(state):
184
  """Resets the session and clears the chat window."""
185
  state.update(EMPTY_STATE)
186
- return state, [], ''
187
 
 
 
 
188
  with gr.Blocks() as iface_chat:
189
  gr.Markdown("""**Let's talk to Bloom in a chat!**""")
190
 
191
  with gr.Row():
192
- model = gr.Radio(["bloom", "bloomz", "bloom-7b1"], value='bloomz', label="Use model")
 
 
193
 
194
  # Additional ending sequence, at which generation shoud stop
195
- endseq = gr.CheckboxGroup(["Human:", "AI:", "\\n", "</s>", "? (question mark)", ". (dot)"],
196
- value=["Human:", "AI:", "\\n", "</s>"], label='Extra end sequences')
 
 
 
197
 
198
  # Maximum length of inference session
199
- max_length = gr.Radio([64, 128, 256, 512, 1024, 2048], value=1024, interactive=True, label="Max length")
 
 
 
 
 
200
 
201
  with gr.Row():
202
  with gr.Column():
203
  # Switch between sampling and greedy generation
204
  do_sample = gr.Checkbox(value=True, interactive=True, label="do_sample")
205
- context = gr.Textbox(lines=3, label="Initial context:", interactive=True,
206
- value="A human talks to a powerful AI that follows the human's instructions.\n"
207
- "AI is talkative, friendly, positive and provides detailed answers to any question.</s>\n"
208
- "Human: Hi!</s>\n"
209
- "AI: How can I help you?")
 
 
 
 
 
 
210
 
211
  # Only one of top_k and top_p can be set. Requires "do_sample=True" to work.
212
  top_k = gr.Number(value=0, precision=0, interactive=True, label="top_k")
@@ -214,12 +247,14 @@ with gr.Blocks() as iface_chat:
214
  # TODO num_beams
215
 
216
  # Generation temperature
217
- temperature = gr.Number(value=0.75, precision=2, interactive=True, label="Temperature")
218
-
 
219
 
220
- chat = gr.Chatbot(label='Chat window')
221
- prompt = gr.Textbox(show_label=False, label='Prompt',
222
- placeholder="Prompt Here and press Enter...").style(container=False)
 
223
 
224
  with gr.Row():
225
  button_generate = gr.Button("Generate")
@@ -231,20 +266,40 @@ with gr.Blocks() as iface_chat:
231
  # Chat history
232
  state = gr.State(EMPTY_STATE)
233
 
234
- inputs = [state, prompt, model, context, output, endseq,
235
- max_length, do_sample, top_k, top_p, temperature]
236
- outputs=[state, chat, prompt, output]
 
 
 
 
 
 
 
 
 
 
 
 
237
 
238
  prompt.submit(generate, inputs=inputs, outputs=outputs)
239
  button_generate.click(generate, inputs=inputs, outputs=outputs)
240
  button_reset.click(reset, inputs=[state], outputs=[state, chat, output])
241
 
242
- examples = gr.Examples(inputs=[context, prompt, model, do_sample, top_k, top_p, temperature],
 
243
  examples=[
244
- ["A Human talks to a powerful AI that follows the Human's instructions. "
245
- "AI is talkative, friendly, positive and provides detailed answers to any question.</s>\n"
246
- "Human: Hi!</s>\n"
247
- "AI: Hi! How can I help you?",
248
- "Could you remind me please who was Neil Armstrong?",
249
- "bloomz", True, 0, 0.9, 0.75],
250
- ])
 
 
 
 
 
 
 
 
1
+ import json
2
+ import re
3
+ import time
4
  import traceback
5
+
6
  import gradio as gr
7
+
8
  import chat_client
 
 
 
9
 
10
+ CHAT_URL = "ws://chat.petals.ml/api/v2/generate"
11
+ # CHAT_URL='ws://localhost:8000/api/v2/generate'
12
 
13
  EMPTY_STATE = {
14
+ "generate": False,
15
+ "model": None,
16
+ "client": None,
17
+ "history": [],
18
  }
19
 
20
+
21
  def generate(state, prompt, model, context, output, *args):
22
  # Save that we're in generating loop
23
+ state["generate"] = True
24
 
25
  try:
26
  yield from _generate(state, prompt, model, context, output, *args)
 
29
  # TODO This is a bit fragile because of recursive call...
30
  print("Retrying session...")
31
  context = output
32
+ output = ""
33
  yield from generate(state, prompt, model, context, output, *args)
34
  finally:
35
+ state["generate"] = False
36
+
37
+
38
+ def _generate(
39
+ state,
40
+ prompt,
41
+ model,
42
+ context,
43
+ output,
44
+ endseq,
45
+ max_length,
46
+ do_sample,
47
+ top_k,
48
+ top_p,
49
+ temperature,
50
+ ):
51
 
52
  start = time.time()
53
+ cnt = 0 # Tokens generated
54
 
55
  def stats():
56
  # Produces inline stats for generation speed
 
62
  sec_per_item = (time.time() - start) / cnt
63
  return f" | {sec_per_item:.1f} sec/t"
64
 
 
65
  eos = "</s>\n" if "bloomz" in model else "\n\n"
66
 
67
+ if state["model"] != model and output:
68
  # If the connection is resumed, output is truncated in generate().
69
+ # So this executes when user change model.
70
  context = output
71
+ output = ""
72
 
73
+ # Update widgets even before we get the first response
74
+ print("prompt", prompt)
75
+ yield state, state["history"] + [[prompt, stats()]], "", output
76
+
77
+ if (
78
+ state["model"] != model
79
+ or state["client"] == None
80
+ or state["client"].is_session() == False
81
+ ):
82
 
83
  try:
84
+ state["client"] = chat_client.ModelClient(CHAT_URL)
85
+ state["client"].open_session(f"bigscience/{model}-petals", max_length)
86
+ state["model"] = model
87
  except Exception:
88
  print(traceback.format_exc())
89
  raise gr.Error(traceback.format_exc(limit=3))
90
 
91
  else:
92
+ context = ""
 
 
93
 
94
+ client = state["client"]
95
  context += eos
 
 
96
 
97
  # Fix eventual eos token mismatch and add eos token to context and prompt
98
  if "bloomz" in model:
 
102
  context = context.replace("</s>", eos)
103
  context = re.sub(r"\n\n+", "\n\n", context)
104
  prompt2 = prompt.replace("</s>", eos) + "\n\n"
105
+
106
  prompt2 = f"{context}Human: {prompt2}AI:"
107
 
108
  # Translate checkbox items to actual sequences
 
134
 
135
  output += prompt2
136
 
137
+ orig_history = state["history"]
138
+ new_line = ""
 
 
 
139
  try:
140
+ for out in client.generate(
141
+ prompt2,
142
+ max_new_tokens=1,
143
+ do_sample=do_sample,
144
+ temperature=temperature,
145
+ top_k=top_k,
146
+ top_p=top_p,
147
+ extra_stop_sequences=seq,
148
+ ):
149
+
150
+ if not state["generate"]:
151
  client.close_session()
152
+ yield state, [], "", ""
153
  # Stopping generation
154
  return
155
 
 
162
  spl = new_line.split(s)
163
  new_line = spl[0]
164
  if len(spl) > 1:
165
+ state["history"] = orig_history + [[prompt, new_line]]
166
  output += new_line
167
+ yield state, state["history"], "", output
168
  # Stopping generation
169
  return
170
+
171
  # Keep original history untouched as we're adding just
172
  # a chunks at one moment.
173
+ state["history"] = orig_history + [[prompt, new_line + stats()]]
174
+ yield state, state["history"], "", output
175
 
176
  # Final line w/o statistics
177
+ yield state, state["history"], "", output
178
 
179
  except (json.decoder.JSONDecodeError, BrokenPipeError):
180
  # Session was interrupted
181
  # Handled in upstream func
182
  client.close_session()
183
+ state["client"] = None
184
+ state["model"] = None
185
 
186
  print("Broken session!")
187
  raise
188
  except Exception:
189
  client.close_session()
190
+ state["client"] = None
191
+ state["model"] = None
192
 
193
  print(traceback.format_exc())
194
  raise gr.Error(traceback.format_exc(limit=3))
195
 
196
+
197
  def reset(state):
198
  """Resets the session and clears the chat window."""
199
  state.update(EMPTY_STATE)
200
+ return state, [], ""
201
 
202
+
203
+ # ---------------------------------------------------------
204
+ # Defining Gradio layout
205
  with gr.Blocks() as iface_chat:
206
  gr.Markdown("""**Let's talk to Bloom in a chat!**""")
207
 
208
  with gr.Row():
209
+ model = gr.Radio(
210
+ ["bloom", "bloomz", "bloom-7b1"], value="bloomz", label="Use model"
211
+ )
212
 
213
  # Additional ending sequence, at which generation shoud stop
214
+ endseq = gr.CheckboxGroup(
215
+ ["Human:", "AI:", "\\n", "</s>", "? (question mark)", ". (dot)"],
216
+ value=["Human:", "AI:", "\\n", "</s>"],
217
+ label="Extra end sequences",
218
+ )
219
 
220
  # Maximum length of inference session
221
+ max_length = gr.Radio(
222
+ [64, 128, 256, 512, 1024, 2048],
223
+ value=1024,
224
+ interactive=True,
225
+ label="Max length",
226
+ )
227
 
228
  with gr.Row():
229
  with gr.Column():
230
  # Switch between sampling and greedy generation
231
  do_sample = gr.Checkbox(value=True, interactive=True, label="do_sample")
232
+ context = gr.Textbox(
233
+ lines=3,
234
+ label="Initial context:",
235
+ interactive=True,
236
+ value="A Human talks to a powerful AI that follows "
237
+ "the Human's instructions.\n"
238
+ "AI is talkative, friendly, positive and provides "
239
+ "detailed answers to any question.</s>\n"
240
+ "Human: Hi!</s>\n"
241
+ "AI: How can I help you?",
242
+ )
243
 
244
  # Only one of top_k and top_p can be set. Requires "do_sample=True" to work.
245
  top_k = gr.Number(value=0, precision=0, interactive=True, label="top_k")
 
247
  # TODO num_beams
248
 
249
  # Generation temperature
250
+ temperature = gr.Number(
251
+ value=0.75, precision=2, interactive=True, label="Temperature"
252
+ )
253
 
254
+ chat = gr.Chatbot(label="Chat window")
255
+ prompt = gr.Textbox(
256
+ show_label=False, label="Prompt", placeholder="Prompt Here and press Enter..."
257
+ ).style(container=False)
258
 
259
  with gr.Row():
260
  button_generate = gr.Button("Generate")
 
266
  # Chat history
267
  state = gr.State(EMPTY_STATE)
268
 
269
+ # Define button actions
270
+ inputs = [
271
+ state,
272
+ prompt,
273
+ model,
274
+ context,
275
+ output,
276
+ endseq,
277
+ max_length,
278
+ do_sample,
279
+ top_k,
280
+ top_p,
281
+ temperature,
282
+ ]
283
+ outputs = [state, chat, prompt, output]
284
 
285
  prompt.submit(generate, inputs=inputs, outputs=outputs)
286
  button_generate.click(generate, inputs=inputs, outputs=outputs)
287
  button_reset.click(reset, inputs=[state], outputs=[state, chat, output])
288
 
289
+ examples = gr.Examples(
290
+ inputs=[context, prompt, model, do_sample, top_k, top_p, temperature],
291
  examples=[
292
+ [
293
+ "A Human talks to a powerful AI that follows the Human's instructions. "
294
+ "AI is talkative, friendly, positive and provides detailed answers to any question.</s>\n"
295
+ "Human: Hi!</s>\n"
296
+ "AI: Hi! How can I help you?",
297
+ "Could you remind me please who was Neil Armstrong?",
298
+ "bloomz",
299
+ True,
300
+ 0,
301
+ 0.9,
302
+ 0.75,
303
+ ],
304
+ ],
305
+ )
prompt.py CHANGED
@@ -1,33 +1,44 @@
1
- #!/usr/bin/env python
2
- # or gradio app.py
3
-
4
  import traceback
 
5
  import gradio as gr
 
6
  import chat_client
7
- import time
8
 
9
- CHAT_URL='ws://chat.petals.ml/api/v2/generate'
10
- #CHAT_URL='ws://localhost:8000/api/v2/generate'
 
11
 
12
  def generate(state, *args):
13
  # Save that we're in generating loop
14
- state['generate'] = True
15
 
16
  try:
17
- for x in _generate(state, *args):
18
- yield x
19
  finally:
20
- state['generate'] = False
21
-
22
- def _generate(state, prompt, model, endseq, max_length,
23
- do_sample, top_k, top_p, temperature,
24
- add_stoptoken, copy_output):
 
 
 
 
 
 
 
 
 
 
 
25
 
26
  start = time.time()
27
  cnt = 0
28
 
29
  def stats():
30
  # Produces inline stats for generation speed
 
31
  if cnt == 0:
32
  return "\u2026 | ? sec/t"
33
  if cnt > time.time() - start:
@@ -70,23 +81,24 @@ def _generate(state, prompt, model, endseq, max_length,
70
  temperature = 1.0
71
 
72
  prompt2 = prompt
73
- output = ''
74
 
75
  # This render prompt dialog immediately and
76
  # don't wait to generator to return first result
77
  yield [state, prompt2, stats()]
78
 
79
  try:
80
- for out in client.generate(prompt,
81
- max_new_tokens=1,
82
- do_sample=do_sample,
83
- temperature=temperature,
84
- top_k=top_k,
85
- top_p=top_p,
86
- extra_stop_sequences=seq
87
- ):
88
-
89
- if not state['generate']:
 
90
  client.close_session()
91
  return
92
 
@@ -104,31 +116,53 @@ def _generate(state, prompt, model, endseq, max_length,
104
  print(traceback.format_exc())
105
  raise gr.Error(traceback.format_exc(limit=3))
106
 
 
107
  def stop(state):
108
  """Stops generating."""
109
  state.update({"generate": False})
110
  return state
111
 
 
 
 
112
  with gr.Blocks() as iface_prompt:
113
- gr.Markdown("""**Useful for testing raw prompts with zero, one or few-shot prompting.**""")
 
 
 
114
 
115
  with gr.Row():
116
- model = gr.Radio(["bloom", "bloomz", "bloom-7b1"], value='bloom', label="Use model")
 
 
117
 
118
  # Additional ending sequence, at which generation shoud stop
119
- endseq = gr.CheckboxGroup(["\\n", "</s>", "? (question mark)", ". (dot)"],
120
- value=["\\n", "</s>"], label='Extra end sequences')
 
 
 
121
 
122
  # Maximum length of inference session
123
- max_length = gr.Radio([64, 128, 256, 512, 1024, 2048], value=512, interactive=True, label="Max length")
 
 
 
 
 
124
 
125
  with gr.Row():
126
  with gr.Column():
127
  # Switch between sampling and greedy generation
128
  do_sample = gr.Checkbox(value=True, interactive=True, label="do_sample")
129
 
130
- # Should the app append stop sequence at the end of prompt or should it leave the prompt open?
131
- add_stoptoken = gr.Checkbox(value=True, interactive=True, label="Automatically add eos token to the prompt.")
 
 
 
 
 
132
 
133
  # Only one of top_k and top_p can be set. Requires "do_sample=True" to work.
134
  top_k = gr.Number(value=0, precision=0, interactive=True, label="top_k")
@@ -136,10 +170,12 @@ with gr.Blocks() as iface_prompt:
136
  # TODO num_beams
137
 
138
  # Generation temperature
139
- temperature = gr.Number(value=0.75, precision=2, interactive=True, label="Temperature")
 
 
140
 
141
- prompt = gr.Textbox(lines=3, label='Prompt', placeholder="Prompt Here...")
142
- state = gr.State({'generate': False})
143
 
144
  with gr.Row():
145
  button_generate = gr.Button("Generate")
@@ -148,22 +184,62 @@ with gr.Blocks() as iface_prompt:
148
  # Automatically copy the output at the end of prompt
149
  copy_output = gr.Checkbox(label="Output -> Prompt")
150
 
151
- output = gr.Textbox(lines=3, label='Output')
152
-
153
- inputs = [state, prompt, model, endseq, max_length, do_sample,
154
- top_k, top_p, temperature, add_stoptoken, copy_output]
155
- outputs = [state, prompt, output]
156
- button_generate.click(generate, inputs=inputs, outputs=outputs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
  button_stop.click(stop, inputs=[state], outputs=[state])
158
 
159
- examples = gr.Examples(inputs=[prompt, model, do_sample, top_k, top_p, temperature, add_stoptoken],
 
160
  examples=[
161
- ["The SQL command to extract all the users whose name starts with A is: ", "bloom-7b1", False, 0, 0, 1, False],
162
- ["The Spanish translation of thank you for your help is: ", "bloom-7b1", False, 0, 0, 1, False],
163
- ["A human talks to a powerful AI that follows the Human's instructions.\n"
164
- "AI is talkative, friendly, positive and provides detailed answers to any question.</s>\n"
165
- "Human: Hi!</s>\n"
166
- "AI: Hi! How can I help you?</s>\n"
167
- "Human: What's the capital of Portugal?</s>\n"
168
- "AI: ", "bloomz", True, 0, 0.9, 0.75, False]
169
- ])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
 
 
2
  import traceback
3
+
4
  import gradio as gr
5
+
6
  import chat_client
 
7
 
8
+ CHAT_URL = "ws://chat.petals.ml/api/v2/generate"
9
+ # CHAT_URL='ws://localhost:8000/api/v2/generate'
10
+
11
 
12
  def generate(state, *args):
13
  # Save that we're in generating loop
14
+ state["generate"] = True
15
 
16
  try:
17
+ yield from _generate(state, *args)
 
18
  finally:
19
+ state["generate"] = False
20
+
21
+
22
+ def _generate(
23
+ state,
24
+ prompt,
25
+ model,
26
+ endseq,
27
+ max_length,
28
+ do_sample,
29
+ top_k,
30
+ top_p,
31
+ temperature,
32
+ add_stoptoken,
33
+ copy_output,
34
+ ):
35
 
36
  start = time.time()
37
  cnt = 0
38
 
39
  def stats():
40
  # Produces inline stats for generation speed
41
+ # sec/t or t/sec depending on the speed
42
  if cnt == 0:
43
  return "\u2026 | ? sec/t"
44
  if cnt > time.time() - start:
 
81
  temperature = 1.0
82
 
83
  prompt2 = prompt
84
+ output = ""
85
 
86
  # This render prompt dialog immediately and
87
  # don't wait to generator to return first result
88
  yield [state, prompt2, stats()]
89
 
90
  try:
91
+ for out in client.generate(
92
+ prompt,
93
+ max_new_tokens=1,
94
+ do_sample=do_sample,
95
+ temperature=temperature,
96
+ top_k=top_k,
97
+ top_p=top_p,
98
+ extra_stop_sequences=seq,
99
+ ):
100
+
101
+ if not state["generate"]:
102
  client.close_session()
103
  return
104
 
 
116
  print(traceback.format_exc())
117
  raise gr.Error(traceback.format_exc(limit=3))
118
 
119
+
120
  def stop(state):
121
  """Stops generating."""
122
  state.update({"generate": False})
123
  return state
124
 
125
+
126
+ # ---------------------------------------------------------
127
+ # Defining Gradio layout
128
  with gr.Blocks() as iface_prompt:
129
+ gr.Markdown(
130
+ """**Useful for testing raw prompts with zero,
131
+ one or few-shot prompting.**"""
132
+ )
133
 
134
  with gr.Row():
135
+ model = gr.Radio(
136
+ ["bloom", "bloomz", "bloom-7b1"], value="bloom", label="Use model"
137
+ )
138
 
139
  # Additional ending sequence, at which generation shoud stop
140
+ endseq = gr.CheckboxGroup(
141
+ ["\\n", "</s>", "? (question mark)", ". (dot)"],
142
+ value=["\\n", "</s>"],
143
+ label="Extra end sequences",
144
+ )
145
 
146
  # Maximum length of inference session
147
+ max_length = gr.Radio(
148
+ [64, 128, 256, 512, 1024, 2048],
149
+ value=512,
150
+ interactive=True,
151
+ label="Max length",
152
+ )
153
 
154
  with gr.Row():
155
  with gr.Column():
156
  # Switch between sampling and greedy generation
157
  do_sample = gr.Checkbox(value=True, interactive=True, label="do_sample")
158
 
159
+ # Should the app append stop sequence at the end of prompt
160
+ # or should it leave the prompt open?
161
+ add_stoptoken = gr.Checkbox(
162
+ value=True,
163
+ interactive=True,
164
+ label="Automatically add eos token to the prompt.",
165
+ )
166
 
167
  # Only one of top_k and top_p can be set. Requires "do_sample=True" to work.
168
  top_k = gr.Number(value=0, precision=0, interactive=True, label="top_k")
 
170
  # TODO num_beams
171
 
172
  # Generation temperature
173
+ temperature = gr.Number(
174
+ value=0.75, precision=2, interactive=True, label="Temperature"
175
+ )
176
 
177
+ prompt = gr.Textbox(lines=3, label="Prompt", placeholder="Prompt Here...")
178
+ state = gr.State({"generate": False})
179
 
180
  with gr.Row():
181
  button_generate = gr.Button("Generate")
 
184
  # Automatically copy the output at the end of prompt
185
  copy_output = gr.Checkbox(label="Output -> Prompt")
186
 
187
+ output = gr.Textbox(lines=3, label="Output")
188
+
189
+ # Define button actions
190
+ button_generate.click(
191
+ generate,
192
+ inputs=[
193
+ state,
194
+ prompt,
195
+ model,
196
+ endseq,
197
+ max_length,
198
+ do_sample,
199
+ top_k,
200
+ top_p,
201
+ temperature,
202
+ add_stoptoken,
203
+ copy_output,
204
+ ],
205
+ outputs=[state, prompt, output],
206
+ )
207
  button_stop.click(stop, inputs=[state], outputs=[state])
208
 
209
+ examples = gr.Examples(
210
+ inputs=[prompt, model, do_sample, top_k, top_p, temperature, add_stoptoken],
211
  examples=[
212
+ [
213
+ "The SQL command to extract all the users whose name starts with A is: ",
214
+ "bloom-7b1",
215
+ False,
216
+ 0,
217
+ 0,
218
+ 1,
219
+ False,
220
+ ],
221
+ [
222
+ "The Spanish translation of thank you for your help is: ",
223
+ "bloom-7b1",
224
+ False,
225
+ 0,
226
+ 0,
227
+ 1,
228
+ False,
229
+ ],
230
+ [
231
+ "A human talks to a powerful AI that follows the Human's instructions.\n"
232
+ "AI is talkative, friendly, positive and provides detailed answers to any question.</s>\n"
233
+ "Human: Hi!</s>\n"
234
+ "AI: Hi! How can I help you?</s>\n"
235
+ "Human: What's the capital of Portugal?</s>\n"
236
+ "AI: ",
237
+ "bloomz",
238
+ True,
239
+ 0,
240
+ 0.9,
241
+ 0.75,
242
+ False,
243
+ ],
244
+ ],
245
+ )