from typing import Any, Dict, List, Tuple import torch from copy import deepcopy from transformers import AutoModelForCausalLM, AutoTokenizer from .GRACE import GRACE from .grace_hparams import GraceHyperParams from .utils import tokenize from ...util import nethook import gradio as gr def apply_grace_to_model( model: AutoModelForCausalLM, tok: AutoTokenizer, requests: List[Dict], hparams: GraceHyperParams, num_steps: int, replacement: str, copy=False, return_orig_weights=False, keep_original_weight=False, **kwargs: Any, ) -> Tuple[AutoModelForCausalLM, Dict[str, Any]]: request = requests if copy: model = deepcopy(model) weights_copy = {} device = torch.device('cpu') hparams.n_iter = num_steps hparams.replacement = replacement editor = GRACE(model=model, config=hparams, device=device) tokens = tokenize(request, tokenizer=tok, device=device) editor.edit(config=hparams, tokens=tokens) editor.to('cpu') gr.Info("Completed editing via GRACE!") return editor