boapps commited on
Commit
bb93dbf
·
verified ·
1 Parent(s): 847ca8c

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +141 -0
app.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from peft import PeftModel
3
+ import transformers
4
+ import gradio as gr
5
+
6
+ from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
7
+
8
+ tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
9
+
10
+ BASE_MODEL = "mistralai/Mistral-7B-v0.1"
11
+ LORA_WEIGHTS = "./qlora-out.mistral.0.9978/"
12
+
13
+ if torch.cuda.is_available():
14
+ device = "cuda"
15
+ else:
16
+ device = "cpu"
17
+
18
+ try:
19
+ if torch.backends.mps.is_available():
20
+ device = "mps"
21
+ except:
22
+ pass
23
+
24
+ if device == "cuda":
25
+ from transformers import BitsAndBytesConfig
26
+
27
+ nf4_config = BitsAndBytesConfig(
28
+ load_in_4bit=True,
29
+ bnb_4bit_quant_type="nf4",
30
+ bnb_4bit_use_double_quant=True,
31
+ bnb_4bit_compute_dtype=torch.bfloat16
32
+ )
33
+
34
+ model = AutoModelForCausalLM.from_pretrained(BASE_MODEL, quantization_config=nf4_config)
35
+
36
+ model = PeftModel.from_pretrained(
37
+ model, LORA_WEIGHTS, torch_dtype=torch.float16, force_download=True
38
+ )
39
+ elif device == "mps":
40
+ model = AutoModelForCausalLM.from_pretrained(
41
+ BASE_MODEL,
42
+ device_map={"": device},
43
+ torch_dtype=torch.float16,
44
+ )
45
+ model = PeftModel.from_pretrained(
46
+ model,
47
+ LORA_WEIGHTS,
48
+ device_map={"": device},
49
+ torch_dtype=torch.float16,
50
+ )
51
+ else:
52
+ model = AutoModelForCausalLM.from_pretrained(
53
+ BASE_MODEL, device_map={"": device}, low_cpu_mem_usage=True
54
+ )
55
+ model = PeftModel.from_pretrained(
56
+ model,
57
+ LORA_WEIGHTS,
58
+ device_map={"": device},
59
+ )
60
+
61
+
62
+ def generate_prompt(instruction, input=None):
63
+ if input:
64
+ return f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
65
+
66
+ ### Instruction:
67
+ {instruction}
68
+
69
+ ### Input:
70
+ {input}
71
+
72
+ ### Response:"""
73
+ else:
74
+ return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request.
75
+
76
+ ### Instruction:
77
+ {instruction}
78
+
79
+ ### Response:"""
80
+
81
+ if device != "cpu":
82
+ model.half()
83
+ model.eval()
84
+ if torch.__version__ >= "2":
85
+ model = torch.compile(model)
86
+
87
+
88
+ def evaluate(
89
+ instruction,
90
+ input=None,
91
+ temperature=0.1,
92
+ top_p=0.75,
93
+ top_k=40,
94
+ num_beams=4,
95
+ max_new_tokens=128,
96
+ **kwargs,
97
+ ):
98
+ prompt = generate_prompt(instruction, input)
99
+ inputs = tokenizer(prompt, return_tensors="pt")
100
+ input_ids = inputs["input_ids"].to(device)
101
+ generation_config = GenerationConfig(
102
+ temperature=temperature,
103
+ top_p=top_p,
104
+ top_k=top_k,
105
+ num_beams=num_beams,
106
+ **kwargs,
107
+ )
108
+ with torch.no_grad():
109
+ generation_output = model.generate(
110
+ input_ids=input_ids,
111
+ generation_config=generation_config,
112
+ return_dict_in_generate=True,
113
+ output_scores=True,
114
+ max_new_tokens=max_new_tokens,
115
+ )
116
+ s = generation_output.sequences[0]
117
+ output = tokenizer.decode(s)
118
+ return output.split("### Response:")[1].strip()
119
+
120
+
121
+ g = gr.Interface(
122
+ fn=evaluate,
123
+ inputs=[
124
+ gr.components.Textbox(
125
+ lines=2, label="Utasítás", placeholder="Mesélj kicsit a szürkemarháról!"
126
+ ),
127
+ gr.components.Textbox(lines=2, label="Input", placeholder="üres"),
128
+ gr.components.Slider(minimum=0, maximum=1, value=0.1, label="Temperature"),
129
+ gr.components.Slider(minimum=0, maximum=1, value=0.75, label="Top p"),
130
+ gr.components.Slider(minimum=0, maximum=100, step=1, value=40, label="Top k"),
131
+ gr.components.Slider(minimum=1, maximum=4, step=1, value=4, label="Beams"),
132
+ gr.components.Slider(
133
+ minimum=1, maximum=512, step=1, value=128, label="Max tokens"
134
+ ),
135
+ ],
136
+ outputs=["text"],
137
+ title="szürkemarha-mistral-v1",
138
+ description="A szürkemarha-mistral egy fejlesztés alatt álló 7 milliárd paraméteres Mistral-0.1 alapú model LoRA finomhangolva instrukciókövetésre.",
139
+ )
140
+ g.queue(concurrency_count=1)
141
+ g.launch()