|
from typing import Any, Dict, List, Tuple |
|
from copy import deepcopy |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
from .WISE import WISE |
|
from .utils import tokenize, get_context_templates |
|
from .wise_hparams import WISEHyperParams |
|
import gradio as gr |
|
|
|
def apply_wise_to_model( |
|
model: AutoModelForCausalLM, |
|
tok: AutoTokenizer, |
|
request: List[Dict], |
|
hparams: WISEHyperParams, |
|
num_steps: int, |
|
edit_lr: float, |
|
copy=False, |
|
return_orig_weights=False, |
|
keep_original_weight=False, |
|
**kwargs: Any, |
|
) -> Tuple[AutoModelForCausalLM, Dict[str, Any]]: |
|
if copy: |
|
model = deepcopy(model) |
|
weights_copy = {} |
|
hparams.n_iter = num_steps |
|
hparams.edit_lr = edit_lr |
|
context_templates = get_context_templates(model, tok, length_params=[[5,5], [10,5]], device=hparams.device) |
|
editor = WISE(model=model, config=hparams, device=hparams.device) |
|
print( |
|
f"Executing WISE algorithm for the update: " |
|
f"[{request['prompt']}] -> [{request['target_new']}]" |
|
) |
|
tokens, act_mask, deact_mask = tokenize(request, tokenizer=tok, device=hparams.device, context_templates=context_templates, hparams=hparams) |
|
editor.edit(config=hparams, tokens=tokens, act_mask=act_mask, deact_mask=deact_mask) |
|
|
|
editor.to('cpu') |
|
gr.Info("Completed editing via WISE!") |
|
|
|
return editor |