ToonTownTommy commited on
Commit
0b882ba
·
verified ·
1 Parent(s): 379a016

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +418 -0
app.py ADDED
@@ -0,0 +1,418 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ from typing import Iterator
4
+
5
+ import gradio as gr
6
+ from dotenv import load_dotenv
7
+ from distutils.util import strtobool
8
+
9
+ from llama2_wrapper import LLAMA2_WRAPPER
10
+
11
+ import logging
12
+
13
+ from prompts.utils import PromtsContainer
14
+
15
+
16
+ def main():
17
+ parser = argparse.ArgumentParser()
18
+ parser.add_argument("--model_path", type=str, default="", help="model path")
19
+ parser.add_argument(
20
+ "--backend_type",
21
+ type=str,
22
+ default="",
23
+ help="Backend options: llama.cpp, gptq, transformers",
24
+ )
25
+ parser.add_argument(
26
+ "--load_in_8bit",
27
+ type=bool,
28
+ default=False,
29
+ help="Whether to use bitsandbytes 8 bit.",
30
+ )
31
+ parser.add_argument(
32
+ "--share",
33
+ type=bool,
34
+ default=False,
35
+ help="Whether to share public for gradio.",
36
+ )
37
+ args = parser.parse_args()
38
+
39
+ load_dotenv()
40
+
41
+ DEFAULT_SYSTEM_PROMPT = os.getenv("DEFAULT_SYSTEM_PROMPT", "")
42
+ MAX_MAX_NEW_TOKENS = int(os.getenv("MAX_MAX_NEW_TOKENS", 2048))
43
+ DEFAULT_MAX_NEW_TOKENS = int(os.getenv("DEFAULT_MAX_NEW_TOKENS", 1024))
44
+ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", 4000))
45
+
46
+ MODEL_PATH = os.getenv("MODEL_PATH")
47
+ assert MODEL_PATH is not None, f"MODEL_PATH is required, got: {MODEL_PATH}"
48
+ BACKEND_TYPE = os.getenv("BACKEND_TYPE")
49
+ assert BACKEND_TYPE is not None, f"BACKEND_TYPE is required, got: {BACKEND_TYPE}"
50
+
51
+ LOAD_IN_8BIT = bool(strtobool(os.getenv("LOAD_IN_8BIT", "True")))
52
+
53
+ if args.model_path != "":
54
+ MODEL_PATH = args.model_path
55
+ if args.backend_type != "":
56
+ BACKEND_TYPE = args.backend_type
57
+ if args.load_in_8bit:
58
+ LOAD_IN_8BIT = True
59
+
60
+ llama2_wrapper = LLAMA2_WRAPPER(
61
+ model_path=MODEL_PATH,
62
+ backend_type=BACKEND_TYPE,
63
+ max_tokens=MAX_INPUT_TOKEN_LENGTH,
64
+ load_in_8bit=LOAD_IN_8BIT,
65
+ # verbose=True,
66
+ )
67
+
68
+ DESCRIPTION = """
69
+ # llama2-webui
70
+ """
71
+ DESCRIPTION2 = """
72
+ - Supporting models: [Llama-2-7b](https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGML)/[13b](https://huggingface.co/llamaste/Llama-2-13b-chat-hf)/[70b](https://huggingface.co/llamaste/Llama-2-70b-chat-hf), [Llama-2-GPTQ](https://huggingface.co/TheBloke/Llama-2-7b-Chat-GPTQ), [Llama-2-GGML](https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGML), [CodeLlama](https://huggingface.co/TheBloke/CodeLlama-7B-Instruct-GPTQ) ...
73
+ - Supporting model backends: [tranformers](https://github.com/huggingface/transformers), [bitsandbytes(8-bit inference)](https://github.com/TimDettmers/bitsandbytes), [AutoGPTQ(4-bit inference)](https://github.com/PanQiWei/AutoGPTQ), [llama.cpp](https://github.com/ggerganov/llama.cpp)
74
+ """
75
+
76
+ def clear_and_save_textbox(message: str) -> tuple[str, str]:
77
+ return "", message
78
+
79
+ def save_textbox_for_prompt(message: str) -> str:
80
+ logging.info("start save_textbox_from_prompt")
81
+ message = convert_summary_to_prompt(message)
82
+ return message
83
+
84
+ def display_input(
85
+ message: str, history: list[tuple[str, str]]
86
+ ) -> list[tuple[str, str]]:
87
+ history.append((message, ""))
88
+ return history
89
+
90
+ def delete_prev_fn(
91
+ history: list[tuple[str, str]]
92
+ ) -> tuple[list[tuple[str, str]], str]:
93
+ try:
94
+ message, _ = history.pop()
95
+ except IndexError:
96
+ message = ""
97
+ return history, message or ""
98
+
99
+ def generate(
100
+ message: str,
101
+ history_with_input: list[tuple[str, str]],
102
+ system_prompt: str,
103
+ max_new_tokens: int,
104
+ temperature: float,
105
+ top_p: float,
106
+ top_k: int,
107
+ ) -> Iterator[list[tuple[str, str]]]:
108
+ if max_new_tokens > MAX_MAX_NEW_TOKENS:
109
+ raise ValueError
110
+ try:
111
+ history = history_with_input[:-1]
112
+ generator = llama2_wrapper.run(
113
+ message,
114
+ history,
115
+ system_prompt,
116
+ max_new_tokens,
117
+ temperature,
118
+ top_p,
119
+ top_k,
120
+ )
121
+ try:
122
+ first_response = next(generator)
123
+ yield history + [(message, first_response)]
124
+ except StopIteration:
125
+ yield history + [(message, "")]
126
+ for response in generator:
127
+ yield history + [(message, response)]
128
+ except Exception as e:
129
+ logging.exception(e)
130
+
131
+ def check_input_token_length(
132
+ message: str, chat_history: list[tuple[str, str]], system_prompt: str
133
+ ) -> None:
134
+ input_token_length = llama2_wrapper.get_input_token_length(
135
+ message, chat_history, system_prompt
136
+ )
137
+ if input_token_length > MAX_INPUT_TOKEN_LENGTH:
138
+ raise gr.Error(
139
+ f"The accumulated input is too long ({input_token_length} > {MAX_INPUT_TOKEN_LENGTH}). Clear your chat history and try again."
140
+ )
141
+
142
+ prompts_container = PromtsContainer()
143
+ prompts = prompts_container.get_prompts_tab_dict()
144
+ default_prompts_checkbox = False
145
+ default_advanced_checkbox = False
146
+
147
+ def convert_summary_to_prompt(summary):
148
+ return prompts_container.get_prompt_by_summary(summary)
149
+
150
+ def two_columns_list(tab_data, chatbot):
151
+ result = []
152
+ for i in range(int(len(tab_data) / 2) + 1):
153
+ row = gr.Row()
154
+ with row:
155
+ for j in range(2):
156
+ index = 2 * i + j
157
+ if index >= len(tab_data):
158
+ break
159
+ item = tab_data[index]
160
+ with gr.Group():
161
+ gr.HTML(
162
+ f'<p style="color: black; font-weight: bold;">{item["act"]}</p>'
163
+ )
164
+ prompt_text = gr.Button(
165
+ label="",
166
+ value=f"{item['summary']}",
167
+ size="sm",
168
+ elem_classes="text-left-aligned",
169
+ )
170
+ prompt_text.click(
171
+ fn=save_textbox_for_prompt,
172
+ inputs=prompt_text,
173
+ outputs=saved_input,
174
+ api_name=False,
175
+ queue=True,
176
+ ).then(
177
+ fn=display_input,
178
+ inputs=[saved_input, chatbot],
179
+ outputs=chatbot,
180
+ api_name=False,
181
+ queue=True,
182
+ ).then(
183
+ fn=check_input_token_length,
184
+ inputs=[saved_input, chatbot, system_prompt],
185
+ api_name=False,
186
+ queue=False,
187
+ ).success(
188
+ fn=generate,
189
+ inputs=[
190
+ saved_input,
191
+ chatbot,
192
+ system_prompt,
193
+ max_new_tokens,
194
+ temperature,
195
+ top_p,
196
+ top_k,
197
+ ],
198
+ outputs=chatbot,
199
+ api_name=False,
200
+ )
201
+ result.append(row)
202
+ return result
203
+
204
+ CSS = """
205
+ .contain { display: flex; flex-direction: column;}
206
+ #component-0 #component-1 #component-2 #component-4 #component-5 { height:71vh !important; }
207
+ #component-0 #component-1 #component-24 > div:nth-child(2) { height:80vh !important; overflow-y:auto }
208
+ .text-left-aligned {text-align: left !important; font-size: 16px;}
209
+ """
210
+ with gr.Blocks(css=CSS) as demo:
211
+ with gr.Row(equal_height=True):
212
+ with gr.Column(scale=2):
213
+ gr.Markdown(DESCRIPTION)
214
+ with gr.Group():
215
+ chatbot = gr.Chatbot(label="Chatbot")
216
+ with gr.Row():
217
+ textbox = gr.Textbox(
218
+ container=False,
219
+ show_label=False,
220
+ placeholder="Type a message...",
221
+ scale=10,
222
+ )
223
+ submit_button = gr.Button(
224
+ "Submit", variant="primary", scale=1, min_width=0
225
+ )
226
+ with gr.Row():
227
+ retry_button = gr.Button("🔄 Retry", variant="secondary")
228
+ undo_button = gr.Button("↩️ Undo", variant="secondary")
229
+ clear_button = gr.Button("🗑️ Clear", variant="secondary")
230
+
231
+ saved_input = gr.State()
232
+ with gr.Row():
233
+ advanced_checkbox = gr.Checkbox(
234
+ label="Advanced",
235
+ value=default_prompts_checkbox,
236
+ container=False,
237
+ elem_classes="min_check",
238
+ )
239
+ prompts_checkbox = gr.Checkbox(
240
+ label="Prompts",
241
+ value=default_prompts_checkbox,
242
+ container=False,
243
+ elem_classes="min_check",
244
+ )
245
+ with gr.Column(visible=default_advanced_checkbox) as advanced_column:
246
+ system_prompt = gr.Textbox(
247
+ label="System prompt", value=DEFAULT_SYSTEM_PROMPT, lines=6
248
+ )
249
+ max_new_tokens = gr.Slider(
250
+ label="Max new tokens",
251
+ minimum=1,
252
+ maximum=MAX_MAX_NEW_TOKENS,
253
+ step=1,
254
+ value=DEFAULT_MAX_NEW_TOKENS,
255
+ )
256
+ temperature = gr.Slider(
257
+ label="Temperature",
258
+ minimum=0.1,
259
+ maximum=4.0,
260
+ step=0.1,
261
+ value=1.0,
262
+ )
263
+ top_p = gr.Slider(
264
+ label="Top-p (nucleus sampling)",
265
+ minimum=0.05,
266
+ maximum=1.0,
267
+ step=0.05,
268
+ value=0.95,
269
+ )
270
+ top_k = gr.Slider(
271
+ label="Top-k",
272
+ minimum=1,
273
+ maximum=1000,
274
+ step=1,
275
+ value=50,
276
+ )
277
+ with gr.Column(scale=1, visible=default_prompts_checkbox) as prompt_column:
278
+ gr.HTML(
279
+ '<p style="color: green; font-weight: bold;font-size: 16px;">\N{four leaf clover} prompts</p>'
280
+ )
281
+ for k, v in prompts.items():
282
+ with gr.Tab(k, scroll_to_output=True):
283
+ lst = two_columns_list(v, chatbot)
284
+ prompts_checkbox.change(
285
+ lambda x: gr.update(visible=x),
286
+ prompts_checkbox,
287
+ prompt_column,
288
+ queue=False,
289
+ )
290
+ advanced_checkbox.change(
291
+ lambda x: gr.update(visible=x),
292
+ advanced_checkbox,
293
+ advanced_column,
294
+ queue=False,
295
+ )
296
+
297
+ textbox.submit(
298
+ fn=clear_and_save_textbox,
299
+ inputs=textbox,
300
+ outputs=[textbox, saved_input],
301
+ api_name=False,
302
+ queue=False,
303
+ ).then(
304
+ fn=display_input,
305
+ inputs=[saved_input, chatbot],
306
+ outputs=chatbot,
307
+ api_name=False,
308
+ queue=False,
309
+ ).then(
310
+ fn=check_input_token_length,
311
+ inputs=[saved_input, chatbot, system_prompt],
312
+ api_name=False,
313
+ queue=False,
314
+ ).success(
315
+ fn=generate,
316
+ inputs=[
317
+ saved_input,
318
+ chatbot,
319
+ system_prompt,
320
+ max_new_tokens,
321
+ temperature,
322
+ top_p,
323
+ top_k,
324
+ ],
325
+ outputs=chatbot,
326
+ api_name=False,
327
+ )
328
+
329
+ button_event_preprocess = (
330
+ submit_button.click(
331
+ fn=clear_and_save_textbox,
332
+ inputs=textbox,
333
+ outputs=[textbox, saved_input],
334
+ api_name=False,
335
+ queue=False,
336
+ )
337
+ .then(
338
+ fn=display_input,
339
+ inputs=[saved_input, chatbot],
340
+ outputs=chatbot,
341
+ api_name=False,
342
+ queue=False,
343
+ )
344
+ .then(
345
+ fn=check_input_token_length,
346
+ inputs=[saved_input, chatbot, system_prompt],
347
+ api_name=False,
348
+ queue=False,
349
+ )
350
+ .success(
351
+ fn=generate,
352
+ inputs=[
353
+ saved_input,
354
+ chatbot,
355
+ system_prompt,
356
+ max_new_tokens,
357
+ temperature,
358
+ top_p,
359
+ top_k,
360
+ ],
361
+ outputs=chatbot,
362
+ api_name=False,
363
+ )
364
+ )
365
+
366
+ retry_button.click(
367
+ fn=delete_prev_fn,
368
+ inputs=chatbot,
369
+ outputs=[chatbot, saved_input],
370
+ api_name=False,
371
+ queue=False,
372
+ ).then(
373
+ fn=display_input,
374
+ inputs=[saved_input, chatbot],
375
+ outputs=chatbot,
376
+ api_name=False,
377
+ queue=False,
378
+ ).then(
379
+ fn=generate,
380
+ inputs=[
381
+ saved_input,
382
+ chatbot,
383
+ system_prompt,
384
+ max_new_tokens,
385
+ temperature,
386
+ top_p,
387
+ top_k,
388
+ ],
389
+ outputs=chatbot,
390
+ api_name=False,
391
+ )
392
+
393
+ undo_button.click(
394
+ fn=delete_prev_fn,
395
+ inputs=chatbot,
396
+ outputs=[chatbot, saved_input],
397
+ api_name=False,
398
+ queue=False,
399
+ ).then(
400
+ fn=lambda x: x,
401
+ inputs=[saved_input],
402
+ outputs=textbox,
403
+ api_name=False,
404
+ queue=False,
405
+ )
406
+
407
+ clear_button.click(
408
+ fn=lambda: ([], ""),
409
+ outputs=[chatbot, saved_input],
410
+ queue=False,
411
+ api_name=False,
412
+ )
413
+
414
+ demo.queue(max_size=20).launch(share=args.share)
415
+
416
+
417
+ if __name__ == "__main__":
418
+ main()