ZJUPeng's picture
add continuous
d6682b6
from typing import Dict, List, Tuple
import numpy as np
import torch
from matplotlib.style import context
from transformers import AutoModelForCausalLM, AutoTokenizer
from ..rome import repr_tools
from ...util import nethook
from .rome_hparams import ROMEHyperParams
def compute_v(
model: AutoModelForCausalLM,
tok: AutoTokenizer,
request: Dict,
hparams: ROMEHyperParams,
layer: int,
left_vector: torch.Tensor,
context_templates: List[str],
) -> torch.Tensor:
"""
Computes the value (right) vector for the rank-1 update.
Runs a simple optimization procedure.
"""
print("Computing right vector (v)")
# Tokenize target into list of int token IDs
target_ids = tok.encode(request["target_new"], return_tensors="pt", add_special_tokens=False).to('cpu')[0]
# if target_ids[0] == tok.bos_token_id or target_ids[0] == tok.unk_token_id:
# target_ids = target_ids[1:]
# Compile list of rewriting and KL x/y pairs
rewriting_prompts, kl_prompts = [
context.format(request["prompt"]) + tok.decode(target_ids[:-1])
for context in context_templates
], ["{} is a"]
all_prompts = rewriting_prompts + kl_prompts
input_tok = tok(
[prompt.format(request["subject"]) for prompt in all_prompts],
return_tensors="pt",
padding=True,
).to("cpu")
# Compute rewriting targets
rewriting_targets = torch.tensor(-100, device='cpu').repeat(
len(rewriting_prompts), *input_tok["input_ids"].shape[1:]
)
for i in range(len(rewriting_prompts)):
ex_len = input_tok["attention_mask"][i].sum()
rewriting_targets[i, ex_len - len(target_ids) : ex_len] = target_ids
# Compute indices of the tokens where the fact is looked up
vanilla_input_prompts = [
context.format(request["prompt"]).format(request['subject'])
for context in context_templates
] + [f"{request['subject']} is a"]
lookup_idxs = [
find_fact_lookup_idx(
prompt, request["subject"], tok, hparams.fact_token, verbose=(i == 0), input_prompt=vanilla_input_prompts[i]
)
for i, prompt in enumerate(all_prompts)
]
# Finalize rewrite and loss layers
loss_layer = max(hparams.v_loss_layer, layer)
print(f"Rewrite layer is {layer}")
print(f"Tying optimization objective to {loss_layer}")
# Set up an optimization over a latent vector that, when output at the
# rewrite layer, i.e. hypothesized fact lookup location, will induce the
# target token to be predicted at the final layer.
if hasattr(model.config, 'n_embd'):
delta = torch.zeros((model.config.n_embd,), requires_grad=True, device=f"cpu")
else:
delta = torch.zeros((model.config.hidden_size,), requires_grad=True, device=f"cpu")
target_init, kl_distr_init = None, None
# Inserts new "delta" variable at the appropriate part of the computation
def edit_output_fn(cur_out, cur_layer):
nonlocal target_init
if cur_layer == hparams.mlp_module_tmp.format(layer):
# Store initial value of the vector of interest
if target_init is None:
print("Recording initial value of v*")
# Initial value is recorded for the clean sentence
target_init = cur_out[0, lookup_idxs[0]].detach().clone()
for i, idx in enumerate(lookup_idxs):
if len(lookup_idxs)!=len(cur_out):
cur_out[idx, i, :] += delta
else:
cur_out[i, idx, :] += delta
return cur_out
# Optimizer
opt = torch.optim.Adam([delta], lr=hparams.v_lr)
nethook.set_requires_grad(False, model)
# Execute optimization
for it in range(hparams.v_num_grad_steps):
opt.zero_grad()
# Forward propagation
with nethook.TraceDict(
module=model,
layers=[
hparams.layer_module_tmp.format(loss_layer),
hparams.mlp_module_tmp.format(layer),
],
retain_input=False,
retain_output=True,
edit_output=edit_output_fn,
) as tr:
logits = model(**input_tok).logits
# Compute distribution for KL divergence
kl_logits = torch.stack(
[
logits[i - len(kl_prompts), idx, :]
for i, idx in enumerate(lookup_idxs[-len(kl_prompts) :])
],
dim=0,
)
kl_log_probs = torch.nn.functional.log_softmax(kl_logits, dim=1)
if kl_distr_init is None:
kl_distr_init = kl_log_probs.detach().clone()
# Compute loss on rewriting targets
log_probs = torch.log_softmax(logits, dim=2)
loss = torch.gather(
log_probs,
2,
torch.where(rewriting_targets != -100, rewriting_targets, 0).unsqueeze(2),
).squeeze(2)
mask = (rewriting_targets != -100).float()
# Aggregate total losses
nll_loss_each = -(loss * mask).sum(1) / target_ids.size(0)
nll_loss = nll_loss_each.mean()
kl_loss = hparams.kl_factor * torch.nn.functional.kl_div(
kl_distr_init, kl_log_probs, log_target=True, reduction="batchmean"
)
weight_decay = hparams.v_weight_decay * (
torch.norm(delta) / torch.norm(target_init) ** 2
)
# weight_decay = hparams.v_weight_decay * torch.norm(delta) ** 2
loss = nll_loss + kl_loss + weight_decay
print(
f"loss {np.round(loss.item(), 3)} = {np.round(nll_loss.item(), 3)} + {np.round(kl_loss.item(), 3)} + {np.round(weight_decay.item(), 3)} "
f"avg prob of [{request['target_new']}] "
f"{torch.exp(-nll_loss_each).mean().item()}"
)
if loss < 5e-2:
break
if it == hparams.v_num_grad_steps - 1:
break
# Backpropagate
loss.backward()
opt.step()
# Project within L2 ball
max_norm = hparams.clamp_norm_factor * target_init.norm()
if delta.norm() > max_norm:
with torch.no_grad():
delta[...] = delta * max_norm / delta.norm()
target = target_init + delta.to(target_init.dtype)
# Retrieve cur_input, the current input to the 2nd MLP layer, and
# cur_output, the original output of the 2nd MLP layer.
cur_input, cur_output = get_module_input_output_at_word(
model,
tok,
layer,
context_template=request["prompt"],
word=request["subject"],
module_template=hparams.rewrite_module_tmp,
fact_token_strategy=hparams.fact_token,
)
# Solving the linear system to compute the right vector
right_vector = (target - cur_output) / torch.dot(cur_input, left_vector)
print(f"Delta norm: {(target - cur_output).norm().item()}")
print(
f"Change in target norm: {target_init.norm().item()} to {target.norm().item()} => {(target.norm() - target_init.norm()).item()}"
)
print(f"Division Factor: {torch.dot(cur_input, left_vector).item()}")
print(f"Right vector norm: {right_vector.norm()}")
return right_vector
def get_module_input_output_at_word(
model: AutoModelForCausalLM,
tok: AutoTokenizer,
layer: int,
context_template: str,
word: str,
module_template: str,
fact_token_strategy: str,
) -> Tuple[torch.Tensor]:
"""
Retrieves detached representations for a word at the input and
output of a particular layer module.
"""
word_repr_args = dict(
model=model,
tok=tok,
layer=layer,
module_template=module_template,
)
if "subject_" in fact_token_strategy and fact_token_strategy.index("subject_") == 0:
subtoken = fact_token_strategy[len("subject_") :]
l_input, l_output = repr_tools.get_reprs_at_word_tokens(
track="both",
subtoken=subtoken,
context_templates=[context_template],
words=[word],
**word_repr_args,
)
elif fact_token_strategy == "last":
l_input, l_output = repr_tools.get_reprs_at_idxs(
track="both",
contexts=[context_template.format(word)],
idxs=[[-1]],
**word_repr_args,
)
else:
raise ValueError(f"fact_token={fact_token_strategy} not recognized")
l_input, l_output = l_input[0], l_output[0]
return l_input.detach(), l_output.detach()
def find_fact_lookup_idx(
prompt: str,
subject: str,
tok: AutoTokenizer,
fact_token_strategy: str,
verbose=True,
input_prompt=None
) -> int:
"""
Computes hypothesized fact lookup index given a sentence and subject.
"""
ret = None
if fact_token_strategy == "last":
ret = len(tok.encode(input_prompt)) - 1
elif (
"subject_" in fact_token_strategy and fact_token_strategy.index("subject_") == 0
):
ret = repr_tools.get_words_idxs_in_templates(
tok=tok,
context_templates=[prompt],
words=[subject],
subtoken=fact_token_strategy[len("subject_") :],
)[0][0]
else:
raise ValueError(f"fact_token={fact_token_strategy} not recognized")
sentence = prompt.format(subject)
if verbose:
print(
f"Lookup index found: {ret} | Sentence: {sentence} | Token:",
tok.decode(tok(sentence)["input_ids"][ret]),
)
return ret