# Device-independent algorithms for LLM. import logging import time import torch import torch.nn.functional as F import transformers logger = logging.getLogger(__name__) # Helper to pull out the response tokens. def isolate_responses_BL(output_BL, prompt_len, eos_token_id): responses_BL = [] resp_BL = output_BL[:, prompt_len:] for i in range(resp_BL.shape[0]): resp_L = resp_BL[i] resplen = resp_L.shape[0] for j in range(resplen): if resp_L[j] == eos_token_id: resplen = j+1 break response_L = resp_L[:resplen].cpu().detach().numpy() responses_BL.append(response_L) return responses_BL def tokenize_prompt(device, tokenizer, chat, quiet=False): # Tokenize the prompt. prompt_BL = tokenizer.apply_chat_template( [chat], tokenizer=True, add_generation_prompt=True, return_tensors='pt' ).to(device) if not quiet: print('PROMPT:') print(tokenizer.decode(prompt_BL[0])) return prompt_BL def generate(device, model, tokenizer, chat): """Generate a response using huggingface's generation.""" prompt_BL = tokenize_prompt(device, tokenizer, chat) prompt_len = prompt_BL.shape[1] # Generate response. # Unfortunately, huggingface's generation code uses 'cumsum', # which doesn't have a deterministic implementation. torch.use_deterministic_algorithms(False) generation_output = model.generate( inputs=prompt_BL, max_new_tokens=512, do_sample=True, return_dict_in_generate=True, ) torch.use_deterministic_algorithms(True) output_BL = generation_output.sequences for response_L in isolate_responses_BL(output_BL, prompt_len, tokenizer.eos_token_id): print(f'RESPONSE:') print(tokenizer.decode(response_L)) return tokenizer.decode(response_L) def generate_with_logits(device, model, tokenizer, chat, seed=None): # Huggingface generation that returns logits too. if seed is not None: transformers.set_seed(seed) prompt_BL = tokenize_prompt(device, tokenizer, chat) prompt_len = prompt_BL.shape[1] # Unfortunately, huggingface's generation code uses 'cumsum', # which doesn't have a deterministic implementation. torch.use_deterministic_algorithms(False) generation_output = model.generate( inputs=prompt_BL, max_new_tokens=16, do_sample=True, return_dict_in_generate=True, output_logits=True, ) torch.use_deterministic_algorithms(True) output_BL = generation_output.sequences logits_BLV = torch.stack(generation_output.logits, axis=1) for response_L in isolate_responses_BL(output_BL, prompt_len, tokenizer.eos_token_id): print(f'RESPONSE:') print(tokenizer.decode(response_L)) return response_L, logits_BLV[0] def response_logits(device, model, tokenizer, chat, response_L): # Calculate logits using a single pass. prompt_BL = tokenize_prompt(device, tokenizer, chat) response_pt_L = torch.from_numpy(response_L[:-1]) # Concatenate along axis 1. input_ids = torch.cat((prompt_BL, response_pt_L[None]), dim=1) outputs = model(input_ids) logits_BLV = outputs.logits return logits_BLV[0][-len(response_L):] def generate_custom(device, model, tokenizer, chat, max_tokens=512, seed=None, return_tokens=False, quiet=False, return_lnprobs=False): """Generate a response using custom generation.""" if seed is not None: torch.manual_seed(seed) prompt_BL = tokenize_prompt(device, tokenizer, chat, quiet=quiet) # Generate response. input_ids = prompt_BL past_key_values = None response_L = [] lnprobs_L = [] while 1: outputs = model(input_ids, past_key_values=past_key_values, use_cache=True) past_key_values = outputs.past_key_values logits = outputs.logits[0, -1, :] p_V = F.softmax(logits, dim=-1) token = torch.multinomial(p_V, num_samples=1).item() response_L.append(token) lnprobs_L.append(F.log_softmax(logits, dim=-1)[token].item()) p_token = p_V[token].item() if not quiet: print(f' Sampled token {format_token(tokenizer, token)} ({p_token*100.0:.3f}%)') if token == tokenizer.eos_token_id or len(response_L) >= max_tokens: break input_ids = torch.tensor([[token]], device=device) if not quiet: print(f'RESPONSE:') print(tokenizer.decode(response_L)) if return_tokens: if return_lnprobs: return response_L, lnprobs_L else: return response_L else: assert not return_lnprobs return tokenizer.decode(response_L) def format_token(tokenizer, token_id): return repr(tokenizer.decode(token_id)) def apoc(device, model_x, model_y, tokenizer, chat_x, chat_y, max_tokens=512, seed=None, return_tokens=False, quiet=False): """Generate a response using APOC unconditional sampling.""" if seed is not None: torch.manual_seed(seed) # This early implementation of the algorithm is numerically non-robust, # so reduce problems by using high-precision floating-point. logit_dtype = torch.float64 prompt_x_BL = tokenize_prompt(device, tokenizer, chat_x, quiet=quiet) prompt_y_BL = tokenize_prompt(device, tokenizer, chat_y, quiet=quiet) # Four variables are needed, since in the first iteration it depends on prompt (X vs Y), # whereas in later iterations it depends on response (a vs b). input_ids_xa = prompt_x_BL input_ids_ya = prompt_y_BL input_ids_xb = prompt_x_BL input_ids_yb = prompt_y_BL past_key_values_xa = None past_key_values_ya = None past_key_values_xb = None past_key_values_yb = None equal = True a_eos = False b_eos = False def zero(): return torch.zeros(1, dtype=logit_dtype, device=device) ln_pya_m_ln_pxa = zero() ln_pxb_m_ln_pyb = zero() response_a_L = [] response_b_L = [] i = 0 while 1: if i >= max_tokens or (a_eos and b_eos): break if not quiet: print(f'Generating response token {i}') i += 1 forward_passes_start = time.perf_counter() if not a_eos: outputs = model_x(input_ids_xa, past_key_values=past_key_values_xa, use_cache=True) past_key_values_xa = outputs.past_key_values logits = outputs.logits[0, -1, :].to(logit_dtype) ln_pxa_V = F.log_softmax(logits, dim=-1) outputs = model_y(input_ids_ya, past_key_values=past_key_values_ya, use_cache=True) past_key_values_ya = outputs.past_key_values logits = outputs.logits[0, -1, :].to(logit_dtype) ln_pya_V = F.log_softmax(logits, dim=-1) if not b_eos: if equal: # In equal mode, neither input_ids nor past_key_values depends on a vs b, # so we can reuse the forward pass results for a 50% time savings. assert not a_eos past_key_values_xb = past_key_values_xa ln_pxb_V = ln_pxa_V past_key_values_yb = past_key_values_ya ln_pyb_V = ln_pya_V else: outputs = model_x(input_ids_xb, past_key_values=past_key_values_xb, use_cache=True) past_key_values_xb = outputs.past_key_values logits = outputs.logits[0, -1, :].to(logit_dtype) ln_pxb_V = F.log_softmax(logits, dim=-1) outputs = model_y(input_ids_yb, past_key_values=past_key_values_yb, use_cache=True) past_key_values_yb = outputs.past_key_values logits = outputs.logits[0, -1, :].to(logit_dtype) ln_pyb_V = F.log_softmax(logits, dim=-1) forward_passes_end = time.perf_counter() if not quiet: print(f' Forward passes took {(forward_passes_end - forward_passes_start)*1000:.0f} ms') if equal: ln_pmeet_V = torch.minimum( ln_pxa_V + torch.maximum(zero(), -ln_pya_m_ln_pxa), ln_pya_V + torch.maximum(zero(), ln_pya_m_ln_pxa), ) pmeet_V = torch.exp(ln_pmeet_V) pmeet = torch.sum(pmeet_V) if not quiet: print(f' Equal mode (pmeet={pmeet*100.0:.3f}%)') if torch.rand_like(pmeet) < pmeet: token_a = token_b = torch.multinomial(pmeet_V, num_samples=1).item() p_token_a = (pmeet_V[token_a] / pmeet).item() if not quiet: print(f' Sampled {format_token(tokenizer, token_a)} ({p_token_a*100.0:.1f}%)') else: if not quiet: print(' Exited equal mode') equal = False if not equal: if not a_eos: wxt_V = torch.maximum(zero(), torch.exp(ln_pxa_V) - torch.exp(ln_pya_V + ln_pya_m_ln_pxa)) token_a = torch.multinomial(wxt_V, num_samples=1).item() p_token_a = (wxt_V[token_a] / torch.sum(wxt_V)).item() if not quiet: print(f' Sampled token_a {format_token(tokenizer, token_a)} ({p_token_a*100.0:.3f}%)') if not b_eos: wyt_V = torch.maximum(zero(), torch.exp(ln_pyb_V) - torch.exp(ln_pxb_V + ln_pxb_m_ln_pyb)) token_b = torch.multinomial(wyt_V, num_samples=1).item() p_token_b = (wyt_V[token_b] / torch.sum(wyt_V)).item() if not quiet: print(f' Sampled token_b {format_token(tokenizer, token_b)} ({p_token_b*100.0:.3f}%)') if not a_eos: response_a_L.append(token_a) input_ids_xa = input_ids_ya = torch.tensor([[token_a]], device=device) ln_pya_m_ln_pxa += ln_pya_V[token_a] - ln_pxa_V[token_a] if token_a == tokenizer.eos_token_id: a_eos = True if not b_eos: response_b_L.append(token_b) input_ids_xb = input_ids_yb = torch.tensor([[token_b]], device=device) ln_pxb_m_ln_pyb += ln_pxb_V[token_b] - ln_pyb_V[token_b] if token_b == tokenizer.eos_token_id: b_eos = True if not quiet: print(f'RESPONSE X:') print(tokenizer.decode(response_a_L)) print(f'RESPONSE Y:') print(tokenizer.decode(response_b_L)) if return_tokens: return response_a_L, response_b_L else: return tokenizer.decode(response_a_L), tokenizer.decode(response_b_L) # Alternative implementation. @torch.no_grad() def apoc_alt(device, model_x, model_y, tokenizer, chat_x, chat_y, max_tokens=512, seed=None): if seed is not None: torch.manual_seed(seed) prompt_x_BL = tokenize_prompt(device, tokenizer, chat_x, quiet=True) prompt_y_BL = tokenize_prompt(device, tokenizer, chat_y, quiet=True) model_pair = ModelPair(model_x, model_y, prompt_x_BL, prompt_y_BL) logger.debug('PROMPT X:') logger.debug(tokenizer.decode(prompt_x_BL[0])) logger.debug('PROMPT Y:') logger.debug(tokenizer.decode(prompt_y_BL[0])) return _apoc_impl(model_pair, tokenizer, max_tokens) LOGIT_DTYPE = torch.float64 class ModelPair: def __init__(self, model_x, model_y, prompt_x_BL, prompt_y_BL): self._model_x = model_x self._model_y = model_y self._prompt_x_BL = prompt_x_BL self._prompt_y_BL = prompt_y_BL self._is_swapped = False def start(self): # Return logprobs for the initial token. outputs = self._model_x(self._prompt_x_BL, use_cache=True) self._past_key_values_x = outputs.past_key_values logits = outputs.logits[0, -1, :].to(LOGIT_DTYPE) lnpx_V = F.log_softmax(logits, dim=-1) outputs = self._model_y(self._prompt_y_BL, use_cache=True) self._past_key_values_y = outputs.past_key_values logits = outputs.logits[0, -1, :].to(LOGIT_DTYPE) lnpy_V = F.log_softmax(logits, dim=-1) return self._maybe_swap(lnpx_V, lnpy_V) def step(self, token): # Append the given token, then return logprobs for the next token. forward_passes_start = time.perf_counter() input_ids = torch.tensor([[token]], device=self._prompt_x_BL.device) outputs = self._model_x(input_ids, past_key_values=self._past_key_values_x, use_cache=True) self._past_key_values_x = outputs.past_key_values logits = outputs.logits[0, -1, :].to(LOGIT_DTYPE) lnpx_V = F.log_softmax(logits, dim=-1) outputs = self._model_y(input_ids, past_key_values=self._past_key_values_y, use_cache=True) self._past_key_values_y = outputs.past_key_values logits = outputs.logits[0, -1, :].to(LOGIT_DTYPE) lnpy_V = F.log_softmax(logits, dim=-1) forward_passes_end = time.perf_counter() logger.debug(f'Incremental forward passes took {(forward_passes_end - forward_passes_start)*1000:.0f} ms') return self._maybe_swap(lnpx_V, lnpy_V) def get_position(self): # Return a position that can be rewound to. return self._past_key_values_x, self._past_key_values_y def rewind_to(self, position): # Rewind the KV cache. self._past_key_values_x, self._past_key_values_y = position def swap_models(self): # Exchange the order of the models. self._is_swapped = not self._is_swapped def _maybe_swap(self, a, b): if self._is_swapped: return b, a else: return a, b def _apoc_impl(model_pair, tokenizer, max_tokens): prefix = [] lnpx_V, lnpy_V = model_pair.start() lnpy_m_lnpx = torch.zeros(1, dtype=lnpx_V.dtype, device=lnpx_V.device) while 1: ln_peq_V = torch.minimum( lnpx_V + F.relu(-lnpy_m_lnpx), lnpy_V + F.relu(lnpy_m_lnpx), ) peq_V = torch.exp(ln_peq_V) peq = torch.sum(peq_V) if torch.rand_like(peq) > peq: logger.debug(f'Completed common prefix ({(1-peq)*100.:.3f}%)') break logger.debug(f'Extending common prefix ({peq*100.:.3f}%)') token = torch.multinomial(peq_V, 1).item() prefix.append(token) lnpy_m_lnpx += lnpy_V[token] - lnpx_V[token] p_token = (peq_V[token] / peq).item() logger.debug(f'Sampled prefix token {format_token(tokenizer, token)} ({p_token*100.0:.3f}%)') if token == tokenizer.eos_token_id or len(prefix) >= max_tokens: return prefix, prefix lnpx_V, lnpy_V = model_pair.step(token) remaining_tokens = max_tokens - len(prefix) split_pos = model_pair.get_position() response_a = prefix + _apoc_gen_suffix(model_pair, tokenizer, remaining_tokens, lnpx_V, lnpy_V, lnpy_m_lnpx) logger.debug('First suffix complete; rewinding') model_pair.rewind_to(split_pos) model_pair.swap_models() response_b = prefix + _apoc_gen_suffix(model_pair, tokenizer, remaining_tokens, lnpy_V, lnpx_V, -lnpy_m_lnpx) return response_a, response_b def _apoc_gen_suffix(model_pair, tokenizer, max_tokens, lnpx_V, lnpy_V, lnpy_m_lnpx): lnpy_m_lnpx = lnpy_m_lnpx.clone() suffix = [] while 1: wx_V = F.relu(torch.exp(lnpx_V) - torch.exp(lnpy_V + lnpy_m_lnpx)) token = torch.multinomial(wx_V, 1).item() suffix.append(token) lnpy_m_lnpx += lnpy_V[token] - lnpx_V[token] p_token = (wx_V[token] / torch.sum(wx_V)).item() logger.debug(f'Sampled suffix token {format_token(tokenizer, token)} ({p_token*100.0:.3f}%)') if token == tokenizer.eos_token_id or len(suffix) >= max_tokens: return suffix lnpx_V, lnpy_V = model_pair.step(token) def generate_streaming(device, model, tokenizer, chat, max_tokens=512, seed=None): """Stream a response using custom generation.""" prompt_BL = tokenize_prompt(device, tokenizer, chat, quiet=True) logger.debug('PROMPT:') logger.debug(tokenizer.decode(prompt_BL[0])) if seed is not None: torch.manual_seed(seed) return _generate_streaming_impl(device, model, tokenizer, prompt_BL, max_tokens) def _generate_streaming_impl(device, model, tokenizer, prompt_BL, max_tokens): input_ids = prompt_BL past_key_values = None n_tokens = 0 while 1: outputs = model(input_ids, past_key_values=past_key_values, use_cache=True) past_key_values = outputs.past_key_values logits = outputs.logits[0, -1, :] p_V = F.softmax(logits, dim=-1) token = torch.multinomial(p_V, num_samples=1).item() p_token = p_V[token].item() logger.debug(f' Sampled token {format_token(tokenizer, token)} ({p_token*100.0:.3f}%)') yield token n_tokens += 1 if token == tokenizer.eos_token_id or n_tokens >= max_tokens: break input_ids = torch.tensor([[token]], device=device) # APOC unconditional streaming @torch.no_grad() def apoc_streaming(model_x, model_y, tokenizer, chat_x, chat_y, max_tokens=512, seed=None): if seed is not None: torch.manual_seed(seed) prompt_x_BL = tokenize_prompt(model_x.device, tokenizer, chat_x, quiet=True) prompt_y_BL = tokenize_prompt(model_y.device, tokenizer, chat_y, quiet=True) model_pair = ModelPair(model_x, model_y, prompt_x_BL, prompt_y_BL) logger.debug('PROMPT X:') logger.debug(tokenizer.decode(prompt_x_BL[0])) logger.debug('PROMPT Y:') logger.debug(tokenizer.decode(prompt_y_BL[0])) return _apoc_streaming_impl(model_pair, tokenizer, max_tokens) def _apoc_streaming_impl(model_pair, tokenizer, max_tokens): remaining_tokens = max_tokens lnpx_V, lnpy_V = model_pair.start() lnpy_m_lnpx = torch.zeros(1, dtype=lnpx_V.dtype, device=lnpx_V.device) while 1: ln_peq_V = torch.minimum( lnpx_V + F.relu(-lnpy_m_lnpx), lnpy_V + F.relu(lnpy_m_lnpx), ) peq_V = torch.exp(ln_peq_V) peq = torch.sum(peq_V) if torch.rand_like(peq) > peq: logger.debug(f'Completed common prefix ({(1-peq)*100.:.3f}%)') break logger.debug(f'Extending common prefix ({peq*100.:.3f}%)') token = torch.multinomial(peq_V, 1).item() remaining_tokens -= 1 yield token, token lnpy_m_lnpx += lnpy_V[token] - lnpx_V[token] p_token = (peq_V[token] / peq).item() logger.debug(f'Sampled prefix token {format_token(tokenizer, token)} ({p_token*100.0:.3f}%)') if token == tokenizer.eos_token_id or remaining_tokens == 0: return lnpx_V, lnpy_V = model_pair.step(token) split_pos = model_pair.get_position() for token_a in _apoc_streaming_gen_suffix(model_pair, tokenizer, remaining_tokens, lnpx_V, lnpy_V, lnpy_m_lnpx): yield token_a, None logger.debug('Suffix a complete; rewinding') model_pair.rewind_to(split_pos) model_pair.swap_models() for token_b in _apoc_streaming_gen_suffix(model_pair, tokenizer, remaining_tokens, lnpy_V, lnpx_V, -lnpy_m_lnpx): yield None, token_b logger.debug('Suffix b complete') def _apoc_streaming_gen_suffix(model_pair, tokenizer, max_tokens, lnpx_V, lnpy_V, lnpy_m_lnpx): remaining_tokens = max_tokens lnpy_m_lnpx = lnpy_m_lnpx.clone() while 1: wx_V = F.relu(torch.exp(lnpx_V) - torch.exp(lnpy_V + lnpy_m_lnpx)) token = torch.multinomial(wx_V, 1).item() remaining_tokens -= 1 yield token lnpy_m_lnpx += lnpy_V[token] - lnpx_V[token] p_token = (wx_V[token] / torch.sum(wx_V)).item() logger.debug(f'Sampled suffix token {format_token(tokenizer, token)} ({p_token*100.0:.3f}%)') if token == tokenizer.eos_token_id or remaining_tokens == 0: return lnpx_V, lnpy_V = model_pair.step(token)