from pathlib import Path from typing import Any, Dict, Hashable import streamlit as st import streamlit.components.v1 as components import numpy as np import torch import torch.nn.functional as F from transformers import AutoModelForCausalLM, AutoTokenizer, BatchEncoding, GPT2LMHeadModel, PreTrainedTokenizer from transformers import LogitsProcessorList, RepetitionPenaltyLogitsProcessor, TemperatureLogitsWarper, TopPLogitsWarper, TypicalLogitsWarper root_dir = Path(__file__).resolve().parent highlighted_text_component = components.declare_component( "highlighted_text", path=root_dir / "highlighted_text" / "build" ) def get_windows_batched( examples: BatchEncoding, window_len: int, start: int = 0, stride: int = 1, pad_id: int = 0 ) -> BatchEncoding: return BatchEncoding({ k: [ t[i][j : j + window_len] + [ pad_id if k in ["input_ids", "labels"] else 0 ] * (j + window_len - len(t[i])) for i in range(len(examples["input_ids"])) for j in range(start, len(examples["input_ids"][i]), stride) ] for k, t in examples.items() }) BAD_CHAR = chr(0xfffd) def ids_to_readable_tokens(tokenizer, ids, strip_whitespace=False): cur_ids = [] result = [] for idx in ids: cur_ids.append(idx) decoded = tokenizer.decode(cur_ids) if BAD_CHAR not in decoded: if strip_whitespace: decoded = decoded.strip() result.append(decoded) del cur_ids[:] else: result.append("") return result def nll_score(logprobs, labels): if logprobs.shape[-1] == 1: return -logprobs.squeeze(-1) else: return -logprobs[:, torch.arange(len(labels)), labels] def kl_div_score(logprobs): log_p = logprobs[ torch.arange(logprobs.shape[1]).clamp(max=logprobs.shape[0] - 1), torch.arange(logprobs.shape[1]) ] # Compute things in place as much as possible log_p_minus_log_q = logprobs del logprobs log_p_minus_log_q *= -1 log_p_minus_log_q += log_p # Use np.exp because torch.exp is not implemented for float16 p_np = log_p.numpy() del log_p np.exp(p_np, out=p_np) result = log_p_minus_log_q result *= torch.as_tensor(p_np) return result.sum(dim=-1) compact_layout = st.experimental_get_query_params().get("compact", ["false"]) == ["true"] if not compact_layout: st.title("Context length probing") st.markdown( """[📃 Paper](https://arxiv.org/abs/2212.14815) | [🌍 Website](https://cifkao.github.io/context-probing) | [🧑‍💻 Code](https://github.com/cifkao/context-probing) """ ) generation_mode = st.radio("Mode", ["Standard", "Generation"], horizontal=True) == "Generation" st.caption( "In standard mode, we analyze the model's predictions on the input text. " "In generation mode, we generate a continuation of the input text (prompt) " "and visualize the contributions of different contexts to each generated token." ) model_name = st.selectbox("Model", ["distilgpt2", "gpt2", "EleutherAI/gpt-neo-125m"]) metric_name = st.radio( "Metric", (["KL divergence"] if not generation_mode else []) + ["NLL loss"], index=0, horizontal=True ) tokenizer = st.cache_resource(AutoTokenizer.from_pretrained, show_spinner=False)(model_name, use_fast=False) # Make sure the logprobs do not use up more than ~4 GB of memory MAX_MEM = 4e9 / (torch.finfo(torch.float16).bits / 8) # Select window lengths such that we are allowed to fill the whole window without running out of memory # (otherwise the window length is irrelevant); if using NLL, memory is not a consideration, but we want # to limit runtime multiplier = tokenizer.vocab_size if metric_name == "KL divergence" else 16384 # arbitrary number window_len_options = [ w for w in [8, 16, 32, 64, 128, 256, 512, 1024] if w == 8 or w * (2 * w) * multiplier <= MAX_MEM ] window_len = st.select_slider( r"Window size ($c_\text{max}$)", options=window_len_options, value=min(128, window_len_options[-1]) ) # Now figure out how many tokens we are allowed to use: # window_len * (num_tokens + window_len) * vocab_size <= MAX_MEM max_tokens = int(MAX_MEM / (multiplier * window_len) - window_len) max_tokens = min(max_tokens, 4096) generate_kwargs = {} if generation_mode: with st.expander("Generation options", expanded=False): generate_kwargs["max_new_tokens"] = st.slider( "Max. number of generated tokens", min_value=8, max_value=min(1024, max_tokens), value=min(128, max_tokens) ) col1, col2, col3, col4 = st.columns(4) with col1: generate_kwargs["temperature"] = st.number_input( min_value=0.01, value=0.9, step=0.05, label="`temperature`" ) with col2: generate_kwargs["top_p"] = st.number_input( min_value=0., value=0.95, max_value=1., step=0.05, label="`top_p`" ) with col3: generate_kwargs["typical_p"] = st.number_input( min_value=0., value=1., max_value=1., step=0.05, label="`typical_p`" ) with col4: generate_kwargs["repetition_penalty"] = st.number_input( min_value=1., value=1., step=0.05, label="`repetition_penalty`" ) DEFAULT_TEXT = """ We present context length probing, a novel explanation technique for causal language models, based on tracking the predictions of a model as a function of the length of available context, and allowing to assign differential importance scores to different contexts. The technique is model-agnostic and does not rely on access to model internals beyond computing token-level probabilities. We apply context length probing to large pre-trained language models and offer some initial analyses and insights, including the potential for studying long-range dependencies. """.replace("\n", " ").strip() text = st.text_area( f"Prompt" if generation_mode else "Input text (≤\u2009{max_tokens} tokens)", st.session_state.get("input_text", DEFAULT_TEXT), key="input_text", ) inputs = tokenizer([text]) [input_ids] = inputs["input_ids"] label_ids = [*input_ids[1:], tokenizer.eos_token_id] inputs["labels"] = [label_ids] num_user_tokens = len(input_ids) if num_user_tokens < 1: st.error("Please enter at least one token.", icon="🚨") st.stop() if not generation_mode and num_user_tokens > max_tokens: st.error( f"Your input has {num_user_tokens} tokens. Please enter at most {max_tokens} tokens " f"or try reducing the window size.", icon="🚨" ) st.stop() with st.spinner("Loading model…"): model = st.cache_resource(AutoModelForCausalLM.from_pretrained, show_spinner=False)(model_name) @torch.inference_mode() def get_logprobs(model, inputs, metric): logprobs = [] batch_size = 8 num_items = len(inputs["input_ids"]) pbar = st.progress(0) for i in range(0, num_items, batch_size): pbar.progress(i / num_items, f"{i}/{num_items}") batch = {k: v[i:i + batch_size] for k, v in inputs.items()} batch_logprobs = model(**batch).logits.log_softmax(dim=-1).to(torch.float16) if metric != "KL divergence": batch_logprobs = torch.gather( batch_logprobs, dim=-1, index=batch["labels"][..., None] ) logprobs.append(batch_logprobs) logprobs = torch.cat(logprobs, dim=0) pbar.empty() return logprobs def get_logits_processor(temperature, top_p, typical_p, repetition_penalty) -> LogitsProcessorList: processor = LogitsProcessorList() if repetition_penalty != 1.0: processor.append(RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty)) if temperature != 1.0: processor.append(TemperatureLogitsWarper(temperature)) if top_p < 1.0: processor.append(TopPLogitsWarper(top_p=top_p, min_tokens_to_keep=1)) if typical_p < 1.0: processor.append(TypicalLogitsWarper(mass=typical_p, min_tokens_to_keep=1)) return processor @torch.inference_mode() def generate(model, inputs, metric, window_len, max_new_tokens, **kwargs): assert metric == "NLL loss" start = max(0, inputs["input_ids"].shape[1] - window_len + 1) inputs_window = {k: v[:, start:] for k, v in inputs.items()} del inputs_window["labels"] logits_warper = get_logits_processor(**kwargs) new_ids, logprobs = [], [] eos_idx = None pbar = st.progress(0) max_steps = max_new_tokens + window_len - 1 for i in range(max_steps): pbar.progress(i / max_steps, f"{i}/{max_steps}") inputs_window["attention_mask"] = torch.ones_like(inputs_window["input_ids"], dtype=torch.long) logits_window = model(**inputs_window).logits.squeeze(0) logprobs_window = logits_window.log_softmax(dim=-1) if eos_idx is None: probs_next = logits_warper(inputs_window["input_ids"], logits_window[[-1]]).softmax(dim=-1) next_token = torch.multinomial(probs_next, num_samples=1).item() if next_token == tokenizer.eos_token_id or i >= max_new_tokens - 1: eos_idx = i else: next_token = tokenizer.eos_token_id new_ids.append(next_token) inputs_window["input_ids"] = torch.cat([inputs_window["input_ids"], torch.tensor([[next_token]])], dim=1) if inputs_window["input_ids"].shape[1] > window_len: inputs_window["input_ids"] = inputs_window["input_ids"][:, 1:] if logprobs_window.shape[0] == window_len: logprobs.append( logprobs_window[torch.arange(window_len), inputs_window["input_ids"].squeeze(0)] ) if eos_idx is not None and i - eos_idx >= window_len - 1: break pbar.empty() return torch.as_tensor(new_ids[:eos_idx + 1]), torch.stack(logprobs)[:, :, None] @torch.inference_mode() def run_context_length_probing( _model: GPT2LMHeadModel, _tokenizer: PreTrainedTokenizer, _inputs: Dict[str, torch.Tensor], window_len: int, metric: str, generation_mode: bool, generate_kwargs: Dict[str, Any], cache_key: Hashable ): del cache_key [input_ids] = _inputs["input_ids"] [label_ids] = _inputs["labels"] with st.spinner("Running model…"): if generation_mode: new_ids, logprobs = generate( model=_model, inputs=_inputs.convert_to_tensors("pt"), metric=metric, window_len=window_len, **generate_kwargs ) output_ids = [*input_ids, *new_ids] window_len = logprobs.shape[1] else: window_len = min(window_len, len(input_ids)) inputs_sliding = get_windows_batched( _inputs, window_len=window_len, start=0, pad_id=_tokenizer.eos_token_id ).convert_to_tensors("pt") logprobs = get_logprobs(model=model, inputs=inputs_sliding, metric=metric) output_ids = [*input_ids, label_ids[-1]] num_tgt_tokens = logprobs.shape[0] with st.spinner("Computing scores…"): logprobs = logprobs.permute(1, 0, 2) logprobs = F.pad(logprobs, (0, 0, 0, window_len, 0, 0), value=torch.nan) logprobs = logprobs.view(-1, logprobs.shape[-1])[:-window_len] logprobs = logprobs.view(window_len, num_tgt_tokens + window_len - 1, logprobs.shape[-1]) if metric == "NLL loss": scores = nll_score(logprobs=logprobs, labels=label_ids) elif metric == "KL divergence": scores = kl_div_score(logprobs) del logprobs # possibly destroyed by the score computation to save memory scores = (-scores).diff(dim=0).transpose(0, 1) scores = scores.nan_to_num() scores /= scores.abs().max(dim=1, keepdim=True).values + 1e-6 scores = scores.to(torch.float16) if generation_mode: scores = F.pad(scores, (0, 0, max(0, len(input_ids) - window_len + 1), 0), value=0.) return output_ids, scores if not generation_mode: run_context_length_probing = st.cache_data(run_context_length_probing, show_spinner=False) output_ids, scores = run_context_length_probing( _model=model, _tokenizer=tokenizer, _inputs=inputs, window_len=window_len, metric=metric_name, generation_mode=generation_mode, generate_kwargs=generate_kwargs, cache_key=(model_name, text), ) tokens = ids_to_readable_tokens(tokenizer, output_ids) st.markdown('', unsafe_allow_html=True) highlighted_text_component( tokens=tokens, scores=scores.tolist(), prefix_len=len(input_ids) if generation_mode else 0 )