gabrielclark3330 commited on
Commit
d4ab11d
1 Parent(s): 715d92e

Streaming output

Browse files
Files changed (1) hide show
  1. app.py +97 -151
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import os
2
  import gradio as gr
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
@@ -78,43 +79,6 @@ def generate_response(chat_history, max_new_tokens, model, tokenizer):
78
  with gr.Blocks() as demo:
79
  gr.Markdown("# Zamba2 Model Selector")
80
  with gr.Tabs():
81
- with gr.TabItem("2.7B Instruct Model"):
82
- gr.Markdown("### Zamba2-2.7B Instruct Model")
83
- with gr.Column():
84
- chat_history_2_7B_instruct = gr.State([])
85
- chatbot_2_7B_instruct = gr.Chatbot()
86
- message_2_7B_instruct = gr.Textbox(lines=2, placeholder="Enter your message...", label="Your Message")
87
- with gr.Accordion("Generation Parameters", open=False):
88
- max_new_tokens_2_7B_instruct = gr.Slider(50, 1000, step=50, value=500, label="Max New Tokens")
89
- # temperature_2_7B_instruct = gr.Slider(0.1, 1.5, step=0.1, value=0.2, label="Temperature")
90
- # top_k_2_7B_instruct = gr.Slider(1, 100, step=1, value=50, label="Top K")
91
- # top_p_2_7B_instruct = gr.Slider(0.1, 1.0, step=0.1, value=1.0, label="Top P")
92
- # repetition_penalty_2_7B_instruct = gr.Slider(1.0, 2.0, step=0.1, value=1.2, label="Repetition Penalty")
93
- # num_beams_2_7B_instruct = gr.Slider(1, 10, step=1, value=1, label="Number of Beams")
94
- # length_penalty_2_7B_instruct = gr.Slider(0.0, 2.0, step=0.1, value=1.0, label="Length Penalty")
95
-
96
- def user_message_2_7B_instruct(message, chat_history):
97
- chat_history = chat_history + [[message, None]]
98
- return gr.update(value=""), chat_history, chat_history
99
-
100
- def bot_response_2_7B_instruct(chat_history, max_new_tokens):
101
- response = generate_response(chat_history, max_new_tokens, model_2_7B_instruct, tokenizer_2_7B_instruct)
102
- chat_history[-1][1] = response
103
- return chat_history, chat_history
104
-
105
- send_button_2_7B_instruct = gr.Button("Send")
106
- send_button_2_7B_instruct.click(
107
- fn=user_message_2_7B_instruct,
108
- inputs=[message_2_7B_instruct, chat_history_2_7B_instruct],
109
- outputs=[message_2_7B_instruct, chat_history_2_7B_instruct, chatbot_2_7B_instruct]
110
- ).then(
111
- fn=bot_response_2_7B_instruct,
112
- inputs=[
113
- chat_history_2_7B_instruct,
114
- max_new_tokens_2_7B_instruct
115
- ],
116
- outputs=[chat_history_2_7B_instruct, chatbot_2_7B_instruct]
117
- )
118
  with gr.TabItem("7B Instruct Model"):
119
  gr.Markdown("### Zamba2-7B Instruct Model")
120
  with gr.Column():
@@ -152,19 +116,58 @@ with gr.Blocks() as demo:
152
  ],
153
  outputs=[chat_history_7B_instruct, chatbot_7B_instruct]
154
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
 
156
  if __name__ == "__main__":
157
  demo.queue().launch(max_threads=1)
158
-
159
  '''
160
 
161
  import os
162
  import gradio as gr
163
- from transformers import AutoTokenizer, AutoModelForCausalLM
164
  import torch
 
 
165
 
166
  model_name_2_7B_instruct = "Zyphra/Zamba2-2.7B-instruct"
167
  model_name_7B_instruct = "Zyphra/Zamba2-7B-instruct"
 
168
 
169
  tokenizer_2_7B_instruct = AutoTokenizer.from_pretrained(model_name_2_7B_instruct)
170
  model_2_7B_instruct = AutoModelForCausalLM.from_pretrained(
@@ -176,131 +179,47 @@ model_7B_instruct = AutoModelForCausalLM.from_pretrained(
176
  model_name_7B_instruct, device_map="cuda", torch_dtype=torch.bfloat16
177
  )
178
 
179
- def extract_assistant_response(generated_text):
180
- assistant_token = '<|im_start|> assistant'
181
- end_token = '<|im_end|>'
182
- start_idx = generated_text.rfind(assistant_token)
183
- if start_idx == -1:
184
- # Assistant token not found
185
- return generated_text.strip()
186
- start_idx += len(assistant_token)
187
- end_idx = generated_text.find(end_token, start_idx)
188
- if end_idx == -1:
189
- # End token not found, return from start_idx to end
190
- return generated_text[start_idx:].strip()
191
- else:
192
- return generated_text[start_idx:end_idx].strip()
193
-
194
- def generate_response_2_7B_instruct(chat_history, max_new_tokens):
195
  sample = []
196
  for turn in chat_history:
197
  if turn[0]:
198
  sample.append({'role': 'user', 'content': turn[0]})
199
  if turn[1]:
200
  sample.append({'role': 'assistant', 'content': turn[1]})
201
- chat_sample = tokenizer_2_7B_instruct.apply_chat_template(sample, tokenize=False)
202
- input_ids = tokenizer_2_7B_instruct(chat_sample, return_tensors='pt', add_special_tokens=False).input_ids.to(model_2_7B_instruct.device)
203
 
204
- # Handle context length limit
205
- max_context_length = 4096
206
  max_new_tokens = int(max_new_tokens)
207
  max_input_length = max_context_length - max_new_tokens
208
- if input_ids.size(1) > max_input_length:
209
- input_ids = input_ids[:, -max_input_length:] # Truncate from the left (oldest tokens)
 
 
210
 
211
- with torch.no_grad():
212
- outputs = model_2_7B_instruct.generate(
213
- input_ids=input_ids,
214
- max_new_tokens=max_new_tokens,
215
- return_dict_in_generate=False,
216
- output_scores=False,
217
- use_cache=True,
218
- num_beams=1,
219
- do_sample=False
220
- )
221
-
222
- generated_text = tokenizer_2_7B_instruct.decode(outputs[0])
223
- assistant_response = extract_assistant_response(generated_text)
224
 
225
- del input_ids
226
- del outputs
227
- torch.cuda.empty_cache()
228
- return assistant_response
229
 
230
- def generate_response_7B_instruct(chat_history, max_new_tokens):
231
- sample = []
232
- for turn in chat_history:
233
- if turn[0]:
234
- sample.append({'role': 'user', 'content': turn[0]})
235
- if turn[1]:
236
- sample.append({'role': 'assistant', 'content': turn[1]})
237
- chat_sample = tokenizer_7B_instruct.apply_chat_template(sample, tokenize=False)
238
- input_ids = tokenizer_7B_instruct(chat_sample, return_tensors='pt', add_special_tokens=False).input_ids.to(model_7B_instruct.device)
239
 
240
- # Handle context length limit
241
- max_context_length = 4096
242
- max_new_tokens = int(max_new_tokens)
243
- max_input_length = max_context_length - max_new_tokens
244
- if input_ids.size(1) > max_input_length:
245
- input_ids = input_ids[:, -max_input_length:] # Truncate from the left (oldest tokens)
246
-
247
- with torch.no_grad():
248
- outputs = model_7B_instruct.generate(
249
- input_ids=input_ids,
250
- max_new_tokens=max_new_tokens,
251
- return_dict_in_generate=False,
252
- output_scores=False,
253
- use_cache=True,
254
- num_beams=1,
255
- do_sample=False
256
- )
257
-
258
- generated_text = tokenizer_7B_instruct.decode(outputs[0])
259
- assistant_response = extract_assistant_response(generated_text)
260
 
 
261
  del input_ids
262
- del outputs
263
  torch.cuda.empty_cache()
264
- return assistant_response
265
 
266
  with gr.Blocks() as demo:
267
  gr.Markdown("# Zamba2 Model Selector")
268
  with gr.Tabs():
269
- with gr.TabItem("2.7B Instruct Model"):
270
- gr.Markdown("### Zamba2-2.7B Instruct Model")
271
- with gr.Column():
272
- chat_history_2_7B_instruct = gr.State([])
273
- chatbot_2_7B_instruct = gr.Chatbot()
274
- message_2_7B_instruct = gr.Textbox(lines=2, placeholder="Enter your message...", label="Your Message")
275
- with gr.Accordion("Generation Parameters", open=False):
276
- max_new_tokens_2_7B_instruct = gr.Slider(50, 1000, step=50, value=500, label="Max New Tokens")
277
-
278
- def user_message_2_7B_instruct(message, chat_history):
279
- chat_history = chat_history + [[message, None]]
280
- return gr.update(value=""), chat_history, chat_history
281
-
282
- def bot_response_2_7B_instruct(chat_history, max_new_tokens):
283
- response = generate_response_2_7B_instruct(chat_history, max_new_tokens)
284
- chat_history[-1][1] = response
285
- return chat_history, chat_history
286
-
287
- send_button_2_7B_instruct = gr.Button("Send")
288
- send_button_2_7B_instruct.click(
289
- fn=user_message_2_7B_instruct,
290
- inputs=[message_2_7B_instruct, chat_history_2_7B_instruct],
291
- outputs=[message_2_7B_instruct, chat_history_2_7B_instruct, chatbot_2_7B_instruct]
292
- ).then(
293
- fn=bot_response_2_7B_instruct,
294
- inputs=[
295
- chat_history_2_7B_instruct,
296
- max_new_tokens_2_7B_instruct
297
- ],
298
- outputs=[chat_history_2_7B_instruct, chatbot_2_7B_instruct]
299
- )
300
  with gr.TabItem("7B Instruct Model"):
301
  gr.Markdown("### Zamba2-7B Instruct Model")
302
  with gr.Column():
303
- chat_history_7B_instruct = gr.State([])
304
  chatbot_7B_instruct = gr.Chatbot()
305
  message_7B_instruct = gr.Textbox(lines=2, placeholder="Enter your message...", label="Your Message")
306
  with gr.Accordion("Generation Parameters", open=False):
@@ -311,9 +230,10 @@ with gr.Blocks() as demo:
311
  return gr.update(value=""), chat_history, chat_history
312
 
313
  def bot_response_7B_instruct(chat_history, max_new_tokens):
314
- response = generate_response_7B_instruct(chat_history, max_new_tokens)
315
- chat_history[-1][1] = response
316
- return chat_history, chat_history
 
317
 
318
  send_button_7B_instruct = gr.Button("Send")
319
  send_button_7B_instruct.click(
@@ -322,13 +242,39 @@ with gr.Blocks() as demo:
322
  outputs=[message_7B_instruct, chat_history_7B_instruct, chatbot_7B_instruct]
323
  ).then(
324
  fn=bot_response_7B_instruct,
325
- inputs=[
326
- chat_history_7B_instruct,
327
- max_new_tokens_7B_instruct
328
- ],
329
- outputs=[chat_history_7B_instruct, chatbot_7B_instruct]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
330
  )
331
 
332
  if __name__ == "__main__":
333
- demo.queue().launch()
334
- '''
 
1
+ '''
2
  import os
3
  import gradio as gr
4
  from transformers import AutoTokenizer, AutoModelForCausalLM
 
79
  with gr.Blocks() as demo:
80
  gr.Markdown("# Zamba2 Model Selector")
81
  with gr.Tabs():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  with gr.TabItem("7B Instruct Model"):
83
  gr.Markdown("### Zamba2-7B Instruct Model")
84
  with gr.Column():
 
116
  ],
117
  outputs=[chat_history_7B_instruct, chatbot_7B_instruct]
118
  )
119
+ with gr.TabItem("2.7B Instruct Model"):
120
+ gr.Markdown("### Zamba2-2.7B Instruct Model")
121
+ with gr.Column():
122
+ chat_history_2_7B_instruct = gr.State([])
123
+ chatbot_2_7B_instruct = gr.Chatbot()
124
+ message_2_7B_instruct = gr.Textbox(lines=2, placeholder="Enter your message...", label="Your Message")
125
+ with gr.Accordion("Generation Parameters", open=False):
126
+ max_new_tokens_2_7B_instruct = gr.Slider(50, 1000, step=50, value=500, label="Max New Tokens")
127
+ # temperature_2_7B_instruct = gr.Slider(0.1, 1.5, step=0.1, value=0.2, label="Temperature")
128
+ # top_k_2_7B_instruct = gr.Slider(1, 100, step=1, value=50, label="Top K")
129
+ # top_p_2_7B_instruct = gr.Slider(0.1, 1.0, step=0.1, value=1.0, label="Top P")
130
+ # repetition_penalty_2_7B_instruct = gr.Slider(1.0, 2.0, step=0.1, value=1.2, label="Repetition Penalty")
131
+ # num_beams_2_7B_instruct = gr.Slider(1, 10, step=1, value=1, label="Number of Beams")
132
+ # length_penalty_2_7B_instruct = gr.Slider(0.0, 2.0, step=0.1, value=1.0, label="Length Penalty")
133
+
134
+ def user_message_2_7B_instruct(message, chat_history):
135
+ chat_history = chat_history + [[message, None]]
136
+ return gr.update(value=""), chat_history, chat_history
137
+
138
+ def bot_response_2_7B_instruct(chat_history, max_new_tokens):
139
+ response = generate_response(chat_history, max_new_tokens, model_2_7B_instruct, tokenizer_2_7B_instruct)
140
+ chat_history[-1][1] = response
141
+ return chat_history, chat_history
142
+
143
+ send_button_2_7B_instruct = gr.Button("Send")
144
+ send_button_2_7B_instruct.click(
145
+ fn=user_message_2_7B_instruct,
146
+ inputs=[message_2_7B_instruct, chat_history_2_7B_instruct],
147
+ outputs=[message_2_7B_instruct, chat_history_2_7B_instruct, chatbot_2_7B_instruct]
148
+ ).then(
149
+ fn=bot_response_2_7B_instruct,
150
+ inputs=[
151
+ chat_history_2_7B_instruct,
152
+ max_new_tokens_2_7B_instruct
153
+ ],
154
+ outputs=[chat_history_2_7B_instruct, chatbot_2_7B_instruct]
155
+ )
156
 
157
  if __name__ == "__main__":
158
  demo.queue().launch(max_threads=1)
 
159
  '''
160
 
161
  import os
162
  import gradio as gr
163
+ from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
164
  import torch
165
+ import threading
166
+ import re
167
 
168
  model_name_2_7B_instruct = "Zyphra/Zamba2-2.7B-instruct"
169
  model_name_7B_instruct = "Zyphra/Zamba2-7B-instruct"
170
+ max_context_length = 4096
171
 
172
  tokenizer_2_7B_instruct = AutoTokenizer.from_pretrained(model_name_2_7B_instruct)
173
  model_2_7B_instruct = AutoModelForCausalLM.from_pretrained(
 
179
  model_name_7B_instruct, device_map="cuda", torch_dtype=torch.bfloat16
180
  )
181
 
182
+ def generate_response(chat_history, max_new_tokens, model, tokenizer):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
  sample = []
184
  for turn in chat_history:
185
  if turn[0]:
186
  sample.append({'role': 'user', 'content': turn[0]})
187
  if turn[1]:
188
  sample.append({'role': 'assistant', 'content': turn[1]})
189
+ chat_sample = tokenizer.apply_chat_template(sample, tokenize=False)
190
+ input_ids = tokenizer(chat_sample, return_tensors='pt', add_special_tokens=False).to(model.device)
191
 
 
 
192
  max_new_tokens = int(max_new_tokens)
193
  max_input_length = max_context_length - max_new_tokens
194
+ if input_ids['input_ids'].size(1) > max_input_length:
195
+ input_ids['input_ids'] = input_ids['input_ids'][:, -max_input_length:]
196
+ if 'attention_mask' in input_ids:
197
+ input_ids['attention_mask'] = input_ids['attention_mask'][:, -max_input_length:]
198
 
199
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
200
+ generation_kwargs = dict(**input_ids, max_new_tokens=int(max_new_tokens), streamer=streamer)
 
 
 
 
 
 
 
 
 
 
 
201
 
202
+ thread = threading.Thread(target=model.generate, kwargs=generation_kwargs)
203
+ thread.start()
 
 
204
 
205
+ assistant_response = ""
 
 
 
 
 
 
 
 
206
 
207
+ for new_text in streamer:
208
+ new_text = re.sub(r'^\s*(?i:assistant)[:\s]*', '', new_text)
209
+ assistant_response += new_text
210
+ yield assistant_response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
211
 
212
+ thread.join()
213
  del input_ids
 
214
  torch.cuda.empty_cache()
 
215
 
216
  with gr.Blocks() as demo:
217
  gr.Markdown("# Zamba2 Model Selector")
218
  with gr.Tabs():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
219
  with gr.TabItem("7B Instruct Model"):
220
  gr.Markdown("### Zamba2-7B Instruct Model")
221
  with gr.Column():
222
+ chat_history_7B_instruct = gr.State([])
223
  chatbot_7B_instruct = gr.Chatbot()
224
  message_7B_instruct = gr.Textbox(lines=2, placeholder="Enter your message...", label="Your Message")
225
  with gr.Accordion("Generation Parameters", open=False):
 
230
  return gr.update(value=""), chat_history, chat_history
231
 
232
  def bot_response_7B_instruct(chat_history, max_new_tokens):
233
+ assistant_response_generator = generate_response(chat_history, max_new_tokens, model_7B_instruct, tokenizer_7B_instruct)
234
+ for assistant_response in assistant_response_generator:
235
+ chat_history[-1][1] = assistant_response
236
+ yield chat_history
237
 
238
  send_button_7B_instruct = gr.Button("Send")
239
  send_button_7B_instruct.click(
 
242
  outputs=[message_7B_instruct, chat_history_7B_instruct, chatbot_7B_instruct]
243
  ).then(
244
  fn=bot_response_7B_instruct,
245
+ inputs=[chat_history_7B_instruct, max_new_tokens_7B_instruct],
246
+ outputs=chatbot_7B_instruct,
247
+ )
248
+
249
+ with gr.TabItem("2.7B Instruct Model"):
250
+ gr.Markdown("### Zamba2-2.7B Instruct Model")
251
+ with gr.Column():
252
+ chat_history_2_7B_instruct = gr.State([])
253
+ chatbot_2_7B_instruct = gr.Chatbot()
254
+ message_2_7B_instruct = gr.Textbox(lines=2, placeholder="Enter your message...", label="Your Message")
255
+ with gr.Accordion("Generation Parameters", open=False):
256
+ max_new_tokens_2_7B_instruct = gr.Slider(50, 1000, step=50, value=500, label="Max New Tokens")
257
+
258
+ def user_message_2_7B_instruct(message, chat_history):
259
+ chat_history = chat_history + [[message, None]]
260
+ return gr.update(value=""), chat_history, chat_history
261
+
262
+ def bot_response_2_7B_instruct(chat_history, max_new_tokens):
263
+ assistant_response_generator = generate_response(chat_history, max_new_tokens, model_2_7B_instruct, tokenizer_2_7B_instruct)
264
+ for assistant_response in assistant_response_generator:
265
+ chat_history[-1][1] = assistant_response
266
+ yield chat_history
267
+
268
+ send_button_2_7B_instruct = gr.Button("Send")
269
+ send_button_2_7B_instruct.click(
270
+ fn=user_message_2_7B_instruct,
271
+ inputs=[message_2_7B_instruct, chat_history_2_7B_instruct],
272
+ outputs=[message_2_7B_instruct, chat_history_2_7B_instruct, chatbot_2_7B_instruct]
273
+ ).then(
274
+ fn=bot_response_2_7B_instruct,
275
+ inputs=[chat_history_2_7B_instruct, max_new_tokens_2_7B_instruct],
276
+ outputs=chatbot_2_7B_instruct,
277
  )
278
 
279
  if __name__ == "__main__":
280
+ demo.queue().launch(max_threads=1)