import os from pathlib import Path from typing import Dict, List import torch from transformers import AutoModelForCausalLM, AutoTokenizer from ..rome import repr_tools from ...util.globals import * from .layer_stats import layer_stats from .rome_hparams import ROMEHyperParams # Cache variables inv_mom2_cache = {} def get_inv_cov( model: AutoModelForCausalLM, tok: AutoTokenizer, layer_name: str, mom2_dataset: str, mom2_n_samples: str, mom2_dtype: str, hparams=None, ) -> torch.Tensor: """ Retrieves covariance statistics, then computes the algebraic inverse. Caches result for future use. """ global inv_mom2_cache model_name = model.config._name_or_path.replace("/", "_") key = (model_name, layer_name) if key not in inv_mom2_cache: print( f"Retrieving inverse covariance statistics for {model_name} @ {layer_name}. " f"The result will be cached to avoid repetitive computation." ) stat = layer_stats( model, tok, layer_name, hparams.stats_dir, mom2_dataset, to_collect=["mom2"], sample_size=mom2_n_samples, precision=mom2_dtype, hparams=hparams ) inv_mom2_cache[key] = torch.inverse( stat.mom2.moment().to(f"cuda:{hparams.device}") ).float() # Cast back to float32 return inv_mom2_cache[key] def compute_u( model: AutoModelForCausalLM, tok: AutoTokenizer, request: Dict, hparams: ROMEHyperParams, layer: int, context_templates: List[str], ) -> torch.Tensor: """ Computes the right vector used in constructing the rank-1 update matrix. """ print("Computing left vector (u)...") # Compute projection token word_repr_args = dict( model=model, tok=tok, layer=layer, module_template=hparams.rewrite_module_tmp, track="in", ) if "subject_" in hparams.fact_token and hparams.fact_token.index("subject_") == 0: word = request["subject"] print(f"Selected u projection object {word}") cur_repr = repr_tools.get_reprs_at_word_tokens( context_templates=[ templ.format(request["prompt"]) for templ in context_templates ], words=[word for _ in range(len(context_templates))], subtoken=hparams.fact_token[len("subject_") :], **word_repr_args, ).mean(0) elif hparams.fact_token == "last": # Heuristic to choose last word. Not a huge deal if there's a minor # edge case (e.g. multi-token word) because the function below will # take the last token. cur_repr = repr_tools.get_reprs_at_idxs( contexts=[ templ.format(request["prompt"].format(request["subject"])) for templ in context_templates ], idxs=[[-1] for _ in range(len(context_templates))], **word_repr_args, ).mean(0) print("Selected u projection token with last token") else: raise ValueError(f"fact_token={hparams.fact_token} not recognized") # Apply inverse second moment adjustment u = cur_repr if hparams.mom2_adjustment: u = get_inv_cov( model, tok, hparams.rewrite_module_tmp.format(layer), hparams.mom2_dataset, hparams.mom2_n_samples, hparams.mom2_dtype, hparams=hparams, ) @ u.unsqueeze(1) u = u.squeeze() return u / u.norm()