llama-python-streamingllm / llama_cpp_python_streamingllm.py
Limour's picture
Upload 2 files
f4e6998 verified
raw
history blame
11.4 kB
from typing import Optional, Sequence, Generator
from llama_cpp import Llama, LogitsProcessorList, LlamaGrammar, llama_cpp, npt, np, StoppingCriteriaList
from ctypes import POINTER
from KMP_list import kmp_search, compute_lps_array
def is_UTF8_incomplete(all_text):
multibyte_fix = 0
if len(all_text) < 3:
all_text = b'000' + all_text
for k, char in enumerate(all_text[-3:]):
k = 3 - k
for num, pattern in [(2, 192), (3, 224), (4, 240)]:
# Bitwise AND check
if num > k and pattern & char == pattern:
multibyte_fix = num - k
return multibyte_fix
def get_complete_UTF8(all_text):
multibyte_fix = is_UTF8_incomplete(all_text)
if multibyte_fix > 0:
multibyte_fix = multibyte_fix - 3
return all_text[:multibyte_fix].decode("utf-8")
else:
return all_text.decode("utf-8")
class StreamingLLM(Llama):
def __init__(self, model_path: str, **kwargs):
super().__init__(model_path, **kwargs)
self.venv = [0]
def str_detokenize(self, tokens) -> str:
return get_complete_UTF8(self.detokenize(tokens))
def kv_cache_seq_trim(self):
self._ctx.kv_cache_seq_rm(-1, self.n_tokens, -1)
def venv_create(self):
self.venv.append(0)
return len(self.venv) - 1
def venv_disband(self):
if len(self.venv) <= 1:
return 0
tmp = self.venv.pop()
self.venv[-1] += tmp
return len(self.venv) - 1
def venv_remove(self, venv_idx=None):
if venv_idx is None:
venv_idx = len(self.venv) - 1
if venv_idx <= 0 or venv_idx >= len(self.venv):
return len(self.venv) - 1
if venv_idx == len(self.venv) - 1:
# 最后一层
self.n_tokens -= min(self.venv.pop(), self.n_tokens)
self.kv_cache_seq_trim()
else:
# 非最后一层
n_keep = self.n_tokens - sum(self.venv[i] for i in range(venv_idx, len(self.venv)))
n_discard = self.venv.pop(venv_idx)
self.kv_cache_seq_ltrim(n_keep, n_discard)
return len(self.venv) - 1
def venv_pop_token(self):
self.n_tokens -= 1
self.venv[-1] -= 1
self.kv_cache_seq_trim()
def kv_cache_seq_ltrim(self, n_keep, n_discard=256, n_past=-1, im_start=None):
if n_past < 0:
n_past = self.n_tokens
if im_start is not None: # [<|im_start|>, name, nl]
lps = compute_lps_array(im_start)
_idx = kmp_search(self.input_ids, im_start, n_keep + n_discard, n_past, lps)
if _idx >= n_keep: # 其实是大于等于 n_keep + n_discard
n_discard = _idx - n_keep # 截断到最近的 im_start 序列结构
else:
_idx = kmp_search(self.input_ids, im_start, n_keep, n_past, lps)
if _idx >= n_keep:
n_keep = _idx + len(im_start) # 至少保留一个 im_start 序列结构
self._ctx.kv_cache_seq_rm(-1, n_keep, n_keep + n_discard)
self._ctx.kv_cache_seq_shift(0, n_keep + n_discard, n_past, -n_discard)
self.input_ids[n_keep:n_past - n_discard] = self.input_ids[n_keep + n_discard:n_past]
self.n_tokens = n_past - n_discard
def eval_t(self, tokens, n_keep=4, n_discard=256, im_start=None):
if self._n_ctx < self.n_tokens + len(tokens):
tmp_n_discard = max(n_discard, self.n_tokens + len(tokens) - self._n_ctx)
self.kv_cache_seq_ltrim(n_keep, tmp_n_discard, im_start=im_start)
for i in range(0, len(tokens), self.n_batch):
batch = tokens[i: i + self.n_batch]
n_past = self.n_tokens
n_tokens = len(batch)
self._batch.set_batch(
batch=batch, n_past=n_past, logits_all=self.context_params.logits_all
)
self._ctx.decode(self._batch)
# Save tokens
self.input_ids[n_past: n_past + n_tokens] = batch
# Save logits
rows = n_tokens
cols = self._n_vocab
offset = (
0 if self.context_params.logits_all else n_tokens - 1
) # NOTE: Only save the last token logits if logits_all is False
self.scores[n_past + offset: n_past + n_tokens, :].reshape(-1)[
:
] = self._ctx.get_logits()[offset * cols: rows * cols]
# Update n_tokens
self.n_tokens += n_tokens
self.venv[-1] += n_tokens
return self.n_tokens
def sample_t(
self,
top_k: int = 40,
top_p: float = 0.95,
min_p: float = 0.05,
typical_p: float = 1.0,
temp: float = 0.80,
repeat_penalty: float = 1.1,
repeat_last_n: int = 64,
frequency_penalty: float = 0.0,
presence_penalty: float = 0.0,
tfs_z: float = 1.0,
mirostat_mode: int = 0,
mirostat_eta: float = 0.1,
mirostat_tau: float = 5.0,
penalize_nl: bool = True,
logits_processor: Optional[LogitsProcessorList] = None,
grammar: Optional[LlamaGrammar] = None,
):
last_n_tokens_data = [llama_cpp.llama_token(0)] * max(
0, repeat_last_n - self.n_tokens
) + self._input_ids[-repeat_last_n:].tolist()
last_n_tokens_size = len(last_n_tokens_data)
n_vocab = self._n_vocab
n_ctx = self._n_ctx
top_k = n_vocab if top_k <= 0 else top_k
last_n_tokens_size = n_ctx if last_n_tokens_size < 0 else last_n_tokens_size
last_n_tokens_data_c = (llama_cpp.llama_token * last_n_tokens_size)(
*last_n_tokens_data
)
logits: npt.NDArray[np.single] = self.scores[self.n_tokens - 1: self.n_tokens, :].ravel()
if logits_processor is not None:
logits[:] = logits_processor(self._input_ids, logits)
self._candidates.copy_logits(logits)
self._ctx.sample_repetition_penalties(
candidates=self._candidates,
last_tokens_data=last_n_tokens_data_c,
penalty_last_n=last_n_tokens_size,
penalty_repeat=repeat_penalty,
penalty_freq=frequency_penalty,
penalty_present=presence_penalty,
)
if not penalize_nl:
nl_logit = logits[self._token_nl]
self._candidates.candidates.data[self._token_nl].logit = llama_cpp.c_float(
nl_logit
)
if grammar is not None:
self._ctx.sample_grammar(
candidates=self._candidates,
grammar=grammar,
)
if temp < 0.0:
self._ctx.sample_softmax(candidates=self._candidates)
id_ = self._candidates.candidates.data[0].id
elif temp == 0.0:
id_ = self._ctx.sample_token_greedy(candidates=self._candidates)
elif mirostat_mode == 1:
self._ctx.sample_temp(candidates=self._candidates, temp=temp)
id_ = self._ctx.sample_token_mirostat(
candidates=self._candidates,
tau=mirostat_tau,
eta=mirostat_eta,
mu=2.0 * mirostat_tau,
m=100,
)
elif mirostat_mode == 2:
self._ctx.sample_temp(candidates=self._candidates, temp=temp)
id_ = self._ctx.sample_token_mirostat_v2(
candidates=self._candidates,
tau=mirostat_tau,
eta=mirostat_eta,
mu=2.0 * mirostat_tau,
)
else:
self._ctx.sample_top_k(candidates=self._candidates, k=top_k, min_keep=1)
self._ctx.sample_tail_free(candidates=self._candidates, z=tfs_z, min_keep=1)
self._ctx.sample_typical(
candidates=self._candidates, p=typical_p, min_keep=1
)
self._ctx.sample_top_p(candidates=self._candidates, p=top_p, min_keep=1)
self._ctx.sample_min_p(candidates=self._candidates, p=min_p, min_keep=1)
self._ctx.sample_temp(candidates=self._candidates, temp=temp)
id_ = self._ctx.sample_token(candidates=self._candidates)
if grammar is not None:
self._ctx.grammar_accept_token(grammar=grammar, token=id_)
return id_
def generate_t(
self,
tokens: Sequence[int],
n_keep,
n_discard: int = 256,
im_start=None,
top_k: int = 40,
top_p: float = 0.95,
min_p: float = 0.05,
typical_p: float = 1.0,
temp: float = 0.80,
repeat_penalty: float = 1.1,
repeat_last_n: int = 64,
frequency_penalty: float = 0.0,
presence_penalty: float = 0.0,
tfs_z: float = 1.0,
mirostat_mode: int = 0,
mirostat_tau: float = 5.0,
mirostat_eta: float = 0.1,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
grammar: Optional[LlamaGrammar] = None,
) -> Generator[int, Optional[Sequence[int]], None]:
typical_p = float(typical_p)
frequency_penalty = float(frequency_penalty)
presence_penalty = float(presence_penalty)
tfs_z = float(tfs_z)
mirostat_tau = float(mirostat_tau)
while True:
self.eval_t(tokens, n_keep, n_discard, im_start=im_start)
token = self.sample_t(
top_k=top_k,
top_p=top_p,
min_p=min_p,
typical_p=typical_p,
temp=temp,
repeat_penalty=repeat_penalty,
repeat_last_n=repeat_last_n,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
tfs_z=tfs_z,
mirostat_mode=mirostat_mode,
mirostat_tau=mirostat_tau,
mirostat_eta=mirostat_eta,
logits_processor=logits_processor,
grammar=grammar,
)
if stopping_criteria is not None and stopping_criteria(
self._input_ids, self._scores[-1, :]
):
return
tokens_or_none = yield token
tokens = [token]
if tokens_or_none is not None:
tokens.extend(tokens_or_none)
def load_session(self, filepath: str):
n_tokens = POINTER(llama_cpp.c_size_t)(llama_cpp.c_size_t(0))
tokens = (llama_cpp.llama_token * self.n_ctx())()
retn = llama_cpp.llama_load_session_file(self._ctx.ctx,
filepath.encode('utf-8'),
tokens,
self.n_ctx(),
n_tokens)
self.n_tokens = n_tokens.contents.value
self.input_ids[:self.n_tokens] = tokens[:self.n_tokens]
return retn
def save_session(self, filepath: str):
tokens = self._input_ids.tolist()
tokens = (llama_cpp.llama_token * len(tokens))(*tokens)
return llama_cpp.llama_save_session_file(self._ctx.ctx, filepath.encode('utf-8'), tokens, self.n_tokens)