File size: 6,042 Bytes
3494c6b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 |
import unicodedata
from typing import List, Optional
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from .logit_lens import LogitLens
def generate_interactive(
model: AutoModelForCausalLM,
tok: AutoTokenizer,
top_k: int = 5,
max_out_len: int = 200,
compare_against: Optional[AutoModelForCausalLM] = None,
use_logit_lens: bool = False,
layer_module_tmp: str = "transformer.h.{}",
ln_f_module: str = "transformer.ln_f",
lm_head_module: str = "lm_head",
):
"""
Puts generation in a loop. Allows users to repeatedly provide inputs
with which text is generated.
"""
if use_logit_lens:
llens_gen = LogitLens(
model,
tok,
layer_module_tmp,
ln_f_module,
lm_head_module,
disabled=not use_logit_lens,
)
if compare_against:
llens_vanilla = LogitLens(
compare_against,
tok,
layer_module_tmp,
ln_f_module,
lm_head_module,
disabled=not use_logit_lens,
)
while True:
prompt = input("Enter a prompt: ").strip(" \r\t\n")
print(
f"Argument Model: "
f"{generate_fast(model, tok, [prompt], n_gen_per_prompt=1, top_k=top_k, max_out_len=max_out_len)}"
)
if compare_against:
print(
f"Baseline Model: "
f"{generate_fast(compare_against, tok, [prompt], n_gen_per_prompt=1, top_k=top_k, max_out_len=max_out_len)}"
)
if use_logit_lens:
inp_prompt = tok([prompt], padding=True, return_tensors="pt").to(
next(model.parameters()).device
)
with llens_gen:
model(**inp_prompt)
print("\n--- Argument Model Logit Lens ---")
llens_gen.pprint()
if compare_against:
with llens_vanilla:
compare_against(**inp_prompt)
print("--- Baseline Model Logit Lens ---")
llens_vanilla.pprint()
print()
def generate_fast(
model: AutoModelForCausalLM,
tok: AutoTokenizer,
prompts: List[str],
n_gen_per_prompt: int = 1,
top_k: int = 5,
max_out_len: int = 200,
vanilla_generation=False,
):
"""
Fast, parallelized auto-regressive text generation with top-k sampling.
Our custom implementation.
"""
# Unroll prompts and tokenize
inp = [prompt for prompt in prompts for _ in range(n_gen_per_prompt)]
inp_tok = tok(inp, padding=True, return_tensors="pt").to(
next(model.parameters()).device
)
input_ids, attention_mask = inp_tok["input_ids"], inp_tok["attention_mask"]
if vanilla_generation:
gen_txt = model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
max_new_tokens=max_out_len
)
txt = [tok.decode(x, skip_special_tokens=True) for x in gen_txt.detach().cpu().numpy().tolist()]
txt = [
unicodedata.normalize("NFKD", x)
.replace("\n\n", " ")
.replace("<|endoftext|>", "")
for x in txt
]
return txt
batch_size = input_ids.size(0)
# Setup storage of fast generation with attention caches.
# `cur_context` is used to define the range of inputs that are not yet
# stored in `past_key_values`. At each step, we are generating the
# next token for the index at `cur_context.stop + 1`.
past_key_values, cur_context = None, slice(0, attention_mask.sum(1).min().item())
with torch.no_grad():
while input_ids.size(1) < max_out_len: # while not exceeding max output length
model_out = model(
input_ids=input_ids[:, cur_context],
attention_mask=None if 'llama'or'baichuan' in model.name_or_path.lower() else attention_mask[:, cur_context],
past_key_values=past_key_values,
use_cache=True,
)
logits, past_key_values = model_out.logits, model_out.past_key_values
softmax_out = torch.nn.functional.softmax(logits[:, -1, :], dim=1)
# Top-k sampling
tk = torch.topk(softmax_out, top_k, dim=1).indices
softmax_out_top_k = torch.gather(softmax_out, 1, tk)
softmax_out_top_k = softmax_out_top_k / softmax_out_top_k.sum(1)[:, None]
new_tok_indices = torch.multinomial(softmax_out_top_k, 1)
new_toks = torch.gather(tk, 1, new_tok_indices)
# If we're currently generating the continuation for the last token in `input_ids`,
# create a new index so we can insert the new token
if cur_context.stop == input_ids.size(1):
attention_mask = torch.cat(
[attention_mask, attention_mask.new_zeros(batch_size, 1)], dim=1
)
input_ids = torch.cat(
[
input_ids,
input_ids.new_ones(batch_size, 1) * tok.pad_token_id,
],
dim=1,
)
last_non_masked = attention_mask.sum(1) - 1
for i in range(batch_size):
new_idx = last_non_masked[i] + 1
if last_non_masked[i].item() + 1 != cur_context.stop:
continue
# Stop generating if we've already maxed out for this prompt
if new_idx < max_out_len:
input_ids[i][new_idx] = new_toks[i]
attention_mask[i][new_idx] = 1
cur_context = slice(cur_context.stop, cur_context.stop + 1)
txt = [tok.decode(x, skip_special_tokens=True) for x in input_ids.detach().cpu().numpy().tolist()]
txt = [
unicodedata.normalize("NFKD", x)
.replace("\n\n", " ")
.replace("<|endoftext|>", "")
for x in txt
]
return txt
|