File size: 1,898 Bytes
3494c6b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModel
from transformers import GPT2TokenizerFast, GPT2Tokenizer
from easyeditor import apply_grace_to_model, GraceHyperParams,nethook
import torch
import gradio as gr



def edit(prompt, target_new, num_steps, replacement):
    request={"prompt":prompt,"target_new":target_new}
    hparams = GraceHyperParams.from_hparams("./hparams/GRACE/gpt2.yaml")

    model = AutoModelForCausalLM.from_pretrained("./models/gpt2", device_map='cpu')
    tok = GPT2Tokenizer.from_pretrained("./models/gpt2")
    tok.pad_token_id = tok.eos_token_id
    global edit_model
    edit_model = apply_grace_to_model(model,tok,request,hparams, num_steps, replacement)
    return prompt

def generate(input_text, target_new=None):
    tok = GPT2Tokenizer.from_pretrained("./models/gpt2")
    hparams = GraceHyperParams.from_hparams("./hparams/GRACE/gpt2.yaml")
    tok.pad_token_id = tok.eos_token_id
    
    global edit_model
    
    if target_new is None:
        max_new_tokens = 25
    else:
        max_new_tokens = len(tok.encode(target_new))
    prompt_len = len(input_text)
    input_ids = tok.encode(input_text, return_tensors='pt').to('cpu')
    edit_output = edit_model.generate(input_ids, max_new_tokens=max_new_tokens, pad_token_id=tok.eos_token_id)
    edit_reply = tok.decode(edit_output[0], skip_special_tokens=True)
    torch.cuda.empty_cache()
    
    ori_model = AutoModelForCausalLM.from_pretrained("./models/gpt2").to('cpu')
    ori_output = ori_model.generate(input_ids, max_new_tokens=max_new_tokens, pad_token_id=tok.eos_token_id)
    ori_reply = tok.decode(ori_output[0], skip_special_tokens=True)
    ori_reply = [(_, 'output') if i > prompt_len else (_, None) for i, _ in enumerate(ori_reply)]
    edit_reply = [(_, 'output') if i > prompt_len else (_, None) for i, _ in enumerate(edit_reply)]
    return ori_reply, edit_reply