s3nh commited on
Commit
410dbc6
1 Parent(s): aa8103d

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +339 -0
app.py ADDED
@@ -0,0 +1,339 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import platform
3
+ import random
4
+ import time
5
+ from dataclasses import asdict, dataclass
6
+ from pathlib import Path
7
+
8
+ # from types import SimpleNamespace
9
+ import gradio as gr
10
+ import psutil
11
+ from about_time import about_time
12
+ from ctransformers import AutoModelForCausalLM
13
+ from dl_hf_model import dl_hf_model
14
+ from loguru import logger
15
+
16
+
17
+ LLM = AutoModelForCausalLM.from_pretrained("s3nh/mamba-gpt-3b-GGML/mamba-gpt-3b.ggmlv3.q8_0.bin",
18
+ model_type="llama")
19
+
20
+
21
+
22
+ @dataclass
23
+ class GenerationConfig:
24
+ temperature: float = 0.7
25
+ top_k: int = 50
26
+ top_p: float = 0.9
27
+ repetition_penalty: float = 1.0
28
+ max_new_tokens: int = 512
29
+ seed: int = 42
30
+ reset: bool = False
31
+ stream: bool = True
32
+ # threads: int = cpu_count
33
+ # stop: list[str] = field(default_factory=lambda: [stop_string])
34
+
35
+
36
+ def generate(
37
+ question: str,
38
+ llm=LLM,
39
+ config: GenerationConfig = GenerationConfig(),
40
+ ):
41
+ """Run model inference, will return a Generator if streaming is true."""
42
+ # _ = prompt_template.format(question=question)
43
+ # print(_)
44
+
45
+ prompt = prompt_template.format(question=question)
46
+
47
+ return llm(
48
+ prompt,
49
+ **asdict(config),
50
+ )
51
+
52
+
53
+ logger.debug(f"{asdict(GenerationConfig())=}")
54
+
55
+
56
+ def user(user_message, history):
57
+ # return user_message, history + [[user_message, None]]
58
+ history.append([user_message, None])
59
+ return user_message, history # keep user_message
60
+
61
+
62
+ def user1(user_message, history):
63
+ # return user_message, history + [[user_message, None]]
64
+ history.append([user_message, None])
65
+ return "", history # clear user_message
66
+
67
+
68
+ def bot_(history):
69
+ user_message = history[-1][0]
70
+ resp = random.choice(["How are you?", "I love you", "I'm very hungry"])
71
+ bot_message = user_message + ": " + resp
72
+ history[-1][1] = ""
73
+ for character in bot_message:
74
+ history[-1][1] += character
75
+ time.sleep(0.02)
76
+ yield history
77
+
78
+ history[-1][1] = resp
79
+ yield history
80
+
81
+
82
+ def bot(history):
83
+ user_message = history[-1][0]
84
+ response = []
85
+
86
+ logger.debug(f"{user_message=}")
87
+
88
+ with about_time() as atime: # type: ignore
89
+ flag = 1
90
+ prefix = ""
91
+ then = time.time()
92
+
93
+ logger.debug("about to generate")
94
+
95
+ config = GenerationConfig(reset=True)
96
+ for elm in generate(user_message, config=config):
97
+ if flag == 1:
98
+ logger.debug("in the loop")
99
+ prefix = f"({time.time() - then:.2f}s) "
100
+ flag = 0
101
+ print(prefix, end="", flush=True)
102
+ logger.debug(f"{prefix=}")
103
+ print(elm, end="", flush=True)
104
+ # logger.debug(f"{elm}")
105
+
106
+ response.append(elm)
107
+ history[-1][1] = prefix + "".join(response)
108
+ yield history
109
+
110
+ _ = (
111
+ f"(time elapsed: {atime.duration_human}, " # type: ignore
112
+ f"{atime.duration/len(''.join(response)):.2f}s/char)" # type: ignore
113
+ )
114
+
115
+ history[-1][1] = "".join(response) + f"\n{_}"
116
+ yield history
117
+
118
+
119
+ def predict_api(prompt):
120
+ logger.debug(f"{prompt=}")
121
+ try:
122
+ # user_prompt = prompt
123
+ config = GenerationConfig(
124
+ temperature=0.2,
125
+ top_k=10,
126
+ top_p=0.9,
127
+ repetition_penalty=1.0,
128
+ max_new_tokens=512, # adjust as needed
129
+ seed=42,
130
+ reset=True, # reset history (cache)
131
+ stream=False,
132
+ # threads=cpu_count,
133
+ # stop=prompt_prefix[1:2],
134
+ )
135
+
136
+ response = generate(
137
+ prompt,
138
+ config=config,
139
+ )
140
+
141
+ logger.debug(f"api: {response=}")
142
+ except Exception as exc:
143
+ logger.error(exc)
144
+ response = f"{exc=}"
145
+ # bot = {"inputs": [response]}
146
+ # bot = [(prompt, response)]
147
+
148
+ return response
149
+
150
+
151
+ css = """
152
+ .importantButton {
153
+ background: linear-gradient(45deg, #7e0570,#5d1c99, #6e00ff) !important;
154
+ border: none !important;
155
+ }
156
+ .importantButton:hover {
157
+ background: linear-gradient(45deg, #ff00e0,#8500ff, #6e00ff) !important;
158
+ border: none !important;
159
+ }
160
+ .disclaimer {font-variant-caps: all-small-caps; font-size: xx-small;}
161
+ .xsmall {font-size: x-small;}
162
+ """
163
+ etext = """In America, where cars are an important part of the national psyche, a decade ago people had suddenly started to drive less, which had not happened since the oil shocks of the 1970s. """
164
+ examples_list = [
165
+ ["What NFL team won the Super Bowl in the year Justin Bieber was born?"],
166
+ [
167
+ "What NFL team won the Super Bowl in the year Justin Bieber was born? Think step by step."
168
+ ],
169
+ ["How to pick a lock? Provide detailed steps."],
170
+ ["If it takes 10 hours to dry 10 clothes, assuming all the clothes are hanged together at the same time for drying , then how long will it take to dry a cloth?"],
171
+ ["is infinity + 1 bigger than infinity?"],
172
+ ["Explain the plot of Cinderella in a sentence."],
173
+ [
174
+ "How long does it take to become proficient in French, and what are the best methods for retaining information?"
175
+ ],
176
+ ["What are some common mistakes to avoid when writing code?"],
177
+ ["Build a prompt to generate a beautiful portrait of a horse"],
178
+ ["Suggest four metaphors to describe the benefits of AI"],
179
+ ["Write a pop song about leaving home for the sandy beaches."],
180
+ ["Write a pop song about having hot sex on a sandy beach."],
181
+ ["Write a summary demonstrating my ability to tame lions"],
182
+ ["鲁迅和周树人什么关系? 说中文。"],
183
+ ["鲁迅和周树人什么关系?"],
184
+ ["鲁迅和周树人什么关系? 用英文回答。"],
185
+ ["从前有一头牛,这头牛后面有什么?"],
186
+ ["正无穷大加一大于正无穷大吗?"],
187
+ ["正无穷大加正无穷大大于正无穷大吗?"],
188
+ ["-2的平方根等于什么?"],
189
+ ["树上有5只鸟,猎人开枪打死了一只。树上还有几只鸟?"],
190
+ ["树上有11只鸟,猎人开枪打死了一只。树上还有几只鸟?提示:需考虑鸟可能受惊吓飞走。"],
191
+ ["以红楼梦的行文风格写一张委婉的请假条。不少于320字。"],
192
+ [f"{etext} 翻成中文,列出3个版本。"],
193
+ [f"{etext} \n 翻成中文,保留原意,但使用文学性的语言。不要写解释。列出3个版本。"],
194
+ ["假定 1 + 2 = 4, 试求 7 + 8。"],
195
+ ["给出判断一个数是不是质数的 javascript 码。"],
196
+ ["给出实现python 里 range(10)的 javascript 码。"],
197
+ ["给出实现python 里 [*(range(10)]的 javascript 码。"],
198
+ ["Erkläre die Handlung von Cinderella in einem Satz."],
199
+ ["Erkläre die Handlung von Cinderella in einem Satz. Auf Deutsch."],
200
+ ]
201
+
202
+ logger.info("start block")
203
+
204
+ with gr.Blocks(
205
+ title=f"{Path(model_loc).name}",
206
+ theme=gr.themes.Soft(text_size="sm", spacing_size="sm"),
207
+ css=css,
208
+ ) as block:
209
+ # buff_var = gr.State("")
210
+ with gr.Accordion("🎈 Info", open=False):
211
+ # gr.HTML(
212
+ # """<center><a href="https://huggingface.co/spaces/mikeee/mpt-30b-chat?duplicate=true"><img src="https://bit.ly/3gLdBN6" alt="Duplicate"></a> and spin a CPU UPGRADE to avoid the queue</center>"""
213
+ # )
214
+ gr.Markdown(
215
+ f"""<h5><center>{Path(model_loc).name}</center></h4>
216
+ Most examples are meant for another model.
217
+ You probably should try to test
218
+ some related prompts.""",
219
+ elem_classes="xsmall",
220
+ )
221
+
222
+ # chatbot = gr.Chatbot().style(height=700) # 500
223
+ chatbot = gr.Chatbot(height=500)
224
+
225
+ # buff = gr.Textbox(show_label=False, visible=True)
226
+
227
+ with gr.Row():
228
+ with gr.Column(scale=5):
229
+ msg = gr.Textbox(
230
+ label="Chat Message Box",
231
+ placeholder="Ask me anything (press Shift+Enter or click Submit to send)",
232
+ show_label=False,
233
+ # container=False,
234
+ lines=6,
235
+ max_lines=30,
236
+ show_copy_button=True,
237
+ # ).style(container=False)
238
+ )
239
+ with gr.Column(scale=1, min_width=50):
240
+ with gr.Row():
241
+ submit = gr.Button("Submit", elem_classes="xsmall")
242
+ stop = gr.Button("Stop", visible=True)
243
+ clear = gr.Button("Clear History", visible=True)
244
+ with gr.Row(visible=False):
245
+ with gr.Accordion("Advanced Options:", open=False):
246
+ with gr.Row():
247
+ with gr.Column(scale=2):
248
+ system = gr.Textbox(
249
+ label="System Prompt",
250
+ value=prompt_template,
251
+ show_label=False,
252
+ container=False,
253
+ # ).style(container=False)
254
+ )
255
+ with gr.Column():
256
+ with gr.Row():
257
+ change = gr.Button("Change System Prompt")
258
+ reset = gr.Button("Reset System Prompt")
259
+
260
+ with gr.Accordion("Example Inputs", open=True):
261
+ examples = gr.Examples(
262
+ examples=examples_list,
263
+ inputs=[msg],
264
+ examples_per_page=40,
265
+ )
266
+
267
+ # with gr.Row():
268
+ with gr.Accordion("Disclaimer", open=False):
269
+ _ = Path(model_loc).name
270
+ gr.Markdown(
271
+ f"Disclaimer: {_} can produce factually incorrect output, and should not be relied on to produce "
272
+ "factually accurate information. {_} was trained on various public datasets; while great efforts "
273
+ "have been taken to clean the pretraining data, it is possible that this model could generate lewd, "
274
+ "biased, or otherwise offensive outputs.",
275
+ elem_classes=["disclaimer"],
276
+ )
277
+
278
+ msg_submit_event = msg.submit(
279
+ # fn=conversation.user_turn,
280
+ fn=user,
281
+ inputs=[msg, chatbot],
282
+ outputs=[msg, chatbot],
283
+ queue=True,
284
+ show_progress="full",
285
+ # api_name=None,
286
+ ).then(bot, chatbot, chatbot, queue=True)
287
+ submit_click_event = submit.click(
288
+ # fn=lambda x, y: ("",) + user(x, y)[1:], # clear msg
289
+ fn=user1, # clear msg
290
+ inputs=[msg, chatbot],
291
+ outputs=[msg, chatbot],
292
+ queue=True,
293
+ # queue=False,
294
+ show_progress="full",
295
+ # api_name=None,
296
+ ).then(bot, chatbot, chatbot, queue=True)
297
+ stop.click(
298
+ fn=None,
299
+ inputs=None,
300
+ outputs=None,
301
+ cancels=[msg_submit_event, submit_click_event],
302
+ queue=False,
303
+ )
304
+ clear.click(lambda: None, None, chatbot, queue=False)
305
+
306
+ with gr.Accordion("For Chat/Translation API", open=False, visible=False):
307
+ input_text = gr.Text()
308
+ api_btn = gr.Button("Go", variant="primary")
309
+ out_text = gr.Text()
310
+
311
+ api_btn.click(
312
+ predict_api,
313
+ input_text,
314
+ out_text,
315
+ api_name="api",
316
+ )
317
+
318
+ # block.load(update_buff, [], buff, every=1)
319
+ # block.load(update_buff, [buff_var], [buff_var, buff], every=1)
320
+
321
+ # concurrency_count=5, max_size=20
322
+ # max_size=36, concurrency_count=14
323
+ # CPU cpu_count=2 16G, model 7G
324
+ # CPU UPGRADE cpu_count=8 32G, model 7G
325
+
326
+ # does not work
327
+ _ = """
328
+ # _ = int(psutil.virtual_memory().total / 10**9 // file_size - 1)
329
+ # concurrency_count = max(_, 1)
330
+ if psutil.cpu_count(logical=False) >= 8:
331
+ # concurrency_count = max(int(32 / file_size) - 1, 1)
332
+ else:
333
+ # concurrency_count = max(int(16 / file_size) - 1, 1)
334
+ # """
335
+
336
+ concurrency_count = 1
337
+ logger.info(f"{concurrency_count=}")
338
+
339
+ block.queue(concurrency_count=concurrency_count, max_size=5).launch(debug=True)