Spaces:
Sleeping
Sleeping
from pathlib import Path | |
import streamlit as st | |
import streamlit.components.v1 as components | |
import torch | |
import torch.nn.functional as F | |
from transformers import AutoModelForCausalLM, AutoTokenizer, BatchEncoding | |
root_dir = Path(__file__).resolve().parent | |
_highlighted_text_fn = components.declare_component( | |
"highlighted_text", path=root_dir / "highlighted_text" / "build" | |
) | |
def highlighted_text(tokens, scores, key=None): | |
return _highlighted_text_fn(tokens=tokens, scores=scores, key=key, default=0) | |
def get_windows_batched(examples: BatchEncoding, window_len: int, stride: int = 1, pad_id: int = 0) -> BatchEncoding: | |
return BatchEncoding({ | |
k: [ | |
t[i][j : j + window_len] + [ | |
pad_id if k == "input_ids" else 0 | |
] * (j + window_len - len(t[i])) | |
for i in range(len(examples["input_ids"])) | |
for j in range(0, len(examples["input_ids"][i]) - 1, stride) | |
] | |
for k, t in examples.items() | |
}) | |
BAD_CHAR = chr(0xfffd) | |
def ids_to_readable_tokens(tokenizer, ids, strip_whitespace=False): | |
cur_ids = [] | |
result = [] | |
for idx in ids: | |
cur_ids.append(idx) | |
decoded = tokenizer.decode(cur_ids) | |
if BAD_CHAR not in decoded: | |
if strip_whitespace: | |
decoded = decoded.strip() | |
result.append(decoded) | |
del cur_ids[:] | |
else: | |
result.append("") | |
return result | |
model_name = st.selectbox("Model", ["distilgpt2"]) | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForCausalLM.from_pretrained(model_name) | |
window_len = st.select_slider("Window size", options=[8, 16, 32, 64, 128, 256, 512, 1024], value=512) | |
text = st.text_area("Input text", "The complex houses married and single soldiers and their families.") | |
inputs = tokenizer([text]) | |
[input_ids] = inputs["input_ids"] | |
window_len = min(window_len, len(input_ids)) | |
tokens = ids_to_readable_tokens(tokenizer, input_ids) | |
inputs_sliding = get_windows_batched( | |
inputs, | |
window_len=window_len, | |
pad_id=tokenizer.eos_token_id | |
) | |
with torch.inference_mode(): | |
logits = model(**inputs_sliding.convert_to_tensors("pt")).logits.to(torch.float16) | |
logits = F.pad(logits, (0, 0, 0, window_len - 1, 0, 0), value=torch.nan) | |
logits = logits.view(-1, logits.shape[-1])[:(window_len - 1) * (len(input_ids) + window_len - 2)] | |
logits = logits.view(window_len - 1, len(input_ids) + window_len - 2, logits.shape[-1]) | |
scores = logits.to(torch.float32).softmax(dim=-1) | |
scores = scores[:, torch.arange(len(input_ids[1:])), input_ids[1:]] | |
scores = scores.diff(dim=0).transpose(0, 1) | |
scores = scores.nan_to_num() | |
scores /= scores.abs().max(dim=1, keepdim=True).values + 1e-9 | |
scores = scores.to(torch.float16) | |
print(scores) | |
st.markdown("---") | |
highlighted_text(tokens=tokens, scores=scores.tolist()) | |