File size: 3,475 Bytes
7a10940
9e51263
 
7a10940
 
 
 
 
2e598e3
 
 
 
 
7a10940
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2e598e3
7a10940
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8bcb3e7
e3eb48d
7a10940
8bcb3e7
7a10940
a299003
7a10940
 
 
8bcb3e7
e3eb48d
7a10940
 
 
 
9be8f93
e3eb48d
7a10940
9be8f93
7a10940
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
# %%
import os
os.environ["CUDA_VISIBLE_DEVICES"] = ""
import gradio as gr
from transformers import LlamaTokenizer
from transformers import LlamaForCausalLM, GenerationConfig
from peft import PeftModel
import torch
if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"
device_map={'': 0}
def generate_instruction_prompt(instruction, input=None):
    if input:
        return f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.

### Instruction:
{instruction}

### Input:
{input}

### Response:"""
    else:
        return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request.

### Instruction:
{instruction}

### Response:"""


def evaluate(
    model,
    tokenizer,
    instruction,
    input=None,
    temperature=0.1,
    top_p=0.75,
    num_beams=4,
    max_token=256,
):
    generation_config = GenerationConfig(
        temperature=temperature,
        top_p=top_p,
        num_beams=num_beams,
        top_k=40,
        no_repeat_ngram_size=3,
    )
    prompt = generate_instruction_prompt(instruction, input)
    inputs = tokenizer(prompt, return_tensors="pt")
    input_ids = inputs["input_ids"].to(device)
    generation_output = model.generate(
        input_ids=input_ids,
        generation_config=generation_config,
        return_dict_in_generate=True,
        output_scores=True,
        max_new_tokens=max_token,
    )
    s = generation_output.sequences[0]
    output = tokenizer.decode(s)
    res = output.split("### Response:")[1].strip()
    print("Response:", res)
    return res


def load_lora(lora_path, base_model="decapoda-research/llama-7b-hf"):
    model = LlamaForCausalLM.from_pretrained(
        base_model,
        load_in_8bit=True,
        # device_map=device_map,
        low_cpu_mem_usage=True,
        torch_type=torch.float16,
    )
    print("Loading LoRA...")
    lora = PeftModel.from_pretrained(
        model,
        lora_path,
        torch_type=torch.float16,
        # device_map=device_map,
    )
    return lora


base_model = "decapoda-research/llama-7b-hf"
tokenizer = LlamaTokenizer.from_pretrained(base_model)
# question = "如果今天是星期五, 那么后天是星期几?"
model = load_lora(lora_path="facat/alpaca-lora-cn", base_model=base_model)

eval = lambda question, input, temperature, beams, max_token: evaluate(
    model,
    tokenizer,
    question,
    input=input,
    temperature=temperature,
    num_beams=beams,
    max_token=max_token,
)

gr.Interface(
    fn=eval,
    inputs=[
        gr.components.Textbox(
            lines=2, label="Instruction", placeholder="Tell me about alpacas."
        ),
        gr.components.Textbox(lines=2, label="Input", placeholder="none"),
        gr.components.Slider(minimum=0, maximum=1, value=0.1, label="Temperature"),
        # gr.components.Slider(minimum=0, maximum=1, value=0.75, label="Top p"),
        # gr.components.Slider(minimum=0, maximum=100, step=1, value=40, label="Top k"),
        gr.components.Slider(minimum=1, maximum=4, step=1, value=4, label="Beams"),
        gr.components.Slider(
            minimum=1, maximum=512, step=1, value=256, label="Max tokens"
        ),
    ],
    outputs=[
        gr.inputs.Textbox(
            lines=8,
            label="Output",
        )
    ],
    title=f"Alpaca-LoRA",
    description=f"Alpaca-LoRA",
).launch()