Spaces:
Runtime error
Runtime error
Upload 3 files
Browse files- chat_template.py +36 -0
- gradio_streamingllm.py +4 -12
- llama_cpp_python_streamingllm.py +6 -26
chat_template.py
CHANGED
@@ -3,6 +3,7 @@ import copy
|
|
3 |
|
4 |
class ChatTemplate:
|
5 |
cache = {}
|
|
|
6 |
|
7 |
def __init__(self, model, im_start=r'<|im_start|>', im_end=r'<|im_end|>', nl='\n'):
|
8 |
self.model = model
|
@@ -31,7 +32,42 @@ class ChatTemplate:
|
|
31 |
self.cache[key] = copy.deepcopy(value) # 深拷贝一下
|
32 |
return value
|
33 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
def __call__(self, _role, prompt=None):
|
|
|
35 |
if prompt is None:
|
36 |
return self._get(_role)
|
37 |
# print(_role, prompt, self.cache)
|
|
|
3 |
|
4 |
class ChatTemplate:
|
5 |
cache = {}
|
6 |
+
roles = set()
|
7 |
|
8 |
def __init__(self, model, im_start=r'<|im_start|>', im_end=r'<|im_end|>', nl='\n'):
|
9 |
self.model = model
|
|
|
32 |
self.cache[key] = copy.deepcopy(value) # 深拷贝一下
|
33 |
return value
|
34 |
|
35 |
+
def _add_role(self, _role):
|
36 |
+
if _role:
|
37 |
+
self.roles.add('\n' + _role)
|
38 |
+
|
39 |
+
def eos_in_role(self, history: str, t_bot):
|
40 |
+
if not (history.endswith('\n') or history.endswith('\r')):
|
41 |
+
return 0
|
42 |
+
tmp = history.rstrip()
|
43 |
+
for _role in self.roles:
|
44 |
+
if tmp.endswith(_role):
|
45 |
+
n = len(t_bot)
|
46 |
+
for i in range(1, n): # 找出需要弃置的tokens长度
|
47 |
+
tmp = self.model.str_detokenize(t_bot[n - i:])
|
48 |
+
if tmp.rstrip().endswith(_role):
|
49 |
+
print('eos_in_role', t_bot[n - i:], repr(tmp))
|
50 |
+
return i
|
51 |
+
print('eos_in_role missing')
|
52 |
+
break
|
53 |
+
return 0
|
54 |
+
|
55 |
+
def eos_in_nlnl(self, history: str, t_bot):
|
56 |
+
if not (history.endswith('\n\n') or history.endswith('\n\r\n')):
|
57 |
+
return 0
|
58 |
+
n = len(t_bot)
|
59 |
+
for i in range(1, n): # 找出需要弃置的tokens长度
|
60 |
+
tmp = self.model.str_detokenize(t_bot[n - i:])
|
61 |
+
if tmp.endswith('\n\n') or tmp.endswith('\n\r\n'):
|
62 |
+
if tmp.startswith(']'): # 避免误判
|
63 |
+
return 0
|
64 |
+
print('eos_in_nlnl', t_bot[n - i:], repr(tmp))
|
65 |
+
return i
|
66 |
+
print('eos_in_nlnl missing')
|
67 |
+
return 0
|
68 |
+
|
69 |
def __call__(self, _role, prompt=None):
|
70 |
+
self._add_role(_role)
|
71 |
if prompt is None:
|
72 |
return self._get(_role)
|
73 |
# print(_role, prompt, self.cache)
|
gradio_streamingllm.py
CHANGED
@@ -28,6 +28,9 @@ from mods.btn_reset import init as btn_reset_init
|
|
28 |
# ========== 聊天的模版 默认 chatml ==========
|
29 |
from chat_template import ChatTemplate
|
30 |
|
|
|
|
|
|
|
31 |
# ========== 全局锁,确保只能进行一个会话 ==========
|
32 |
cfg['session_lock'] = threading.Lock()
|
33 |
cfg['session_active'] = False
|
@@ -84,8 +87,6 @@ with gr.Blocks() as role:
|
|
84 |
cfg['role_chat_style'] = gr.Textbox(lines=10, label="回复示例", **cfg['role_chat_style'])
|
85 |
|
86 |
# ========== 加载角色卡-缓存 ==========
|
87 |
-
from mods.load_cache import init as load_cache_init
|
88 |
-
|
89 |
text_display_init(cfg)
|
90 |
load_cache_init(cfg)
|
91 |
|
@@ -99,15 +100,6 @@ with gr.Blocks() as chatting:
|
|
99 |
cfg['vo'] = gr.Textbox(label='VO', show_copy_button=True, elem_id="VO-area")
|
100 |
cfg['s_info'] = gr.Textbox(value=cfg['model'].venv_info, max_lines=1, label='info', interactive=False)
|
101 |
cfg['msg'] = gr.Textbox(label='Prompt', lines=2, max_lines=2, elem_id='prompt', autofocus=True, **cfg['msg'])
|
102 |
-
with gr.Row():
|
103 |
-
cfg['btn_vo'] = gr.Button("旁白")
|
104 |
-
cfg['btn_rag'] = gr.Button("RAG")
|
105 |
-
cfg['btn_retry'] = gr.Button("Retry")
|
106 |
-
cfg['btn_com1'] = gr.Button("自定义1")
|
107 |
-
cfg['btn_reset'] = gr.Button("Reset")
|
108 |
-
cfg['btn_debug'] = gr.Button("Debug")
|
109 |
-
cfg['btn_submit'] = gr.Button("Submit")
|
110 |
-
cfg['btn_suggest'] = gr.Button("建议")
|
111 |
|
112 |
cfg['gr'] = gr
|
113 |
btn_com_init(cfg)
|
@@ -164,4 +156,4 @@ demo = gr.TabbedInterface([chatting, setting, role],
|
|
164 |
["聊天", "设置", '角色'],
|
165 |
css=custom_css)
|
166 |
gr.close_all()
|
167 |
-
demo.queue(api_open=False, max_size=1).launch(share=False)
|
|
|
28 |
# ========== 聊天的模版 默认 chatml ==========
|
29 |
from chat_template import ChatTemplate
|
30 |
|
31 |
+
# ========== 加载角色卡-缓存 ==========
|
32 |
+
from mods.load_cache import init as load_cache_init
|
33 |
+
|
34 |
# ========== 全局锁,确保只能进行一个会话 ==========
|
35 |
cfg['session_lock'] = threading.Lock()
|
36 |
cfg['session_active'] = False
|
|
|
87 |
cfg['role_chat_style'] = gr.Textbox(lines=10, label="回复示例", **cfg['role_chat_style'])
|
88 |
|
89 |
# ========== 加载角色卡-缓存 ==========
|
|
|
|
|
90 |
text_display_init(cfg)
|
91 |
load_cache_init(cfg)
|
92 |
|
|
|
100 |
cfg['vo'] = gr.Textbox(label='VO', show_copy_button=True, elem_id="VO-area")
|
101 |
cfg['s_info'] = gr.Textbox(value=cfg['model'].venv_info, max_lines=1, label='info', interactive=False)
|
102 |
cfg['msg'] = gr.Textbox(label='Prompt', lines=2, max_lines=2, elem_id='prompt', autofocus=True, **cfg['msg'])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
103 |
|
104 |
cfg['gr'] = gr
|
105 |
btn_com_init(cfg)
|
|
|
156 |
["聊天", "设置", '角色'],
|
157 |
css=custom_css)
|
158 |
gr.close_all()
|
159 |
+
demo.queue(api_open=False, max_size=1).launch(share=False, show_error=True, show_api=False)
|
llama_cpp_python_streamingllm.py
CHANGED
@@ -6,35 +6,13 @@ from ctypes import POINTER
|
|
6 |
from KMP_list import kmp_search, compute_lps_array
|
7 |
|
8 |
|
9 |
-
def is_UTF8_incomplete(all_text):
|
10 |
-
multibyte_fix = 0
|
11 |
-
if len(all_text) < 3:
|
12 |
-
all_text = b'000' + all_text
|
13 |
-
for k, char in enumerate(all_text[-3:]):
|
14 |
-
k = 3 - k
|
15 |
-
for num, pattern in [(2, 192), (3, 224), (4, 240)]:
|
16 |
-
# Bitwise AND check
|
17 |
-
if num > k and pattern & char == pattern:
|
18 |
-
multibyte_fix = num - k
|
19 |
-
return multibyte_fix
|
20 |
-
|
21 |
-
|
22 |
-
def get_complete_UTF8(all_text):
|
23 |
-
multibyte_fix = is_UTF8_incomplete(all_text)
|
24 |
-
if multibyte_fix > 0:
|
25 |
-
multibyte_fix = multibyte_fix - 3
|
26 |
-
return all_text[:multibyte_fix].decode("utf-8")
|
27 |
-
else:
|
28 |
-
return all_text.decode("utf-8")
|
29 |
-
|
30 |
-
|
31 |
class StreamingLLM(Llama):
|
32 |
def __init__(self, model_path: str, **kwargs):
|
33 |
super().__init__(model_path, **kwargs)
|
34 |
self._venv_init()
|
35 |
|
36 |
def str_detokenize(self, tokens) -> str:
|
37 |
-
return
|
38 |
|
39 |
def kv_cache_seq_trim(self):
|
40 |
self._ctx.kv_cache_seq_rm(-1, self.n_tokens, -1)
|
@@ -103,9 +81,9 @@ class StreamingLLM(Llama):
|
|
103 |
break
|
104 |
return True
|
105 |
|
106 |
-
def venv_pop_token(self):
|
107 |
-
self.n_tokens -=
|
108 |
-
self.venv[-1] -=
|
109 |
self.kv_cache_seq_trim()
|
110 |
|
111 |
@property
|
@@ -113,6 +91,8 @@ class StreamingLLM(Llama):
|
|
113 |
return str((self.n_tokens, self.venv, self.venv_idx_map))
|
114 |
|
115 |
def kv_cache_seq_ltrim(self, n_keep, n_discard=256, n_past=-1, im_start=None):
|
|
|
|
|
116 |
if n_past < 0:
|
117 |
n_past = self.n_tokens
|
118 |
if im_start is not None: # [<|im_start|>, name, nl]
|
|
|
6 |
from KMP_list import kmp_search, compute_lps_array
|
7 |
|
8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
class StreamingLLM(Llama):
|
10 |
def __init__(self, model_path: str, **kwargs):
|
11 |
super().__init__(model_path, **kwargs)
|
12 |
self._venv_init()
|
13 |
|
14 |
def str_detokenize(self, tokens) -> str:
|
15 |
+
return self.detokenize(tokens).decode('utf-8', errors='ignore')
|
16 |
|
17 |
def kv_cache_seq_trim(self):
|
18 |
self._ctx.kv_cache_seq_rm(-1, self.n_tokens, -1)
|
|
|
81 |
break
|
82 |
return True
|
83 |
|
84 |
+
def venv_pop_token(self, n=1):
|
85 |
+
self.n_tokens -= n
|
86 |
+
self.venv[-1] -= n
|
87 |
self.kv_cache_seq_trim()
|
88 |
|
89 |
@property
|
|
|
91 |
return str((self.n_tokens, self.venv, self.venv_idx_map))
|
92 |
|
93 |
def kv_cache_seq_ltrim(self, n_keep, n_discard=256, n_past=-1, im_start=None):
|
94 |
+
if n_keep < 0:
|
95 |
+
return
|
96 |
if n_past < 0:
|
97 |
n_past = self.n_tokens
|
98 |
if im_start is not None: # [<|im_start|>, name, nl]
|