Spaces:
Build error
Build error
import cuda_ext | |
from model import ExLlama, ExLlamaCache | |
from tokenizer import ExLlamaTokenizer | |
from lora import ExLlamaLora | |
import torch | |
import torch.nn.functional as F | |
MAX_CACHED_STRINGS = 100 | |
class ExLlamaAltGenerator: | |
class Settings: | |
temperature = 0.95 | |
top_k = 40 # consider the most probable top_k samples, 0 to disable top_k sampling | |
top_p = 0.65 # consider tokens up to a cumulative probabiltiy of top_p, 0.0 to disable top_p sampling | |
min_p = 0.0 # Do not consider tokens with probability less than this | |
typical = 0.0 # Locally typical sampling threshold, 0.0 to disable typical sampling | |
token_repetition_penalty_max = 1.15 # Repetition penalty for most recent tokens | |
token_repetition_penalty_sustain = -1 # No. most recent tokens to repeat penalty for, -1 to apply to whole context | |
token_repetition_penalty_decay = 0 # Gradually decrease penalty over this many tokens | |
disallowed_tokens: list[int] = None # List of tokens to inhibit, e.g. tokenizer.eos_token_id | |
lora: ExLlamaLora = None # LoRA to apply when generating | |
# Internal state | |
model: ExLlama | |
cache: ExLlamaCache | |
tokenizer: ExLlamaTokenizer | |
tokenizer_cache = {} | |
settings: Settings | |
stop_strings: list = [] | |
stop_tokens: list = [] | |
held_text: str = "" | |
max_stop_tokens: int = 2 | |
sequence_ids: torch.Tensor = None | |
sequence_str: str = None | |
remaining_tokens: int = 0 | |
def __init__(self, model: ExLlama, tokenizer: ExLlamaTokenizer, cache: ExLlamaCache): | |
self.model = model | |
self.tokenizer = tokenizer | |
self.cache = cache | |
self.settings = ExLlamaAltGenerator.Settings() | |
def cached_tokenize(self, text: str, encode_special_characters = False): | |
if text in self.tokenizer_cache: | |
return self.tokenizer_cache[text] | |
while len(self.tokenizer_cache) >= MAX_CACHED_STRINGS: | |
del self.tokenizer_cache[next(iter(self.tokenizer_cache))] # Always removes oldest entry, as of Python 3.7 | |
new_enc = self.tokenizer.encode(text, encode_special_characters = encode_special_characters) | |
self.tokenizer_cache[text] = new_enc | |
return new_enc | |
def get_num_tokens(self, text: str, encode_special_characters = False): | |
return self.cached_tokenize(text, encode_special_characters = encode_special_characters).shape[-1] | |
# Begin generating | |
# | |
# prompt: The input prompt. Will be tokenized and then truncated to the models max sequence length minus max_new_tokens | |
# stop_conditions: List of strings or integer token IDs that will end the sequence | |
# settings: ExLlamaAltGeneratorSettings | |
# encode_special_characters: Set to true to tokenize "</s>" etc. | |
def begin_stream(self, prompt: str, stop_conditions: list, max_new_tokens: int, gen_settings: Settings, encode_special_characters = False): | |
assert isinstance(prompt, str), "ExLlamaAltGenerator does not support batched generation" | |
# Tokenize prompt and limit length to allow prompt and (max) new tokens within max sequence length | |
max_input_tokens = self.model.config.max_seq_len - max_new_tokens | |
self.remaining_tokens = max_new_tokens | |
input_ids = self.cached_tokenize(prompt, encode_special_characters) | |
applied_input_ids = input_ids[:, -max_input_tokens:] | |
self.sequence_str = self.tokenizer.decode(applied_input_ids)[0] if applied_input_ids.shape[0] < input_ids.shape[0] else prompt | |
# Settings | |
self.stop_strings = [] | |
self.stop_tokens = [] | |
for t in stop_conditions: | |
if isinstance(t, int): self.stop_tokens += [t] | |
elif isinstance(t, str): self.stop_strings += [t] | |
else: raise ValueError("Unsupported type in stop_conditions") | |
self.held_text = "" | |
self.max_stop_tokens = 2 | |
for ss in self.stop_strings: | |
self.max_stop_tokens = max(self.max_stop_tokens, self.get_num_tokens(ss) + 2) | |
self.settings = gen_settings | |
# Start generation | |
self.gen_begin_reuse(applied_input_ids, gen_settings) | |
# Get the next chunk of text in the stream | |
# | |
# Returns stream_chunk: str, EOS: bool | |
def stream(self): | |
# Check total response length | |
if self.remaining_tokens == 0: | |
self.sequence_str += self.held_text | |
return self.held_text, True | |
self.remaining_tokens -= 1 | |
# Decode the current tail end of the sequence | |
old_tail = self.tokenizer.decode(self.sequence_ids[:, -self.max_stop_tokens:])[0] | |
# Generate a single token and append to the sequence | |
next_token = self.gen_single_token(self.settings) | |
# End immediately if it was a stop token | |
if next_token in self.stop_tokens: | |
self.sequence_str += self.held_text | |
return self.held_text, True | |
# Decode the tail end of the sequence with the added token to get (actual) characters added | |
new_tail = self.tokenizer.decode(self.sequence_ids[:, -(self.max_stop_tokens + 1):])[0] | |
self.held_text += new_tail[len(old_tail):] | |
# Hold text as long as it contains part of a stop string | |
partial_ss = False | |
for ss in self.stop_strings: | |
# Check if held_text fully contains stop string | |
position = self.held_text.find(ss) | |
if position != -1: | |
self.sequence_str += self.held_text[:position] | |
return self.held_text[:position], True | |
# Check for overlap between end of held_text and start of stop string | |
overlap = 0 | |
for j in range(1, min(len(self.held_text), len(ss)) + 1): | |
if self.held_text[-j:] == ss[:j]: overlap = j | |
if overlap > 0: partial_ss = True | |
# If holding text because of a partial stop condition, return nothing but also EOS = False | |
if partial_ss: | |
return "", False | |
# No stop condition, so return whatever has been generated | |
stream_text = self.held_text | |
self.held_text = "" | |
self.sequence_str += stream_text | |
return stream_text, False | |
# Generate and return a full prompt completion. See begin_stream() above | |
def generate(self, prompt: str, stop_conditions: list, max_new_tokens: int, gen_settings: Settings, encode_special_characters = False): | |
self.begin_stream(prompt, stop_conditions, max_new_tokens, gen_settings, encode_special_characters) | |
response = "" | |
while True: | |
chunk, eos = self.stream() | |
response += chunk | |
if eos: break | |
return response | |
# Begin generation | |
def gen_begin(self, in_tokens, gen_settings): | |
self.sequence_ids = in_tokens.clone() | |
self.cache.current_seq_len = 0 | |
self.model.forward(self.sequence_ids[:, :-1], self.cache, preprocess_only = True, lora = gen_settings.lora) | |
def gen_begin_reuse(self, in_tokens, gen_settings): | |
if self.sequence_ids is None or self.cache.current_seq_len == 0: | |
self.gen_begin(in_tokens, gen_settings) | |
return | |
reuse = 0 | |
while reuse < self.sequence_ids.shape[-1] and reuse < in_tokens.shape[-1] and self.sequence_ids[0, reuse] == in_tokens[0, reuse]: | |
reuse += 1 | |
if reuse < 2: | |
self.gen_begin(in_tokens, gen_settings) | |
return | |
self.cache.current_seq_len = reuse - 1 | |
self.sequence_ids = in_tokens[:, :reuse] | |
if reuse < in_tokens.shape[-1]: self.gen_feed_tokens(in_tokens[:, reuse:], gen_settings) | |
def gen_feed_tokens(self, in_tokens, gen_settings): | |
if self.sequence_ids is None: | |
self.gen_begin(in_tokens, gen_settings) | |
return | |
start = self.cache.current_seq_len | |
self.sequence_ids = torch.cat((self.sequence_ids, in_tokens), dim = 1) | |
self.model.forward(self.sequence_ids[:, start : -1], self.cache, preprocess_only = True, lora = gen_settings.lora) | |
# Generate one token in current sequence | |
def gen_single_token(self, gen_settings): | |
# Simple sampling case: | |
logits = self.model.forward(self.sequence_ids[:, -1:], self.cache, lora = gen_settings.lora) | |
token, _ = self.sample(logits, gen_settings) | |
self.sequence_ids = torch.cat([self.sequence_ids, token], dim = 1) | |
return token | |
def sample(self, logits, gen_settings): | |
cuda_ext.ext_apply_rep_penalty_mask_cpu(self.sequence_ids, | |
self.settings.token_repetition_penalty_max, | |
self.settings.token_repetition_penalty_sustain, | |
self.settings.token_repetition_penalty_decay, | |
logits) | |
logits[:, :, self.tokenizer.bos_token_id] = -10000.0 | |
if logits.dim() == 3: logits = logits[0, -1, :] | |
elif logits.dim() == 2: logits = logits[-1, :] | |
else: raise ValueError("Bad logits dimension") | |
# Disallow tokens | |
if gen_settings.disallowed_tokens is not None: | |
logits[gen_settings.disallowed_tokens] = float("-inf") | |
# Base probabilities | |
logits /= gen_settings.temperature | |
logits += 1e-8 | |
probs = torch.softmax(logits, dim = -1) | |
# Top K | |
if gen_settings.top_k == 0: | |
top_probs, top_indices = torch.sort(probs, descending = True) | |
else: | |
top_probs, top_indices = torch.topk(probs, gen_settings.top_k) | |
top_probs = F.normalize(top_probs, p = 1, dim = -1) | |
# Top P | |
if gen_settings.top_p > 0.0: | |
num_top_p_probs = 0 | |
cum_prob = top_probs[0].item() | |
while True: | |
num_top_p_probs += 1 | |
if num_top_p_probs == top_probs.shape[-1]: break | |
if top_probs[num_top_p_probs].item() < gen_settings.min_p: break | |
cum_prob += top_probs[num_top_p_probs].item() | |
if cum_prob > gen_settings.top_p: break | |
top_probs = top_probs[:num_top_p_probs] | |
top_probs = F.normalize(top_probs, p = 1, dim = -1) | |
top_indices = top_indices[:num_top_p_probs] | |
# Locally typical sampling | |
if gen_settings.typical > 0.0: | |
epsilon = 1e-10 | |
log_probs = (top_probs + epsilon).log() | |
neg_entropy = (top_probs * log_probs).sum() | |
entropy_dev = (neg_entropy - log_probs).abs() | |
_, entropy_dev_order = torch.sort(entropy_dev) | |
top_probs = top_probs.gather(-1, entropy_dev_order) | |
top_indices = top_indices.gather(-1, entropy_dev_order) | |
num_typical_probs = 0 | |
cum_prob = top_probs[0].item() | |
while True: | |
num_typical_probs += 1 | |
if num_typical_probs == top_probs.shape[-1]: break | |
cum_prob += top_probs[num_typical_probs].item() | |
if cum_prob > gen_settings.typical: break | |
top_probs = top_probs[:num_typical_probs] | |
top_probs = F.normalize(top_probs, p = 1, dim = -1) | |
top_indices = top_indices[:num_typical_probs] | |
# Multinomial sampling from top_probs, kept in same order as top_indices | |
sampled_ind = torch.multinomial(top_probs, 1) | |
sampled_tokens = top_indices[sampled_ind] | |
sampled_probs = top_probs[sampled_ind] # Return probs before second norm | |
# if sampled_tokens.shape[0] > 1: | |
# sampled_tokens, ind = sampled_tokens.sort() | |
# sampled_probs = sampled_probs[ind] | |
return sampled_tokens.unsqueeze(0), sampled_probs.unsqueeze(0) | |