Limour commited on
Commit
bac870d
·
verified ·
1 Parent(s): 3f9dd7e

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +436 -0
app.py ADDED
@@ -0,0 +1,436 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+ import os
3
+ import re
4
+ import json
5
+
6
+ import gradio as gr
7
+
8
+ from chat_template import ChatTemplate
9
+ from llama_cpp_python_streamingllm import StreamingLLM
10
+
11
+ # ========== 让聊天界面的文本框等高 ==========
12
+ custom_css = r'''
13
+ #area > div {
14
+ height: 100%;
15
+ }
16
+ #RAG-area {
17
+ flex-grow: 1;
18
+ }
19
+ #RAG-area > label {
20
+ height: 100%;
21
+ display: flex;
22
+ flex-direction: column;
23
+ }
24
+ #RAG-area > label > textarea {
25
+ flex-grow: 1;
26
+ max-height: 20vh;
27
+ }
28
+ #VO-area {
29
+ flex-grow: 1;
30
+ }
31
+ #VO-area > label {
32
+ height: 100%;
33
+ display: flex;
34
+ flex-direction: column;
35
+ }
36
+ #VO-area > label > textarea {
37
+ flex-grow: 1;
38
+ max-height: 20vh;
39
+ }
40
+ #prompt > label > textarea {
41
+ max-height: 63px;
42
+ }
43
+ '''
44
+
45
+
46
+ # ========== 适配 SillyTavern 的模版 ==========
47
+ def text_format(text: str, _env=None, **env):
48
+ if _env is not None:
49
+ for k, v in _env.items():
50
+ text = text.replace(r'{{' + k + r'}}', v)
51
+ for k, v in env.items():
52
+ text = text.replace(r'{{' + k + r'}}', v)
53
+ return text
54
+
55
+
56
+ # ========== 哈希函数 ==========
57
+ def x_hash(x: str):
58
+ return hashlib.sha1(x.encode('utf-8')).hexdigest()
59
+
60
+
61
+ # ========== 读取配置文件 ==========
62
+ with open('rp_config.json', encoding='utf-8') as f:
63
+ tmp = f.read()
64
+ with open('rp_sample_config.json', encoding='utf-8') as f:
65
+ cfg = json.load(f)
66
+ cfg['setting_cache_path']['value'] += x_hash(tmp)
67
+ cfg.update(json.loads(tmp))
68
+
69
+ # ========== 给引号加粗 ==========
70
+ reg_q = re.compile(r'“(.+?)”')
71
+
72
+
73
+ def chat_display_format(text: str):
74
+ return reg_q.sub(r' **\g<0>** ', text)
75
+
76
+
77
+ # ========== 温度、采样之类的设置 ==========
78
+ with gr.Blocks() as setting:
79
+ with gr.Row():
80
+ setting_path = gr.Textbox(label="模型路径", max_lines=1, scale=2, **cfg['setting_path'])
81
+ setting_cache_path = gr.Textbox(label="缓存路径", max_lines=1, scale=2, **cfg['setting_cache_path'])
82
+ setting_seed = gr.Number(label="随机种子", scale=1, **cfg['setting_seed'])
83
+ setting_n_gpu_layers = gr.Number(label="n_gpu_layers", scale=1, **cfg['setting_n_gpu_layers'])
84
+ with gr.Row():
85
+ setting_ctx = gr.Number(label="上下文大小(Tokens)", **cfg['setting_ctx'])
86
+ setting_max_tokens = gr.Number(label="最大响应长度(Tokens)", interactive=True, **cfg['setting_max_tokens'])
87
+ setting_n_keep = gr.Number(value=10, label="n_keep", interactive=False)
88
+ setting_n_discard = gr.Number(label="n_discard", interactive=True, **cfg['setting_n_discard'])
89
+ with gr.Row():
90
+ setting_temperature = gr.Number(label="温度", interactive=True, **cfg['setting_temperature'])
91
+ setting_repeat_penalty = gr.Number(label="重复惩罚", interactive=True, **cfg['setting_repeat_penalty'])
92
+ setting_frequency_penalty = gr.Number(label="频率惩罚", interactive=True, **cfg['setting_frequency_penalty'])
93
+ setting_presence_penalty = gr.Number(label="存在惩罚", interactive=True, **cfg['setting_presence_penalty'])
94
+ setting_repeat_last_n = gr.Number(label="惩罚范围", interactive=True, **cfg['setting_repeat_last_n'])
95
+ with gr.Row():
96
+ setting_top_k = gr.Number(label="Top-K", interactive=True, **cfg['setting_top_k'])
97
+ setting_top_p = gr.Number(label="Top P", interactive=True, **cfg['setting_top_p'])
98
+ setting_min_p = gr.Number(label="Min P", interactive=True, **cfg['setting_min_p'])
99
+ setting_typical_p = gr.Number(label="Typical", interactive=True, **cfg['setting_typical_p'])
100
+ setting_tfs_z = gr.Number(label="TFS", interactive=True, **cfg['setting_tfs_z'])
101
+ with gr.Row():
102
+ setting_mirostat_mode = gr.Number(label="Mirostat 模式", **cfg['setting_mirostat_mode'])
103
+ setting_mirostat_eta = gr.Number(label="Mirostat 学习率", interactive=True, **cfg['setting_mirostat_eta'])
104
+ setting_mirostat_tau = gr.Number(label="Mirostat 目标熵", interactive=True, **cfg['setting_mirostat_tau'])
105
+
106
+ # ========== 加载模型 ==========
107
+ model = StreamingLLM(model_path=setting_path.value,
108
+ seed=setting_seed.value,
109
+ n_gpu_layers=setting_n_gpu_layers.value,
110
+ n_ctx=setting_ctx.value)
111
+ setting_ctx.value = model.n_ctx()
112
+
113
+ # ========== 聊天的模版 默认 chatml ==========
114
+ chat_template = ChatTemplate(model)
115
+
116
+ # ========== 展示角色卡 ==========
117
+ with gr.Blocks() as role:
118
+ with gr.Row():
119
+ role_usr = gr.Textbox(label="用户名称", max_lines=1, interactive=False, **cfg['role_usr'])
120
+ role_char = gr.Textbox(label="角色名称", max_lines=1, interactive=False, **cfg['role_char'])
121
+
122
+ role_char_d = gr.Textbox(lines=10, label="故事描述", **cfg['role_char_d'])
123
+ role_chat_style = gr.Textbox(lines=10, label="回复示例", **cfg['role_chat_style'])
124
+
125
+ # model.eval_t([1]) # 这个暖机的 bos [1] 删了就不正常了
126
+ if os.path.exists(setting_cache_path.value):
127
+ # ========== 加载角色卡-缓存 ==========
128
+ tmp = model.load_session(setting_cache_path.value)
129
+ print(f'load cache from {setting_cache_path.value} {tmp}')
130
+ tmp = chat_template('system',
131
+ text_format(role_char_d.value,
132
+ char=role_char.value,
133
+ user=role_usr.value))
134
+ setting_n_keep.value = len(tmp)
135
+ tmp = chat_template(role_char.value,
136
+ text_format(role_chat_style.value,
137
+ char=role_char.value,
138
+ user=role_usr.value))
139
+ setting_n_keep.value += len(tmp)
140
+ # ========== 加载角色卡-第一条消息 ==========
141
+ chatbot = []
142
+ for one in cfg["role_char_first"]:
143
+ one['name'] = text_format(one['name'],
144
+ char=role_char.value,
145
+ user=role_usr.value)
146
+ one['value'] = text_format(one['value'],
147
+ char=role_char.value,
148
+ user=role_usr.value)
149
+ if one['name'] == role_char.value:
150
+ chatbot.append((None, chat_display_format(one['value'])))
151
+ print(one)
152
+ else:
153
+ # ========== 加载角色卡-角色描述 ==========
154
+ tmp = chat_template('system',
155
+ text_format(role_char_d.value,
156
+ char=role_char.value,
157
+ user=role_usr.value))
158
+ setting_n_keep.value = model.eval_t(tmp) # 此内容永久存在
159
+
160
+ # ========== 加载角色卡-回复示例 ==========
161
+ tmp = chat_template(role_char.value,
162
+ text_format(role_chat_style.value,
163
+ char=role_char.value,
164
+ user=role_usr.value))
165
+ setting_n_keep.value = model.eval_t(tmp) # 此内容永久存在
166
+
167
+ # ========== 加载角色卡-第一条消息 ==========
168
+ chatbot = []
169
+ for one in cfg["role_char_first"]:
170
+ one['name'] = text_format(one['name'],
171
+ char=role_char.value,
172
+ user=role_usr.value)
173
+ one['value'] = text_format(one['value'],
174
+ char=role_char.value,
175
+ user=role_usr.value)
176
+ if one['name'] == role_char.value:
177
+ chatbot.append((None, chat_display_format(one['value'])))
178
+ print(one)
179
+ tmp = chat_template(one['name'], one['value'])
180
+ model.eval_t(tmp) # 此内容随上下文增加将被丢弃
181
+
182
+ # ========== 保存角色卡-缓存 ==========
183
+ with open(setting_cache_path.value, 'wb') as f:
184
+ pass
185
+ tmp = model.save_session(setting_cache_path.value)
186
+ print(f'save cache {tmp}')
187
+
188
+
189
+ # ========== 流式输出函数 ==========
190
+ def btn_submit_com(_n_keep, _n_discard,
191
+ _temperature, _repeat_penalty, _frequency_penalty,
192
+ _presence_penalty, _repeat_last_n, _top_k,
193
+ _top_p, _min_p, _typical_p,
194
+ _tfs_z, _mirostat_mode, _mirostat_eta,
195
+ _mirostat_tau, _role, _max_tokens):
196
+ # ========== 初始化输出模版 ==========
197
+ t_bot = chat_template(_role)
198
+ completion_tokens = [] # 有可能多个 tokens 才能构成一个 utf-8 编码的文字
199
+ history = ''
200
+ # ========== 流式输出 ==========
201
+ for token in model.generate_t(
202
+ tokens=t_bot,
203
+ n_keep=_n_keep,
204
+ n_discard=_n_discard,
205
+ im_start=chat_template.im_start_token,
206
+ top_k=_top_k,
207
+ top_p=_top_p,
208
+ min_p=_min_p,
209
+ typical_p=_typical_p,
210
+ temp=_temperature,
211
+ repeat_penalty=_repeat_penalty,
212
+ repeat_last_n=_repeat_last_n,
213
+ frequency_penalty=_frequency_penalty,
214
+ presence_penalty=_presence_penalty,
215
+ tfs_z=_tfs_z,
216
+ mirostat_mode=_mirostat_mode,
217
+ mirostat_tau=_mirostat_tau,
218
+ mirostat_eta=_mirostat_eta,
219
+ ):
220
+ if token in chat_template.eos or token == chat_template.nlnl:
221
+ t_bot.extend(completion_tokens)
222
+ print('token in eos', token)
223
+ break
224
+ completion_tokens.append(token)
225
+ all_text = model.str_detokenize(completion_tokens)
226
+ if not all_text:
227
+ continue
228
+ t_bot.extend(completion_tokens)
229
+ history += all_text
230
+ yield history
231
+ if token in chat_template.onenl:
232
+ # ========== 移除末尾的换行符 ==========
233
+ if t_bot[-2] in chat_template.onenl:
234
+ model.venv_pop_token()
235
+ break
236
+ if t_bot[-2] in chat_template.onerl and t_bot[-3] in chat_template.onenl:
237
+ model.venv_pop_token()
238
+ break
239
+ if history[-2:] == '\n\n': # 各种 'x\n\n' 的token,比如'。\n\n'
240
+ print('t_bot[-4:]', t_bot[-4:], repr(model.str_detokenize(t_bot[-4:])),
241
+ repr(model.str_detokenize(t_bot[-1:])))
242
+ break
243
+ if len(t_bot) > _max_tokens:
244
+ break
245
+ completion_tokens = []
246
+ # ========== 查看末尾的换行符 ==========
247
+ print('history', repr(history))
248
+ # ========== 给 kv_cache 加上输出结束符 ==========
249
+ model.eval_t(chat_template.im_end_nl, _n_keep, _n_discard)
250
+ t_bot.extend(chat_template.im_end_nl)
251
+
252
+
253
+ # ========== 显示用户消息 ==========
254
+ def btn_submit_usr(message: str, history):
255
+ # print('btn_submit_usr', message, history)
256
+ if history is None:
257
+ history = []
258
+ return "", history + [[message.strip(), '']], gr.update(interactive=False)
259
+
260
+
261
+ # ========== 模型流式响应 ==========
262
+ def btn_submit_bot(history, _n_keep, _n_discard,
263
+ _temperature, _repeat_penalty, _frequency_penalty,
264
+ _presence_penalty, _repeat_last_n, _top_k,
265
+ _top_p, _min_p, _typical_p,
266
+ _tfs_z, _mirostat_mode, _mirostat_eta,
267
+ _mirostat_tau, _usr, _char,
268
+ _rag, _max_tokens):
269
+ # ========== 需要临时注入的内容 ==========
270
+ rag_idx = None
271
+ if len(_rag) > 0:
272
+ rag_idx = model.venv_create() # 记录 venv_idx
273
+ t_rag = chat_template('system', _rag)
274
+ model.eval_t(t_rag, _n_keep, _n_discard)
275
+ model.venv_create() # 与 t_rag 隔离
276
+ # ========== 用户输入 ==========
277
+ t_msg = history[-1][0]
278
+ t_msg = chat_template(_usr, t_msg)
279
+ model.eval_t(t_msg, _n_keep, _n_discard)
280
+ # ========== 模型输出 ==========
281
+ _tmp = btn_submit_com(_n_keep, _n_discard,
282
+ _temperature, _repeat_penalty, _frequency_penalty,
283
+ _presence_penalty, _repeat_last_n, _top_k,
284
+ _top_p, _min_p, _typical_p,
285
+ _tfs_z, _mirostat_mode, _mirostat_eta,
286
+ _mirostat_tau, _char, _max_tokens)
287
+ for _h in _tmp:
288
+ history[-1][1] = _h
289
+ yield history, str((model.n_tokens, model.venv))
290
+ # ========== 输出完毕后格式化输出 ==========
291
+ history[-1][1] = chat_display_format(history[-1][1])
292
+ yield history, str((model.n_tokens, model.venv))
293
+ # ========== 及时清理上一次生成的旁白 ==========
294
+ if vo_idx > 0:
295
+ print('vo_idx', vo_idx, model.venv)
296
+ model.venv_remove(vo_idx)
297
+ print('vo_idx', vo_idx, model.venv)
298
+ if rag_idx and vo_idx < rag_idx:
299
+ rag_idx -= 1
300
+ # ========== 响应完毕后清除注入的内容 ==========
301
+ if rag_idx is not None:
302
+ model.venv_remove(rag_idx) # 销毁对应的 venv
303
+ model.venv_disband() # 退出隔离环境
304
+ yield history, str((model.n_tokens, model.venv))
305
+ print('venv_disband', vo_idx, model.venv)
306
+
307
+
308
+ # ========== 待实现 ==========
309
+ def btn_rag_(_rag, _msg):
310
+ retn = ''
311
+ return retn
312
+
313
+
314
+ vo_idx = 0
315
+
316
+
317
+ # ========== 输出一段旁白 ==========
318
+ def btn_submit_vo(_n_keep, _n_discard,
319
+ _temperature, _repeat_penalty, _frequency_penalty,
320
+ _presence_penalty, _repeat_last_n, _top_k,
321
+ _top_p, _min_p, _typical_p,
322
+ _tfs_z, _mirostat_mode, _mirostat_eta,
323
+ _mirostat_tau, _max_tokens):
324
+ global vo_idx
325
+ vo_idx = model.venv_create() # 创建隔离环境
326
+ # ========== 模型输出旁白 ==========
327
+ _tmp = btn_submit_com(_n_keep, _n_discard,
328
+ _temperature, _repeat_penalty, _frequency_penalty,
329
+ _presence_penalty, _repeat_last_n, _top_k,
330
+ _top_p, _min_p, _typical_p,
331
+ _tfs_z, _mirostat_mode, _mirostat_eta,
332
+ _mirostat_tau, '旁白', _max_tokens)
333
+ for _h in _tmp:
334
+ yield _h, str((model.n_tokens, model.venv))
335
+
336
+
337
+ # ========== 给用户提供默认回复 ==========
338
+ def btn_submit_suggest(_n_keep, _n_discard,
339
+ _temperature, _repeat_penalty, _frequency_penalty,
340
+ _presence_penalty, _repeat_last_n, _top_k,
341
+ _top_p, _min_p, _typical_p,
342
+ _tfs_z, _mirostat_mode, _mirostat_eta,
343
+ _mirostat_tau, _usr, _max_tokens):
344
+ model.venv_create() # 创建隔离环境
345
+ # ========== 模型输出 ==========
346
+ _tmp = btn_submit_com(_n_keep, _n_discard,
347
+ _temperature, _repeat_penalty, _frequency_penalty,
348
+ _presence_penalty, _repeat_last_n, _top_k,
349
+ _top_p, _min_p, _typical_p,
350
+ _tfs_z, _mirostat_mode, _mirostat_eta,
351
+ _mirostat_tau, _usr, _max_tokens)
352
+ _h = ''
353
+ for _h in _tmp:
354
+ yield _h, str((model.n_tokens, model.venv))
355
+ model.venv_remove() # 销毁隔离环境
356
+ yield _h, str((model.n_tokens, model.venv))
357
+
358
+
359
+ # ========== 聊天页面 ==========
360
+ with gr.Blocks() as chatting:
361
+ with gr.Row(equal_height=True):
362
+ chatbot = gr.Chatbot(height='60vh', scale=2, value=chatbot,
363
+ avatar_images=(r'assets/user.png', r'assets/chatbot.webp'))
364
+ with gr.Column(scale=1, elem_id="area"):
365
+ rag = gr.Textbox(label='RAG', show_copy_button=True, elem_id="RAG-area")
366
+ vo = gr.Textbox(label='VO', show_copy_button=True, elem_id="VO-area")
367
+ s_info = gr.Textbox(value=str((model.n_tokens, model.venv)), max_lines=1, label='info', interactive=False)
368
+ msg = gr.Textbox(label='Prompt', lines=2, max_lines=2, elem_id='prompt', autofocus=True, **cfg['msg'])
369
+ with gr.Row():
370
+ btn_rag = gr.Button("RAG")
371
+ btn_submit = gr.Button("Submit")
372
+ btn_retry = gr.Button("Retry")
373
+ btn_com1 = gr.Button("自定义1")
374
+ btn_com2 = gr.Button("自定义2")
375
+ btn_com3 = gr.Button("自定义3")
376
+
377
+ btn_rag.click(fn=btn_rag_, outputs=rag,
378
+ inputs=[rag, msg])
379
+
380
+ btn_submit.click(
381
+ fn=btn_submit_usr, api_name="submit",
382
+ inputs=[msg, chatbot],
383
+ outputs=[msg, chatbot, btn_submit]
384
+ ).then(
385
+ fn=btn_submit_bot,
386
+ inputs=[chatbot, setting_n_keep, setting_n_discard,
387
+ setting_temperature, setting_repeat_penalty, setting_frequency_penalty,
388
+ setting_presence_penalty, setting_repeat_last_n, setting_top_k,
389
+ setting_top_p, setting_min_p, setting_typical_p,
390
+ setting_tfs_z, setting_mirostat_mode, setting_mirostat_eta,
391
+ setting_mirostat_tau, role_usr, role_char,
392
+ rag, setting_max_tokens],
393
+ outputs=[chatbot, s_info]
394
+ ).then(
395
+ fn=btn_submit_vo,
396
+ inputs=[setting_n_keep, setting_n_discard,
397
+ setting_temperature, setting_repeat_penalty, setting_frequency_penalty,
398
+ setting_presence_penalty, setting_repeat_last_n, setting_top_k,
399
+ setting_top_p, setting_min_p, setting_typical_p,
400
+ setting_tfs_z, setting_mirostat_mode, setting_mirostat_eta,
401
+ setting_mirostat_tau, setting_max_tokens],
402
+ outputs=[vo, s_info]
403
+ ).then(
404
+ fn=btn_submit_suggest,
405
+ inputs=[setting_n_keep, setting_n_discard,
406
+ setting_temperature, setting_repeat_penalty, setting_frequency_penalty,
407
+ setting_presence_penalty, setting_repeat_last_n, setting_top_k,
408
+ setting_top_p, setting_min_p, setting_typical_p,
409
+ setting_tfs_z, setting_mirostat_mode, setting_mirostat_eta,
410
+ setting_mirostat_tau, role_usr, setting_max_tokens],
411
+ outputs=[msg, s_info]
412
+ ).then(
413
+ fn=lambda: gr.update(interactive=True),
414
+ outputs=btn_submit
415
+ )
416
+
417
+ # ========== 用于调试 ==========
418
+ btn_com1.click(fn=lambda: model.str_detokenize(model._input_ids), outputs=rag)
419
+
420
+
421
+ @btn_com2.click(inputs=setting_cache_path,
422
+ outputs=s_info)
423
+ def btn_com2(_cache_path):
424
+ _tmp = model.load_session(setting_cache_path.value)
425
+ print(f'load cache from {setting_cache_path.value} {_tmp}')
426
+ global vo_idx
427
+ vo_idx = 0
428
+ model.venv = [0]
429
+ return str((model.n_tokens, model.venv))
430
+
431
+ # ========== 开始运行 ==========
432
+ demo = gr.TabbedInterface([chatting, setting, role],
433
+ ["聊天", "设置", '角色'],
434
+ css=custom_css)
435
+ gr.close_all()
436
+ demo.queue().launch(share=False)