from enum import Enum from pathlib import Path import streamlit as st import streamlit.components.v1 as components import torch import torch.nn.functional as F from transformers import AutoModelForCausalLM, AutoTokenizer, BatchEncoding root_dir = Path(__file__).resolve().parent highlighted_text_component = components.declare_component( "highlighted_text", path=root_dir / "highlighted_text" / "build" ) def get_windows_batched(examples: BatchEncoding, window_len: int, stride: int = 1, pad_id: int = 0) -> BatchEncoding: return BatchEncoding({ k: [ t[i][j : j + window_len] + [ pad_id if k == "input_ids" else 0 ] * (j + window_len - len(t[i])) for i in range(len(examples["input_ids"])) for j in range(0, len(examples["input_ids"][i]) - 1, stride) ] for k, t in examples.items() }) BAD_CHAR = chr(0xfffd) def ids_to_readable_tokens(tokenizer, ids, strip_whitespace=False): cur_ids = [] result = [] for idx in ids: cur_ids.append(idx) decoded = tokenizer.decode(cur_ids) if BAD_CHAR not in decoded: if strip_whitespace: decoded = decoded.strip() result.append(decoded) del cur_ids[:] else: result.append("") return result compact_layout = st.experimental_get_query_params().get("compact", ["false"]) == ["true"] if not compact_layout: st.title("Context length probing") st.markdown( """[📃 Paper](https://arxiv.org/abs/2212.14815) | [🌍 Website](https://cifkao.github.io/context-probing) | [🧑‍💻 Code](https://github.com/cifkao/context-probing) """ ) model_name = st.selectbox("Model", ["distilgpt2", "gpt2", "EleutherAI/gpt-neo-125m"]) metric_name = st.selectbox("Metric", ["KL divergence", "Cross entropy"], index=1) tokenizer = st.cache_resource(AutoTokenizer.from_pretrained, show_spinner=False)(model_name, use_fast=False) # Make sure the logprobs do not use up more than ~4 GB of memory MAX_MEM = 4e9 / (torch.finfo(torch.float16).bits / 8) # Select window lengths such that we are allowed to fill the whole window without running out of memory # (otherwise the window length is irrelevant) window_len_options = [ w for w in [8, 16, 32, 64, 128, 256, 512, 1024] if w == 8 or w * (2 * w) * tokenizer.vocab_size <= MAX_MEM ] window_len = st.select_slider( r"Window size ($c_\text{max}$)", options=window_len_options, value=min(128, window_len_options[-1]) ) # Now figure out how many tokens we are allowed to use: # window_len * (num_tokens + window_len) * vocab_size <= MAX_MEM max_tokens = int(MAX_MEM / (tokenizer.vocab_size * window_len) - window_len) DEFAULT_TEXT = """ We present context length probing, a novel explanation technique for causal language models, based on tracking the predictions of a model as a function of the length of available context, and allowing to assign differential importance scores to different contexts. The technique is model-agnostic and does not rely on access to model internals beyond computing token-level probabilities. We apply context length probing to large pre-trained language models and offer some initial analyses and insights, including the potential for studying long-range dependencies. """.replace("\n", " ").strip() text = st.text_area("Input text", DEFAULT_TEXT) if tokenizer.eos_token: text += tokenizer.eos_token inputs = tokenizer([text]) [input_ids] = inputs["input_ids"] num_user_tokens = len(input_ids) - (1 if tokenizer.eos_token else 0) if num_user_tokens < 1 or num_user_tokens > max_tokens: st.caption(f":red[{num_user_tokens}]/{max_tokens} tokens") else: st.caption(f"{num_user_tokens}/{max_tokens} tokens") if num_user_tokens < 1: st.error("Please enter at least one token.", icon="🚨") st.stop() if num_user_tokens > max_tokens: st.error( f"Please enter at most {max_tokens} tokens or try reducing the window size.", icon="🚨" ) st.stop() if metric_name == "KL divergence": st.error("KL divergence is not supported yet. Stay tuned!", icon="😭") st.stop() with st.spinner("Loading model…"): model = st.cache_resource(AutoModelForCausalLM.from_pretrained, show_spinner=False)(model_name) window_len = min(window_len, len(input_ids)) @st.cache_data(show_spinner=False) @torch.inference_mode() def get_logprobs(_model, _inputs, cache_key): del cache_key return _model(**_inputs).logits.log_softmax(dim=-1).to(torch.float16) @st.cache_data(show_spinner=False) @torch.inference_mode() def run_context_length_probing(_model, _tokenizer, _inputs, window_len, cache_key): del cache_key inputs_sliding = get_windows_batched( _inputs, window_len=window_len, pad_id=_tokenizer.eos_token_id ).convert_to_tensors("pt") logprobs = [] with st.spinner("Running model…"): batch_size = 8 num_items = len(inputs_sliding["input_ids"]) pbar = st.progress(0) for i in range(0, num_items, batch_size): pbar.progress(i / num_items, f"{i}/{num_items}") batch = {k: v[i:i + batch_size] for k, v in inputs_sliding.items()} logprobs.append( get_logprobs( _model, batch, cache_key=(model_name, batch["input_ids"].cpu().numpy().tobytes()) ) ) logprobs = torch.cat(logprobs, dim=0) pbar.empty() with st.spinner("Computing scores…"): logprobs = logprobs.permute(1, 0, 2) logprobs = F.pad(logprobs, (0, 0, 0, window_len, 0, 0), value=torch.nan) logprobs = logprobs.view(-1, logprobs.shape[-1])[:-window_len] logprobs = logprobs.view(window_len, len(input_ids) + window_len - 2, logprobs.shape[-1]) scores = logprobs[:, torch.arange(len(input_ids[1:])), input_ids[1:]] scores = scores.diff(dim=0).transpose(0, 1) scores = scores.nan_to_num() scores /= scores.abs().max(dim=1, keepdim=True).values + 1e-6 scores = scores.to(torch.float16) return scores scores = run_context_length_probing( _model=model, _tokenizer=tokenizer, _inputs=inputs, window_len=window_len, cache_key=(model_name, text), ) tokens = ids_to_readable_tokens(tokenizer, input_ids) st.markdown('', unsafe_allow_html=True) highlighted_text_component(tokens=tokens, scores=scores.tolist())