Tobias Bergmann commited on
Commit
5ac9a35
·
1 Parent(s): 6ba0c05

simple GUI

Browse files
Files changed (1) hide show
  1. app.py +22 -246
app.py CHANGED
@@ -26,256 +26,32 @@ pipe = Llama(
26
  model_path=model_path
27
  )
28
 
29
- # Setup the engine
30
- #pipe = Pipeline.create(
31
- # task="text-generation",
32
- # model_path=MODEL_ID,
33
- # sequence_length=MAX_MAX_NEW_TOKENS,
34
- # prompt_sequence_length=8,
35
- # num_cores=8,
36
- #)
37
-
38
-
39
- def clear_and_save_textbox(message: str) -> Tuple[str, str]:
40
- return "", message
41
-
42
-
43
- def display_input(
44
- message: str, history: List[Tuple[str, str]]
45
- ) -> List[Tuple[str, str]]:
46
- history.append((message, ""))
47
- return history
48
-
49
-
50
- def delete_prev_fn(history: List[Tuple[str, str]]) -> Tuple[List[Tuple[str, str]], str]:
51
- try:
52
- message, _ = history.pop()
53
- except IndexError:
54
- message = ""
55
- return history, message or ""
56
-
57
- theme = gr.themes.Soft(
58
- primary_hue="blue",
59
- secondary_hue="green",
60
- )
61
-
62
- with gr.Blocks(theme=theme) as demo:
63
  gr.Markdown(DESCRIPTION)
64
-
65
- with gr.Group():
66
- chatbot = gr.Chatbot(label="Chatbot")
67
- with gr.Row():
68
- textbox = gr.Textbox(
69
- container=False,
70
- show_label=False,
71
- placeholder="Type a message...",
72
- scale=10,
73
- )
74
- submit_button = gr.Button("Submit", variant="primary", scale=1, min_width=0)
75
-
76
  with gr.Row():
77
- retry_button = gr.Button("🔄 Retry", variant="secondary")
78
- undo_button = gr.Button("↩️ Undo", variant="secondary")
79
- clear_button = gr.Button("🗑️ Clear", variant="secondary")
80
-
81
- saved_input = gr.State()
82
-
83
- gr.Examples(
84
- examples=[
85
- "Write a story about sparse neurons.",
86
- "Write a story about a summer camp.",
87
- "Make a recipe for banana bread.",
88
- "Write a cookbook for gluten-free snacks.",
89
- "Write about the role of animation in video games."
90
- ],
91
- inputs=[textbox],
92
- )
93
-
94
- max_new_tokens = gr.Slider(
95
- label="Max new tokens",
96
- value=DEFAULT_MAX_NEW_TOKENS,
97
- minimum=0,
98
- maximum=MAX_MAX_NEW_TOKENS,
99
- step=1,
100
- interactive=True,
101
- info="The maximum numbers of new tokens",
102
- )
103
- temperature = gr.Slider(
104
- label="Temperature",
105
- value=0.9,
106
- minimum=0.05,
107
- maximum=1.0,
108
- step=0.05,
109
- interactive=True,
110
- info="Higher values produce more diverse outputs",
111
- )
112
- top_p = gr.Slider(
113
- label="Top-p (nucleus) sampling",
114
- value=0.40,
115
- minimum=0.0,
116
- maximum=1,
117
- step=0.05,
118
- interactive=True,
119
- info="Higher values sample more low-probability tokens",
120
- )
121
- top_k = gr.Slider(
122
- label="Top-k sampling",
123
- value=20,
124
  minimum=1,
125
- maximum=100,
126
- step=1,
127
- interactive=True,
128
- info="Sample from the top_k most likely tokens",
129
- )
130
- reptition_penalty = gr.Slider(
131
- label="Repetition penalty",
132
- value=1.2,
133
- minimum=1.0,
134
- maximum=2.0,
135
- step=0.05,
136
- interactive=True,
137
- info="Penalize repeated tokens",
138
- )
139
-
140
- # Generation inference
141
- def generate(
142
- message,
143
- history,
144
- max_new_tokens: int,
145
- temperature: float,
146
- top_p: float,
147
- top_k: int,
148
- reptition_penalty: float,
149
- ):
150
- generation_config = {
151
- "max_new_tokens": max_new_tokens,
152
- "do_sample": True,
153
- "temperature": temperature,
154
- "top_p": top_p,
155
- "top_k": top_k,
156
- "reptition_penalty": reptition_penalty,
157
- }
158
-
159
- conversation = []
160
- conversation.append({"role": "user", "content": message})
161
-
162
- formatted_conversation = pipe.tokenizer.apply_chat_template(
163
- conversation, tokenize=False, add_generation_prompt=True
164
- )
165
-
166
- inference = pipe(
167
- sequences=formatted_conversation,
168
- generation_config=generation_config,
169
- streaming=True,
170
- )
171
-
172
- for token in inference:
173
- history[-1][1] += token.generations[0].text
174
- yield history
175
-
176
- print(pipe.timer_manager)
177
-
178
- # Hooking up all the buttons
179
- textbox.submit(
180
- fn=clear_and_save_textbox,
181
- inputs=textbox,
182
- outputs=[textbox, saved_input],
183
- api_name=False,
184
- queue=False,
185
- ).then(
186
- fn=display_input,
187
- inputs=[saved_input, chatbot],
188
- outputs=chatbot,
189
- api_name=False,
190
- queue=False,
191
- ).success(
192
- generate,
193
- inputs=[
194
- saved_input,
195
- chatbot,
196
- max_new_tokens,
197
- temperature,
198
- top_p,
199
- top_k,
200
- reptition_penalty,
201
- ],
202
- outputs=[chatbot],
203
- api_name=False,
204
- )
205
-
206
- submit_button.click(
207
- fn=clear_and_save_textbox,
208
- inputs=textbox,
209
- outputs=[textbox, saved_input],
210
- api_name=False,
211
- queue=False,
212
- ).then(
213
- fn=display_input,
214
- inputs=[saved_input, chatbot],
215
- outputs=chatbot,
216
- api_name=False,
217
- queue=False,
218
- ).success(
219
- generate,
220
- inputs=[
221
- saved_input,
222
- chatbot,
223
- max_new_tokens,
224
- temperature,
225
- top_p,
226
- top_k,
227
- reptition_penalty,
228
- ],
229
- outputs=[chatbot],
230
- api_name=False,
231
- )
232
-
233
- retry_button.click(
234
- fn=delete_prev_fn,
235
- inputs=chatbot,
236
- outputs=[chatbot, saved_input],
237
- api_name=False,
238
- queue=False,
239
- ).then(
240
- fn=display_input,
241
- inputs=[saved_input, chatbot],
242
- outputs=chatbot,
243
- api_name=False,
244
- queue=False,
245
- ).then(
246
- generate,
247
- inputs=[
248
- saved_input,
249
- chatbot,
250
- max_new_tokens,
251
- temperature,
252
- top_p,
253
- top_k,
254
- reptition_penalty,
255
- ],
256
- outputs=[chatbot],
257
- api_name=False,
258
- )
259
-
260
- undo_button.click(
261
- fn=delete_prev_fn,
262
- inputs=chatbot,
263
- outputs=[chatbot, saved_input],
264
- api_name=False,
265
- queue=False,
266
- ).then(
267
- fn=lambda x: x,
268
- inputs=[saved_input],
269
- outputs=textbox,
270
- api_name=False,
271
- queue=False,
272
  )
 
 
273
 
274
- clear_button.click(
275
- fn=lambda: ([], ""),
276
- outputs=[chatbot, saved_input],
277
- queue=False,
278
- api_name=False,
279
- )
280
 
281
  demo.queue().launch(share=True)
 
26
  model_path=model_path
27
  )
28
 
29
+ def predict(message: str, history: List[List[str]], max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS):
30
+ if not message:
31
+ return "", history
32
+ prompt = message
33
+ output = pipe(
34
+ prompt,
35
+ max_tokens=max_new_tokens,
36
+ stop=["</s>"],
37
+ )
38
+ reply = output['choices'][0]['text']
39
+ history.append([message, reply])
40
+ return "", history
41
+
42
+ with gr.Blocks() as demo:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  gr.Markdown(DESCRIPTION)
44
+ chatbot = gr.Chatbot()
 
 
 
 
 
 
 
 
 
 
 
45
  with gr.Row():
46
+ textbox = gr.Textbox(placeholder="Type here and press enter")
47
+ max_new_tokens_slider = gr.Slider(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  minimum=1,
49
+ maximum=MAX_MAX_NEW_TOKENS,
50
+ value=DEFAULT_MAX_NEW_TOKENS,
51
+ label="Max New Tokens",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  )
53
+ textbox.submit(predict, [textbox, chatbot, max_new_tokens_slider], [textbox, chatbot])
54
+
55
 
 
 
 
 
 
 
56
 
57
  demo.queue().launch(share=True)