from enum import Enum
from pathlib import Path
import streamlit as st
import streamlit.components.v1 as components
import numpy as np
import torch
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer, BatchEncoding
root_dir = Path(__file__).resolve().parent
highlighted_text_component = components.declare_component(
"highlighted_text", path=root_dir / "highlighted_text" / "build"
def get_windows_batched(examples: BatchEncoding, window_len: int, stride: int = 1, pad_id: int = 0) -> BatchEncoding:
return BatchEncoding({
k: [
t[i][j : j + window_len] + [
pad_id if k in ["input_ids", "labels"] else 0
] * (j + window_len - len(t[i]))
for i in range(len(examples["input_ids"]))
for j in range(0, len(examples["input_ids"][i]) - 1, stride)
for k, t in examples.items()
BAD_CHAR = chr(0xfffd)
def ids_to_readable_tokens(tokenizer, ids, strip_whitespace=False):
cur_ids = []
result = []
for idx in ids:
decoded = tokenizer.decode(cur_ids)
if BAD_CHAR not in decoded:
if strip_whitespace:
decoded = decoded.strip()
del cur_ids[:]
return result
def nll_score(logprobs, labels):
if logprobs.shape[-1] == 1:
return -logprobs.squeeze(-1)
return -logprobs[:, torch.arange(len(labels)), labels]
def kl_div_score(logprobs):
log_p = logprobs[
torch.arange(logprobs.shape[1]).clamp(max=logprobs.shape[0] - 1),
# Compute things in place as much as possible
log_p_minus_log_q = logprobs
del logprobs
log_p_minus_log_q *= -1
log_p_minus_log_q += log_p
# Use np.exp because torch.exp is not implemented for float16
p_np = log_p.numpy()
del log_p
np.exp(p_np, out=p_np)
result = log_p_minus_log_q
result *= torch.as_tensor(p_np)
return result.sum(dim=-1)
compact_layout = st.experimental_get_query_params().get("compact", ["false"]) == ["true"]
if not compact_layout:
st.title("Context length probing")
generation_mode = False
model_name = st.selectbox("Model", ["distilgpt2", "gpt2", "EleutherAI/gpt-neo-125m"])
metric_name =
"Metric", (["KL divergence"] if not generation_mode else []) + ["NLL loss"], index=0, horizontal=True
tokenizer = st.cache_resource(AutoTokenizer.from_pretrained, show_spinner=False)(model_name, use_fast=False)
# Make sure the logprobs do not use up more than ~4 GB of memory
MAX_MEM = 4e9 / (torch.finfo(torch.float16).bits / 8)
# Select window lengths such that we are allowed to fill the whole window without running out of memory
# (otherwise the window length is irrelevant)
logprobs_dim = tokenizer.vocab_size if metric_name == "KL divergence" else 1
window_len_options = [
w for w in [8, 16, 32, 64, 128, 256, 512, 1024]
if w == 8 or w * (2 * w) * logprobs_dim <= MAX_MEM
window_len = st.select_slider(
r"Window size ($c_\text{max}$)",
value=min(128, window_len_options[-1])
# Now figure out how many tokens we are allowed to use:
# window_len * (num_tokens + window_len) * vocab_size <= MAX_MEM
max_tokens = int(MAX_MEM / (logprobs_dim * window_len) - window_len)
max_tokens = min(max_tokens, 2048)
We present context length probing, a novel explanation technique for causal
language models, based on tracking the predictions of a model as a function of the length of
available context, and allowing to assign differential importance scores to different contexts.
The technique is model-agnostic and does not rely on access to model internals beyond computing
token-level probabilities. We apply context length probing to large pre-trained language models
and offer some initial analyses and insights, including the potential for studying long-range
""".replace("\n", " ").strip()
text = st.text_area(
f"Input text (≤\u2009{max_tokens} tokens)",
st.session_state.get("input_text", DEFAULT_TEXT),
if tokenizer.eos_token:
text += tokenizer.eos_token
inputs = tokenizer([text])
[input_ids] = inputs["input_ids"]
inputs["labels"] = [[*input_ids[1:], tokenizer.eos_token_id]]
num_user_tokens = len(input_ids) - (1 if tokenizer.eos_token else 0)
if num_user_tokens < 1:
st.error("Please enter at least one token.", icon="🚨")
if num_user_tokens > max_tokens:
f"Your input has {num_user_tokens} tokens. Please enter at most {max_tokens} tokens "
f"or try reducing the window size.",
with st.spinner("Loading model…"):
model = st.cache_resource(AutoModelForCausalLM.from_pretrained, show_spinner=False)(model_name)
window_len = min(window_len, len(input_ids))
def get_logprobs(_model, _inputs, cache_key):
del cache_key
return _model(**_inputs).logits.log_softmax(dim=-1).to(torch.float16)
def run_context_length_probing(_model, _tokenizer, _inputs, window_len, metric, cache_key):
del cache_key
inputs_sliding = get_windows_batched(
logprobs = []
with st.spinner("Running model…"):
batch_size = 8
num_items = len(inputs_sliding["input_ids"])
pbar = st.progress(0)
for i in range(0, num_items, batch_size):
pbar.progress(i / num_items, f"{i}/{num_items}")
batch = {k: v[i:i + batch_size] for k, v in inputs_sliding.items()}
batch_logprobs = get_logprobs(
cache_key=(model_name, batch["input_ids"].cpu().numpy().tobytes())
batch_labels = batch["labels"]
if metric != "KL divergence":
batch_logprobs = torch.gather(
batch_logprobs, dim=-1, index=batch_labels[..., None]
logprobs =, dim=0)
with st.spinner("Computing scores…"):
logprobs = logprobs.permute(1, 0, 2)
logprobs = F.pad(logprobs, (0, 0, 0, window_len, 0, 0), value=torch.nan)
logprobs = logprobs.view(-1, logprobs.shape[-1])[:-window_len]
logprobs = logprobs.view(window_len, len(input_ids) + window_len - 2, logprobs.shape[-1])
if metric == "NLL loss":
scores = nll_score(logprobs=logprobs, labels=input_ids[1:])
elif metric == "KL divergence":
scores = kl_div_score(logprobs)
del logprobs # possibly destroyed by the score computation to save memory
scores = (-scores).diff(dim=0).transpose(0, 1)
scores = scores.nan_to_num()
scores /= scores.abs().max(dim=1, keepdim=True).values + 1e-6
scores =
return scores
scores = run_context_length_probing(
cache_key=(model_name, text),
tokens = ids_to_readable_tokens(tokenizer, input_ids)
st.markdown('<label style="font-size: 14px;">Output</label>', unsafe_allow_html=True)
highlighted_text_component(tokens=tokens, scores=scores.tolist())