ZJUPeng's picture
initial commit
3494c6b
raw
history blame
1.11 kB
from typing import Any, Dict, List, Tuple
import torch
from copy import deepcopy
from transformers import AutoModelForCausalLM, AutoTokenizer
from .GRACE import GRACE
from .grace_hparams import GraceHyperParams
from .utils import tokenize
from ...util import nethook
import gradio as gr
def apply_grace_to_model(
model: AutoModelForCausalLM,
tok: AutoTokenizer,
requests: List[Dict],
hparams: GraceHyperParams,
num_steps: int,
replacement: str,
copy=False,
return_orig_weights=False,
keep_original_weight=False,
**kwargs: Any,
) -> Tuple[AutoModelForCausalLM, Dict[str, Any]]:
request = requests
if copy:
model = deepcopy(model)
weights_copy = {}
device = torch.device('cpu')
hparams.n_iter = num_steps
hparams.replacement = replacement
editor = GRACE(model=model, config=hparams, device=device)
tokens = tokenize(request, tokenizer=tok, device=device)
editor.edit(config=hparams, tokens=tokens)
editor.to('cpu')
gr.Info("Completed editing via GRACE!")
return editor