Spaces:
Runtime error
Runtime error
zetavg
commited on
Commit
•
90c428d
1
Parent(s):
2f0a0ce
change model loading mechanism
Browse files- LLaMA_LoRA.ipynb +2 -7
- app.py +1 -5
- llama_lora/globals.py +3 -5
- llama_lora/models.py +206 -65
- llama_lora/ui/finetune_ui.py +23 -15
- llama_lora/ui/inference_ui.py +24 -21
- llama_lora/ui/main_page.py +1 -2
- llama_lora/ui/tokenizer_ui.py +4 -2
- llama_lora/utils/lru_cache.py +4 -0
LLaMA_LoRA.ipynb
CHANGED
@@ -281,7 +281,7 @@
|
|
281 |
"\n",
|
282 |
"# Set Configs\n",
|
283 |
"from llama_lora.llama_lora.globals import Global\n",
|
284 |
-
"Global.
|
285 |
"data_dir_realpath = !realpath ./data\n",
|
286 |
"Global.data_dir = data_dir_realpath[0]\n",
|
287 |
"Global.load_8bit = True\n",
|
@@ -289,12 +289,7 @@
|
|
289 |
"# Prepare Data Dir\n",
|
290 |
"import os\n",
|
291 |
"from llama_lora.llama_lora.utils.data import init_data_dir\n",
|
292 |
-
"init_data_dir()
|
293 |
-
"\n",
|
294 |
-
"# Load the Base Model\n",
|
295 |
-
"from llama_lora.llama_lora.models import load_base_model\n",
|
296 |
-
"load_base_model()\n",
|
297 |
-
"print(f\"Base model loaded: '{Global.base_model}'.\")"
|
298 |
],
|
299 |
"metadata": {
|
300 |
"id": "Yf6g248ylteP"
|
|
|
281 |
"\n",
|
282 |
"# Set Configs\n",
|
283 |
"from llama_lora.llama_lora.globals import Global\n",
|
284 |
+
"Global.default_base_model_name = base_model\n",
|
285 |
"data_dir_realpath = !realpath ./data\n",
|
286 |
"Global.data_dir = data_dir_realpath[0]\n",
|
287 |
"Global.load_8bit = True\n",
|
|
|
289 |
"# Prepare Data Dir\n",
|
290 |
"import os\n",
|
291 |
"from llama_lora.llama_lora.utils.data import init_data_dir\n",
|
292 |
+
"init_data_dir()"
|
|
|
|
|
|
|
|
|
|
|
293 |
],
|
294 |
"metadata": {
|
295 |
"id": "Yf6g248ylteP"
|
app.py
CHANGED
@@ -7,7 +7,6 @@ import gradio as gr
|
|
7 |
from llama_lora.globals import Global
|
8 |
from llama_lora.ui.main_page import main_page, get_page_title, main_page_custom_css
|
9 |
from llama_lora.utils.data import init_data_dir
|
10 |
-
from llama_lora.models import load_base_model
|
11 |
|
12 |
|
13 |
def main(
|
@@ -31,7 +30,7 @@ def main(
|
|
31 |
data_dir
|
32 |
), "Please specify a --data_dir, e.g. --data_dir='./data'"
|
33 |
|
34 |
-
Global.
|
35 |
Global.data_dir = os.path.abspath(data_dir)
|
36 |
Global.load_8bit = load_8bit
|
37 |
|
@@ -41,9 +40,6 @@ def main(
|
|
41 |
os.makedirs(data_dir, exist_ok=True)
|
42 |
init_data_dir()
|
43 |
|
44 |
-
if not skip_loading_base_model:
|
45 |
-
load_base_model()
|
46 |
-
|
47 |
with gr.Blocks(title=get_page_title(), css=main_page_custom_css()) as demo:
|
48 |
main_page()
|
49 |
|
|
|
7 |
from llama_lora.globals import Global
|
8 |
from llama_lora.ui.main_page import main_page, get_page_title, main_page_custom_css
|
9 |
from llama_lora.utils.data import init_data_dir
|
|
|
10 |
|
11 |
|
12 |
def main(
|
|
|
30 |
data_dir
|
31 |
), "Please specify a --data_dir, e.g. --data_dir='./data'"
|
32 |
|
33 |
+
Global.default_base_model_name = base_model
|
34 |
Global.data_dir = os.path.abspath(data_dir)
|
35 |
Global.load_8bit = load_8bit
|
36 |
|
|
|
40 |
os.makedirs(data_dir, exist_ok=True)
|
41 |
init_data_dir()
|
42 |
|
|
|
|
|
|
|
43 |
with gr.Blocks(title=get_page_title(), css=main_page_custom_css()) as demo:
|
44 |
main_page()
|
45 |
|
llama_lora/globals.py
CHANGED
@@ -13,12 +13,10 @@ from .lib.finetune import train
|
|
13 |
class Global:
|
14 |
version = None
|
15 |
|
16 |
-
base_model: str = ""
|
17 |
data_dir: str = ""
|
18 |
load_8bit: bool = False
|
19 |
|
20 |
-
|
21 |
-
loaded_base_model: Any = None
|
22 |
|
23 |
# Functions
|
24 |
train_fn: Any = train
|
@@ -31,8 +29,8 @@ class Global:
|
|
31 |
generation_force_stopped_at = None
|
32 |
|
33 |
# Model related
|
34 |
-
|
35 |
-
|
36 |
|
37 |
# GPU Info
|
38 |
gpu_cc = None # GPU compute capability
|
|
|
13 |
class Global:
|
14 |
version = None
|
15 |
|
|
|
16 |
data_dir: str = ""
|
17 |
load_8bit: bool = False
|
18 |
|
19 |
+
default_base_model_name: str = ""
|
|
|
20 |
|
21 |
# Functions
|
22 |
train_fn: Any = train
|
|
|
29 |
generation_force_stopped_at = None
|
30 |
|
31 |
# Model related
|
32 |
+
loaded_models = LRUCache(1)
|
33 |
+
loaded_tokenizers = LRUCache(1)
|
34 |
|
35 |
# GPU Info
|
36 |
gpu_cc = None # GPU compute capability
|
llama_lora/models.py
CHANGED
@@ -3,9 +3,8 @@ import sys
|
|
3 |
import gc
|
4 |
|
5 |
import torch
|
6 |
-
import
|
7 |
from peft import PeftModel
|
8 |
-
from transformers import GenerationConfig, LlamaForCausalLM, LlamaTokenizer
|
9 |
|
10 |
from .globals import Global
|
11 |
|
@@ -23,96 +22,120 @@ def get_device():
|
|
23 |
pass
|
24 |
|
25 |
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
def get_base_model():
|
30 |
-
load_base_model()
|
31 |
-
return Global.loaded_base_model
|
32 |
-
|
33 |
-
|
34 |
-
def get_model_with_lora(lora_weights_name_or_path: str = "tloen/alpaca-lora-7b"):
|
35 |
-
Global.model_has_been_used = True
|
36 |
|
37 |
-
|
38 |
-
model_from_cache = Global.cached_lora_models.get(lora_weights_name_or_path)
|
39 |
-
if model_from_cache:
|
40 |
-
return model_from_cache
|
41 |
|
42 |
if device == "cuda":
|
43 |
-
model =
|
44 |
-
|
45 |
-
|
46 |
torch_dtype=torch.float16,
|
|
|
47 |
device_map={'': 0}, # ? https://github.com/tloen/alpaca-lora/issues/21
|
48 |
)
|
49 |
elif device == "mps":
|
50 |
-
model =
|
51 |
-
|
52 |
-
lora_weights_name_or_path,
|
53 |
device_map={"": device},
|
54 |
torch_dtype=torch.float16,
|
55 |
)
|
56 |
else:
|
57 |
-
model =
|
58 |
-
|
59 |
-
lora_weights_name_or_path,
|
60 |
-
device_map={"": device},
|
61 |
)
|
62 |
|
63 |
-
model.config.pad_token_id = get_tokenizer().pad_token_id = 0
|
64 |
model.config.bos_token_id = 1
|
65 |
model.config.eos_token_id = 2
|
66 |
|
67 |
-
|
68 |
-
model.half() # seems to fix bugs for some users.
|
69 |
|
70 |
-
model.eval()
|
71 |
-
if torch.__version__ >= "2" and sys.platform != "win32":
|
72 |
-
model = torch.compile(model)
|
73 |
|
74 |
-
|
75 |
-
|
|
|
76 |
|
77 |
-
|
|
|
|
|
78 |
|
|
|
|
|
79 |
|
80 |
-
|
81 |
-
load_base_model()
|
82 |
-
return Global.loaded_tokenizer
|
83 |
|
84 |
|
85 |
-
def
|
|
|
|
|
86 |
if Global.ui_dev_mode:
|
87 |
return
|
88 |
|
89 |
-
if
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
if
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
94 |
if device == "cuda":
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
torch_dtype=torch.float16,
|
99 |
-
# device_map="auto",
|
100 |
device_map={'': 0}, # ? https://github.com/tloen/alpaca-lora/issues/21
|
101 |
)
|
102 |
elif device == "mps":
|
103 |
-
|
104 |
-
|
|
|
105 |
device_map={"": device},
|
106 |
torch_dtype=torch.float16,
|
107 |
)
|
108 |
else:
|
109 |
-
|
110 |
-
|
|
|
|
|
111 |
)
|
112 |
|
113 |
-
|
114 |
-
|
115 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
116 |
|
117 |
|
118 |
def clear_cache():
|
@@ -124,19 +147,137 @@ def clear_cache():
|
|
124 |
|
125 |
|
126 |
def unload_models():
|
127 |
-
|
128 |
-
Global.
|
|
|
129 |
|
130 |
-
del Global.loaded_tokenizer
|
131 |
-
Global.loaded_tokenizer = None
|
132 |
|
133 |
-
Global.cached_lora_models.clear()
|
134 |
|
135 |
-
clear_cache()
|
136 |
|
137 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
138 |
|
139 |
|
140 |
-
def unload_models_if_already_used():
|
141 |
-
|
142 |
-
|
|
|
3 |
import gc
|
4 |
|
5 |
import torch
|
6 |
+
from transformers import LlamaForCausalLM, LlamaTokenizer
|
7 |
from peft import PeftModel
|
|
|
8 |
|
9 |
from .globals import Global
|
10 |
|
|
|
22 |
pass
|
23 |
|
24 |
|
25 |
+
def get_new_base_model(base_model_name):
|
26 |
+
if Global.ui_dev_mode:
|
27 |
+
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
|
29 |
+
device = get_device()
|
|
|
|
|
|
|
30 |
|
31 |
if device == "cuda":
|
32 |
+
model = LlamaForCausalLM.from_pretrained(
|
33 |
+
base_model_name,
|
34 |
+
load_in_8bit=Global.load_8bit,
|
35 |
torch_dtype=torch.float16,
|
36 |
+
# device_map="auto",
|
37 |
device_map={'': 0}, # ? https://github.com/tloen/alpaca-lora/issues/21
|
38 |
)
|
39 |
elif device == "mps":
|
40 |
+
model = LlamaForCausalLM.from_pretrained(
|
41 |
+
base_model_name,
|
|
|
42 |
device_map={"": device},
|
43 |
torch_dtype=torch.float16,
|
44 |
)
|
45 |
else:
|
46 |
+
model = LlamaForCausalLM.from_pretrained(
|
47 |
+
base_model_name, device_map={"": device}, low_cpu_mem_usage=True
|
|
|
|
|
48 |
)
|
49 |
|
50 |
+
model.config.pad_token_id = get_tokenizer(base_model_name).pad_token_id = 0
|
51 |
model.config.bos_token_id = 1
|
52 |
model.config.eos_token_id = 2
|
53 |
|
54 |
+
return model
|
|
|
55 |
|
|
|
|
|
|
|
56 |
|
57 |
+
def get_tokenizer(base_model_name):
|
58 |
+
if Global.ui_dev_mode:
|
59 |
+
return
|
60 |
|
61 |
+
loaded_tokenizer = Global.loaded_tokenizers.get(base_model_name)
|
62 |
+
if loaded_tokenizer:
|
63 |
+
return loaded_tokenizer
|
64 |
|
65 |
+
tokenizer = LlamaTokenizer.from_pretrained(base_model_name)
|
66 |
+
Global.loaded_tokenizers.set(base_model_name, tokenizer)
|
67 |
|
68 |
+
return tokenizer
|
|
|
|
|
69 |
|
70 |
|
71 |
+
def get_model(
|
72 |
+
base_model_name,
|
73 |
+
peft_model_name = None):
|
74 |
if Global.ui_dev_mode:
|
75 |
return
|
76 |
|
77 |
+
if peft_model_name == "None":
|
78 |
+
peft_model_name = None
|
79 |
+
|
80 |
+
model_key = base_model_name
|
81 |
+
if peft_model_name:
|
82 |
+
model_key = f"{base_model_name}//{peft_model_name}"
|
83 |
+
|
84 |
+
loaded_model = Global.loaded_models.get(model_key)
|
85 |
+
if loaded_model:
|
86 |
+
return loaded_model
|
87 |
+
|
88 |
+
peft_model_name_or_path = peft_model_name
|
89 |
+
|
90 |
+
lora_models_directory_path = os.path.join(Global.data_dir, "lora_models")
|
91 |
+
possible_lora_model_path = os.path.join(lora_models_directory_path, peft_model_name)
|
92 |
+
if os.path.isdir(possible_lora_model_path):
|
93 |
+
peft_model_name_or_path = possible_lora_model_path
|
94 |
+
|
95 |
+
Global.loaded_models.prepare_to_set()
|
96 |
+
clear_cache()
|
97 |
+
|
98 |
+
model = get_new_base_model(base_model_name)
|
99 |
+
|
100 |
+
if peft_model_name:
|
101 |
+
device = get_device()
|
102 |
+
|
103 |
if device == "cuda":
|
104 |
+
model = PeftModel.from_pretrained(
|
105 |
+
model,
|
106 |
+
peft_model_name_or_path,
|
107 |
torch_dtype=torch.float16,
|
|
|
108 |
device_map={'': 0}, # ? https://github.com/tloen/alpaca-lora/issues/21
|
109 |
)
|
110 |
elif device == "mps":
|
111 |
+
model = PeftModel.from_pretrained(
|
112 |
+
model,
|
113 |
+
peft_model_name_or_path,
|
114 |
device_map={"": device},
|
115 |
torch_dtype=torch.float16,
|
116 |
)
|
117 |
else:
|
118 |
+
model = PeftModel.from_pretrained(
|
119 |
+
model,
|
120 |
+
peft_model_name_or_path,
|
121 |
+
device_map={"": device},
|
122 |
)
|
123 |
|
124 |
+
model.config.pad_token_id = get_tokenizer(base_model_name).pad_token_id = 0
|
125 |
+
model.config.bos_token_id = 1
|
126 |
+
model.config.eos_token_id = 2
|
127 |
+
|
128 |
+
if not Global.load_8bit:
|
129 |
+
model.half() # seems to fix bugs for some users.
|
130 |
+
|
131 |
+
model.eval()
|
132 |
+
if torch.__version__ >= "2" and sys.platform != "win32":
|
133 |
+
model = torch.compile(model)
|
134 |
+
|
135 |
+
Global.loaded_models.set(model_key, model)
|
136 |
+
clear_cache()
|
137 |
+
|
138 |
+
return model
|
139 |
|
140 |
|
141 |
def clear_cache():
|
|
|
147 |
|
148 |
|
149 |
def unload_models():
|
150 |
+
Global.loaded_models.clear()
|
151 |
+
Global.loaded_tokenizers.clear()
|
152 |
+
clear_cache()
|
153 |
|
|
|
|
|
154 |
|
|
|
155 |
|
|
|
156 |
|
157 |
+
|
158 |
+
########
|
159 |
+
|
160 |
+
# def get_base_model():
|
161 |
+
# load_base_model()
|
162 |
+
# return Global.loaded_base_model
|
163 |
+
|
164 |
+
|
165 |
+
# def get_model_with_lora(lora_weights_name_or_path: str = "tloen/alpaca-lora-7b"):
|
166 |
+
# # Global.model_has_been_used = True
|
167 |
+
# #
|
168 |
+
# #
|
169 |
+
# if Global.loaded_tokenizer is None:
|
170 |
+
# Global.loaded_tokenizer = LlamaTokenizer.from_pretrained(
|
171 |
+
# Global.base_model
|
172 |
+
# )
|
173 |
+
|
174 |
+
# if Global.cached_lora_models:
|
175 |
+
# model_from_cache = Global.cached_lora_models.get(lora_weights_name_or_path)
|
176 |
+
# if model_from_cache:
|
177 |
+
# return model_from_cache
|
178 |
+
|
179 |
+
# Global.cached_lora_models.prepare_to_set()
|
180 |
+
|
181 |
+
# if device == "cuda":
|
182 |
+
# model = PeftModel.from_pretrained(
|
183 |
+
# get_new_base_model(),
|
184 |
+
# lora_weights_name_or_path,
|
185 |
+
# torch_dtype=torch.float16,
|
186 |
+
# device_map={'': 0}, # ? https://github.com/tloen/alpaca-lora/issues/21
|
187 |
+
# )
|
188 |
+
# elif device == "mps":
|
189 |
+
# model = PeftModel.from_pretrained(
|
190 |
+
# get_new_base_model(),
|
191 |
+
# lora_weights_name_or_path,
|
192 |
+
# device_map={"": device},
|
193 |
+
# torch_dtype=torch.float16,
|
194 |
+
# )
|
195 |
+
# else:
|
196 |
+
# model = PeftModel.from_pretrained(
|
197 |
+
# get_new_base_model(),
|
198 |
+
# lora_weights_name_or_path,
|
199 |
+
# device_map={"": device},
|
200 |
+
# )
|
201 |
+
|
202 |
+
# model.config.pad_token_id = get_tokenizer().pad_token_id = 0
|
203 |
+
# model.config.bos_token_id = 1
|
204 |
+
# model.config.eos_token_id = 2
|
205 |
+
|
206 |
+
# if not Global.load_8bit:
|
207 |
+
# model.half() # seems to fix bugs for some users.
|
208 |
+
|
209 |
+
# model.eval()
|
210 |
+
# if torch.__version__ >= "2" and sys.platform != "win32":
|
211 |
+
# model = torch.compile(model)
|
212 |
+
|
213 |
+
# if Global.cached_lora_models:
|
214 |
+
# Global.cached_lora_models.set(lora_weights_name_or_path, model)
|
215 |
+
|
216 |
+
# clear_cache()
|
217 |
+
|
218 |
+
# return model
|
219 |
+
|
220 |
+
|
221 |
+
|
222 |
+
|
223 |
+
|
224 |
+
# def load_base_model():
|
225 |
+
# return;
|
226 |
+
|
227 |
+
# if Global.ui_dev_mode:
|
228 |
+
# return
|
229 |
+
|
230 |
+
# if Global.loaded_tokenizer is None:
|
231 |
+
# Global.loaded_tokenizer = LlamaTokenizer.from_pretrained(
|
232 |
+
# Global.base_model
|
233 |
+
# )
|
234 |
+
# if Global.loaded_base_model is None:
|
235 |
+
# if device == "cuda":
|
236 |
+
# Global.loaded_base_model = LlamaForCausalLM.from_pretrained(
|
237 |
+
# Global.base_model,
|
238 |
+
# load_in_8bit=Global.load_8bit,
|
239 |
+
# torch_dtype=torch.float16,
|
240 |
+
# # device_map="auto",
|
241 |
+
# device_map={'': 0}, # ? https://github.com/tloen/alpaca-lora/issues/21
|
242 |
+
# )
|
243 |
+
# elif device == "mps":
|
244 |
+
# Global.loaded_base_model = LlamaForCausalLM.from_pretrained(
|
245 |
+
# Global.base_model,
|
246 |
+
# device_map={"": device},
|
247 |
+
# torch_dtype=torch.float16,
|
248 |
+
# )
|
249 |
+
# else:
|
250 |
+
# Global.loaded_base_model = LlamaForCausalLM.from_pretrained(
|
251 |
+
# Global.base_model, device_map={"": device}, low_cpu_mem_usage=True
|
252 |
+
# )
|
253 |
+
|
254 |
+
# Global.loaded_base_model.config.pad_token_id = get_tokenizer().pad_token_id = 0
|
255 |
+
# Global.loaded_base_model.config.bos_token_id = 1
|
256 |
+
# Global.loaded_base_model.config.eos_token_id = 2
|
257 |
+
|
258 |
+
|
259 |
+
# def clear_cache():
|
260 |
+
# gc.collect()
|
261 |
+
|
262 |
+
# # if not shared.args.cpu: # will not be running on CPUs anyway
|
263 |
+
# with torch.no_grad():
|
264 |
+
# torch.cuda.empty_cache()
|
265 |
+
|
266 |
+
|
267 |
+
# def unload_models():
|
268 |
+
# del Global.loaded_base_model
|
269 |
+
# Global.loaded_base_model = None
|
270 |
+
|
271 |
+
# del Global.loaded_tokenizer
|
272 |
+
# Global.loaded_tokenizer = None
|
273 |
+
|
274 |
+
# Global.cached_lora_models.clear()
|
275 |
+
|
276 |
+
# clear_cache()
|
277 |
+
|
278 |
+
# Global.model_has_been_used = False
|
279 |
|
280 |
|
281 |
+
# def unload_models_if_already_used():
|
282 |
+
# if Global.model_has_been_used:
|
283 |
+
# unload_models()
|
llama_lora/ui/finetune_ui.py
CHANGED
@@ -10,8 +10,8 @@ from transformers import TrainerCallback
|
|
10 |
|
11 |
from ..globals import Global
|
12 |
from ..models import (
|
13 |
-
|
14 |
-
clear_cache,
|
15 |
from ..utils.data import (
|
16 |
get_available_template_names,
|
17 |
get_available_dataset_names,
|
@@ -269,14 +269,16 @@ def do_train(
|
|
269 |
progress=gr.Progress(track_tqdm=should_training_progress_track_tqdm),
|
270 |
):
|
271 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
272 |
if not should_training_progress_track_tqdm:
|
273 |
progress(0, desc="Preparing train data...")
|
274 |
|
275 |
-
|
276 |
-
# If model has been used in inference, we need to unload it first.
|
277 |
-
# Otherwise, we'll get a 'Function MmBackward0 returned an invalid
|
278 |
-
# gradient at index 1 - expected device meta but got cuda:0' error.
|
279 |
-
unload_models_if_already_used()
|
280 |
|
281 |
prompter = Prompter(template)
|
282 |
variable_names = prompter.get_variable_names()
|
@@ -415,17 +417,12 @@ Train data (first 10):
|
|
415 |
|
416 |
Global.should_stop_training = False
|
417 |
|
418 |
-
|
419 |
-
|
420 |
-
clear_cache()
|
421 |
-
|
422 |
-
base_model = get_base_model()
|
423 |
-
tokenizer = get_tokenizer()
|
424 |
|
425 |
# Do not let other tqdm iterations interfere the progress reporting after training starts.
|
426 |
# progress.track_tqdm = False # setting this dynamically is not working, determining if track_tqdm should be enabled based on GPU cores at start instead.
|
427 |
|
428 |
-
output_dir = os.path.join(Global.data_dir, "lora_models", model_name)
|
429 |
if not os.path.exists(output_dir):
|
430 |
os.makedirs(output_dir)
|
431 |
|
@@ -435,10 +432,11 @@ Train data (first 10):
|
|
435 |
dataset_name = dataset_from_data_dir
|
436 |
|
437 |
info = {
|
438 |
-
'base_model':
|
439 |
'prompt_template': template,
|
440 |
'dataset_name': dataset_name,
|
441 |
'dataset_rows': len(train_data),
|
|
|
442 |
}
|
443 |
json.dump(info, info_json_file, indent=2)
|
444 |
|
@@ -472,7 +470,11 @@ Train data (first 10):
|
|
472 |
|
473 |
result_message = f"Training ended:\n{str(train_output)}\n\nLogs:\n{logs_str}"
|
474 |
print(result_message)
|
|
|
|
|
|
|
475 |
clear_cache()
|
|
|
476 |
return result_message
|
477 |
|
478 |
except Exception as e:
|
@@ -837,6 +839,12 @@ def finetune_ui():
|
|
837 |
document.getElementById('finetune_confirm_stop_btn').style.display =
|
838 |
'none';
|
839 |
}, 5000);
|
|
|
|
|
|
|
|
|
|
|
|
|
840 |
document.getElementById('finetune_stop_btn').style.display = 'none';
|
841 |
document.getElementById('finetune_confirm_stop_btn').style.display =
|
842 |
'block';
|
|
|
10 |
|
11 |
from ..globals import Global
|
12 |
from ..models import (
|
13 |
+
get_new_base_model, get_tokenizer,
|
14 |
+
clear_cache, unload_models)
|
15 |
from ..utils.data import (
|
16 |
get_available_template_names,
|
17 |
get_available_dataset_names,
|
|
|
269 |
progress=gr.Progress(track_tqdm=should_training_progress_track_tqdm),
|
270 |
):
|
271 |
try:
|
272 |
+
base_model_name = Global.default_base_model_name
|
273 |
+
output_dir = os.path.join(Global.data_dir, "lora_models", model_name)
|
274 |
+
if os.path.exists(output_dir):
|
275 |
+
if (not os.path.isdir(output_dir)) or os.path.exists(os.path.join(output_dir, 'adapter_config.json')):
|
276 |
+
raise ValueError(f"The output directory already exists and is not empty. ({output_dir})")
|
277 |
+
|
278 |
if not should_training_progress_track_tqdm:
|
279 |
progress(0, desc="Preparing train data...")
|
280 |
|
281 |
+
unload_models() # Need RAM for training
|
|
|
|
|
|
|
|
|
282 |
|
283 |
prompter = Prompter(template)
|
284 |
variable_names = prompter.get_variable_names()
|
|
|
417 |
|
418 |
Global.should_stop_training = False
|
419 |
|
420 |
+
base_model = get_new_base_model(base_model_name)
|
421 |
+
tokenizer = get_tokenizer(base_model_name)
|
|
|
|
|
|
|
|
|
422 |
|
423 |
# Do not let other tqdm iterations interfere the progress reporting after training starts.
|
424 |
# progress.track_tqdm = False # setting this dynamically is not working, determining if track_tqdm should be enabled based on GPU cores at start instead.
|
425 |
|
|
|
426 |
if not os.path.exists(output_dir):
|
427 |
os.makedirs(output_dir)
|
428 |
|
|
|
432 |
dataset_name = dataset_from_data_dir
|
433 |
|
434 |
info = {
|
435 |
+
'base_model': base_model_name,
|
436 |
'prompt_template': template,
|
437 |
'dataset_name': dataset_name,
|
438 |
'dataset_rows': len(train_data),
|
439 |
+
'timestamp': time.time()
|
440 |
}
|
441 |
json.dump(info, info_json_file, indent=2)
|
442 |
|
|
|
470 |
|
471 |
result_message = f"Training ended:\n{str(train_output)}\n\nLogs:\n{logs_str}"
|
472 |
print(result_message)
|
473 |
+
|
474 |
+
del base_model
|
475 |
+
del tokenizer
|
476 |
clear_cache()
|
477 |
+
|
478 |
return result_message
|
479 |
|
480 |
except Exception as e:
|
|
|
839 |
document.getElementById('finetune_confirm_stop_btn').style.display =
|
840 |
'none';
|
841 |
}, 5000);
|
842 |
+
document.getElementById('finetune_confirm_stop_btn').style['pointer-events'] =
|
843 |
+
'none';
|
844 |
+
setTimeout(function () {
|
845 |
+
document.getElementById('finetune_confirm_stop_btn').style['pointer-events'] =
|
846 |
+
'inherit';
|
847 |
+
}, 300);
|
848 |
document.getElementById('finetune_stop_btn').style.display = 'none';
|
849 |
document.getElementById('finetune_confirm_stop_btn').style.display =
|
850 |
'block';
|
llama_lora/ui/inference_ui.py
CHANGED
@@ -7,11 +7,10 @@ import transformers
|
|
7 |
from transformers import GenerationConfig
|
8 |
|
9 |
from ..globals import Global
|
10 |
-
from ..models import
|
11 |
from ..utils.data import (
|
12 |
get_available_template_names,
|
13 |
get_available_lora_model_names,
|
14 |
-
get_path_of_available_lora_model,
|
15 |
get_info_of_available_lora_model)
|
16 |
from ..utils.prompter import Prompter
|
17 |
from ..utils.callbacks import Iteratorize, Stream
|
@@ -22,6 +21,18 @@ default_show_raw = True
|
|
22 |
inference_output_lines = 12
|
23 |
|
24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
def do_inference(
|
26 |
lora_model_name,
|
27 |
prompt_template,
|
@@ -37,6 +48,8 @@ def do_inference(
|
|
37 |
show_raw=False,
|
38 |
progress=gr.Progress(track_tqdm=True),
|
39 |
):
|
|
|
|
|
40 |
try:
|
41 |
if Global.generation_force_stopped_at is not None:
|
42 |
required_elapsed_time_after_forced_stop = 1
|
@@ -52,16 +65,8 @@ def do_inference(
|
|
52 |
prompter = Prompter(prompt_template)
|
53 |
prompt = prompter.generate_prompt(variables)
|
54 |
|
55 |
-
if not lora_model_name:
|
56 |
-
lora_model_name = "None"
|
57 |
-
if "/" not in lora_model_name and lora_model_name != "None":
|
58 |
-
path_of_available_lora_model = get_path_of_available_lora_model(
|
59 |
-
lora_model_name)
|
60 |
-
if path_of_available_lora_model:
|
61 |
-
lora_model_name = path_of_available_lora_model
|
62 |
-
|
63 |
if Global.ui_dev_mode:
|
64 |
-
message = f"Hi, I’m currently in UI-development mode and do not have access to resources to process your request. However, this behavior is similar to what will actually happen, so you can try and see how it will work!\n\nBase model: {
|
65 |
print(message)
|
66 |
|
67 |
if stream_output:
|
@@ -90,18 +95,13 @@ def do_inference(
|
|
90 |
return
|
91 |
time.sleep(1)
|
92 |
yield (
|
93 |
-
gr.Textbox.update(value=message, lines=1),
|
94 |
json.dumps(list(range(len(message.split()))), indent=2)
|
95 |
)
|
96 |
return
|
97 |
|
98 |
-
|
99 |
-
|
100 |
-
model = get_model_with_lora(lora_model_name)
|
101 |
-
else:
|
102 |
-
raise ValueError("No LoRA model selected.")
|
103 |
-
|
104 |
-
tokenizer = get_tokenizer()
|
105 |
|
106 |
inputs = tokenizer(prompt, return_tensors="pt")
|
107 |
input_ids = inputs["input_ids"].to(device)
|
@@ -210,7 +210,6 @@ def do_inference(
|
|
210 |
gr.Textbox.update(value=response, lines=inference_output_lines),
|
211 |
raw_output)
|
212 |
|
213 |
-
|
214 |
except Exception as e:
|
215 |
raise gr.Error(e)
|
216 |
|
@@ -232,7 +231,7 @@ def reload_selections(current_lora_model, current_prompt_template):
|
|
232 |
|
233 |
default_lora_models = ["tloen/alpaca-lora-7b"]
|
234 |
available_lora_models = default_lora_models + get_available_lora_model_names()
|
235 |
-
available_lora_models = available_lora_models
|
236 |
|
237 |
current_lora_model = current_lora_model or next(
|
238 |
iter(available_lora_models), None)
|
@@ -462,6 +461,10 @@ def inference_ui():
|
|
462 |
things_that_might_timeout.append(lora_model_change_event)
|
463 |
|
464 |
generate_event = generate_btn.click(
|
|
|
|
|
|
|
|
|
465 |
fn=do_inference,
|
466 |
inputs=[
|
467 |
lora_model,
|
|
|
7 |
from transformers import GenerationConfig
|
8 |
|
9 |
from ..globals import Global
|
10 |
+
from ..models import get_model, get_tokenizer, get_device
|
11 |
from ..utils.data import (
|
12 |
get_available_template_names,
|
13 |
get_available_lora_model_names,
|
|
|
14 |
get_info_of_available_lora_model)
|
15 |
from ..utils.prompter import Prompter
|
16 |
from ..utils.callbacks import Iteratorize, Stream
|
|
|
21 |
inference_output_lines = 12
|
22 |
|
23 |
|
24 |
+
def prepare_inference(lora_model_name, progress=gr.Progress(track_tqdm=True)):
|
25 |
+
base_model_name = Global.default_base_model_name
|
26 |
+
|
27 |
+
try:
|
28 |
+
get_tokenizer(base_model_name)
|
29 |
+
get_model(base_model_name, lora_model_name)
|
30 |
+
return ("", "")
|
31 |
+
|
32 |
+
except Exception as e:
|
33 |
+
raise gr.Error(e)
|
34 |
+
|
35 |
+
|
36 |
def do_inference(
|
37 |
lora_model_name,
|
38 |
prompt_template,
|
|
|
48 |
show_raw=False,
|
49 |
progress=gr.Progress(track_tqdm=True),
|
50 |
):
|
51 |
+
base_model_name = Global.default_base_model_name
|
52 |
+
|
53 |
try:
|
54 |
if Global.generation_force_stopped_at is not None:
|
55 |
required_elapsed_time_after_forced_stop = 1
|
|
|
65 |
prompter = Prompter(prompt_template)
|
66 |
prompt = prompter.generate_prompt(variables)
|
67 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
if Global.ui_dev_mode:
|
69 |
+
message = f"Hi, I’m currently in UI-development mode and do not have access to resources to process your request. However, this behavior is similar to what will actually happen, so you can try and see how it will work!\n\nBase model: {base_model_name}\nLoRA model: {lora_model_name}\n\nThe following is your prompt:\n\n{prompt}"
|
70 |
print(message)
|
71 |
|
72 |
if stream_output:
|
|
|
95 |
return
|
96 |
time.sleep(1)
|
97 |
yield (
|
98 |
+
gr.Textbox.update(value=message, lines=1), # TODO
|
99 |
json.dumps(list(range(len(message.split()))), indent=2)
|
100 |
)
|
101 |
return
|
102 |
|
103 |
+
tokenizer = get_tokenizer(base_model_name)
|
104 |
+
model = get_model(base_model_name, lora_model_name)
|
|
|
|
|
|
|
|
|
|
|
105 |
|
106 |
inputs = tokenizer(prompt, return_tensors="pt")
|
107 |
input_ids = inputs["input_ids"].to(device)
|
|
|
210 |
gr.Textbox.update(value=response, lines=inference_output_lines),
|
211 |
raw_output)
|
212 |
|
|
|
213 |
except Exception as e:
|
214 |
raise gr.Error(e)
|
215 |
|
|
|
231 |
|
232 |
default_lora_models = ["tloen/alpaca-lora-7b"]
|
233 |
available_lora_models = default_lora_models + get_available_lora_model_names()
|
234 |
+
available_lora_models = available_lora_models + ["None"]
|
235 |
|
236 |
current_lora_model = current_lora_model or next(
|
237 |
iter(available_lora_models), None)
|
|
|
461 |
things_that_might_timeout.append(lora_model_change_event)
|
462 |
|
463 |
generate_event = generate_btn.click(
|
464 |
+
fn=prepare_inference,
|
465 |
+
inputs=[lora_model],
|
466 |
+
outputs=[inference_output, inference_raw_output],
|
467 |
+
).then(
|
468 |
fn=do_inference,
|
469 |
inputs=[
|
470 |
lora_model,
|
llama_lora/ui/main_page.py
CHANGED
@@ -1,7 +1,6 @@
|
|
1 |
import gradio as gr
|
2 |
|
3 |
from ..globals import Global
|
4 |
-
from ..models import get_model_with_lora
|
5 |
|
6 |
from .inference_ui import inference_ui
|
7 |
from .finetune_ui import finetune_ui
|
@@ -31,7 +30,7 @@ def main_page():
|
|
31 |
info = []
|
32 |
if Global.version:
|
33 |
info.append(f"LLaMA-LoRA Tuner `{Global.version}`")
|
34 |
-
info.append(f"Base model: `{Global.
|
35 |
if Global.ui_show_sys_info:
|
36 |
info.append(f"Data dir: `{Global.data_dir}`")
|
37 |
gr.Markdown(f"""
|
|
|
1 |
import gradio as gr
|
2 |
|
3 |
from ..globals import Global
|
|
|
4 |
|
5 |
from .inference_ui import inference_ui
|
6 |
from .finetune_ui import finetune_ui
|
|
|
30 |
info = []
|
31 |
if Global.version:
|
32 |
info.append(f"LLaMA-LoRA Tuner `{Global.version}`")
|
33 |
+
info.append(f"Base model: `{Global.default_base_model_name}`")
|
34 |
if Global.ui_show_sys_info:
|
35 |
info.append(f"Data dir: `{Global.data_dir}`")
|
36 |
gr.Markdown(f"""
|
llama_lora/ui/tokenizer_ui.py
CHANGED
@@ -7,11 +7,12 @@ from ..models import get_tokenizer
|
|
7 |
|
8 |
|
9 |
def handle_decode(encoded_tokens_json):
|
|
|
10 |
try:
|
11 |
encoded_tokens = json.loads(encoded_tokens_json)
|
12 |
if Global.ui_dev_mode:
|
13 |
return f"Not actually decoding tokens in UI dev mode.", gr.Markdown.update("", visible=False)
|
14 |
-
tokenizer = get_tokenizer()
|
15 |
decoded_tokens = tokenizer.decode(encoded_tokens)
|
16 |
return decoded_tokens, gr.Markdown.update("", visible=False)
|
17 |
except Exception as e:
|
@@ -19,10 +20,11 @@ def handle_decode(encoded_tokens_json):
|
|
19 |
|
20 |
|
21 |
def handle_encode(decoded_tokens):
|
|
|
22 |
try:
|
23 |
if Global.ui_dev_mode:
|
24 |
return f"[\"Not actually encoding tokens in UI dev mode.\"]", gr.Markdown.update("", visible=False)
|
25 |
-
tokenizer = get_tokenizer()
|
26 |
result = tokenizer(decoded_tokens)
|
27 |
encoded_tokens_json = json.dumps(result['input_ids'], indent=2)
|
28 |
return encoded_tokens_json, gr.Markdown.update("", visible=False)
|
|
|
7 |
|
8 |
|
9 |
def handle_decode(encoded_tokens_json):
|
10 |
+
base_model_name = Global.default_base_model_name
|
11 |
try:
|
12 |
encoded_tokens = json.loads(encoded_tokens_json)
|
13 |
if Global.ui_dev_mode:
|
14 |
return f"Not actually decoding tokens in UI dev mode.", gr.Markdown.update("", visible=False)
|
15 |
+
tokenizer = get_tokenizer(base_model_name)
|
16 |
decoded_tokens = tokenizer.decode(encoded_tokens)
|
17 |
return decoded_tokens, gr.Markdown.update("", visible=False)
|
18 |
except Exception as e:
|
|
|
20 |
|
21 |
|
22 |
def handle_encode(decoded_tokens):
|
23 |
+
base_model_name = Global.default_base_model_name
|
24 |
try:
|
25 |
if Global.ui_dev_mode:
|
26 |
return f"[\"Not actually encoding tokens in UI dev mode.\"]", gr.Markdown.update("", visible=False)
|
27 |
+
tokenizer = get_tokenizer(base_model_name)
|
28 |
result = tokenizer(decoded_tokens)
|
29 |
encoded_tokens_json = json.dumps(result['input_ids'], indent=2)
|
30 |
return encoded_tokens_json, gr.Markdown.update("", visible=False)
|
llama_lora/utils/lru_cache.py
CHANGED
@@ -25,3 +25,7 @@ class LRUCache:
|
|
25 |
|
26 |
def clear(self):
|
27 |
self.cache.clear()
|
|
|
|
|
|
|
|
|
|
25 |
|
26 |
def clear(self):
|
27 |
self.cache.clear()
|
28 |
+
|
29 |
+
def prepare_to_set(self):
|
30 |
+
if len(self.cache) >= self.capacity:
|
31 |
+
self.cache.popitem(last=False)
|