djuna commited on
Commit
c5377ed
1 Parent(s): c963f2b

move to djuna-test-lab

Browse files
Files changed (1) hide show
  1. app.py +145 -34
app.py CHANGED
@@ -11,6 +11,7 @@ import gradio as gr
11
  import huggingface_hub
12
  import torch
13
  import yaml
 
14
  from gradio_logsview.logsview import Log, LogsView, LogsViewRunner
15
  from mergekit.config import MergeConfiguration
16
 
@@ -43,7 +44,7 @@ has_gpu = torch.cuda.is_available()
43
  # )
44
 
45
  cli = "mergekit-yaml config.yaml merge --copy-tokenizer" + (
46
- " --cuda --low-cpu-memory --allow-crimes" if has_gpu else " --allow-crimes --out-shard-size 1B --lazy-unpickle"
47
  )
48
 
49
  MARKDOWN_DESCRIPTION = """
@@ -106,33 +107,35 @@ This Space is heavily inspired by LazyMergeKit by Maxime Labonne (see [Colab](ht
106
  examples = [[str(f)] for f in pathlib.Path("examples").glob("*.yaml")]
107
 
108
  # Do not set community token as `HF_TOKEN` to avoid accidentally using it in merge scripts.
109
- # `COMMUNITY_HF_TOKEN` is used to upload models to the community organization (https://huggingface.co/mergekit-community)
110
  # when user do not provide a token.
111
  COMMUNITY_HF_TOKEN = os.getenv("COMMUNITY_HF_TOKEN")
112
 
113
 
114
- def merge(yaml_config: str, hf_token: str, repo_name: str) -> Iterable[List[Log]]:
115
  runner = LogsViewRunner()
116
 
117
  if not yaml_config:
118
  yield runner.log("Empty yaml, pick an example below", level="ERROR")
119
  return
120
- try:
121
- merge_config = MergeConfiguration.model_validate(yaml.safe_load(yaml_config))
122
- except Exception as e:
123
- yield runner.log(f"Invalid yaml {e}", level="ERROR")
124
- return
 
 
125
 
126
  is_community_model = False
127
  if not hf_token:
128
- if "/" in repo_name and not repo_name.startswith("mergekit-community/"):
129
  yield runner.log(
130
  f"Cannot upload merge model to namespace {repo_name.split('/')[0]}: you must provide a valid token.",
131
  level="ERROR",
132
  )
133
  return
134
  yield runner.log(
135
- "No HF token provided. Your merged model will be uploaded to the https://huggingface.co/mergekit-community organization."
136
  )
137
  is_community_model = True
138
  if not COMMUNITY_HF_TOKEN:
@@ -156,8 +159,8 @@ def merge(yaml_config: str, hf_token: str, repo_name: str) -> Iterable[List[Log]
156
  repo_name += "-" + "".join(random.choices(string.ascii_lowercase, k=7))
157
  repo_name = repo_name.replace("/", "-").strip("-")
158
 
159
- if is_community_model and not repo_name.startswith("mergekit-community/"):
160
- repo_name = f"mergekit-community/{repo_name}"
161
 
162
  try:
163
  yield runner.log(f"Creating repo {repo_name}")
@@ -170,7 +173,7 @@ def merge(yaml_config: str, hf_token: str, repo_name: str) -> Iterable[List[Log]
170
  # Set tmp HF_HOME to avoid filling up disk Space
171
  tmp_env = os.environ.copy() # taken from https://stackoverflow.com/a/4453495
172
  tmp_env["HF_HOME"] = f"{tmpdirname}/.cache"
173
- full_cli = cli + f" --lora-merge-cache {tmpdirname}/.lora_cache"
174
  yield from runner.run_command(full_cli.split(), cwd=merged_path, env=tmp_env)
175
 
176
  if runner.exit_code != 0:
@@ -187,27 +190,137 @@ def merge(yaml_config: str, hf_token: str, repo_name: str) -> Iterable[List[Log]
187
  yield runner.log(f"Model successfully uploaded to HF: {repo_url.repo_id}")
188
 
189
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190
  with gr.Blocks() as demo:
191
  gr.Markdown(MARKDOWN_DESCRIPTION)
192
 
193
- with gr.Row():
194
- filename = gr.Textbox(visible=False, label="filename")
195
- config = gr.Code(language="yaml", lines=10, label="config.yaml")
196
- with gr.Column():
197
- token = gr.Textbox(
198
- lines=1,
199
- label="HF Write Token",
200
- info="https://hf.co/settings/token",
201
- type="password",
202
- placeholder="Optional. Will upload merged model to MergeKit Community if empty.",
203
- )
204
- repo_name = gr.Textbox(
205
- lines=1,
206
- label="Repo name",
207
- placeholder="Optional. Will create a random name if empty.",
208
- )
209
- button = gr.Button("Merge", variant="primary")
210
- logs = LogsView(label="Terminal output")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
211
  gr.Examples(
212
  examples,
213
  fn=lambda s: (s,),
@@ -218,11 +331,9 @@ with gr.Blocks() as demo:
218
  )
219
  gr.Markdown(MARKDOWN_ARTICLE)
220
 
221
- button.click(fn=merge, inputs=[config, token, repo_name], outputs=[logs])
222
-
223
 
224
  # Run garbage collection every hour to keep the community org clean.
225
- # Empty models might exists if the merge fails abruptly (e.g. if user leaves the Space).
226
  def _garbage_collect_every_hour():
227
  while True:
228
  try:
 
11
  import huggingface_hub
12
  import torch
13
  import yaml
14
+ import bitsandbytes
15
  from gradio_logsview.logsview import Log, LogsView, LogsViewRunner
16
  from mergekit.config import MergeConfiguration
17
 
 
44
  # )
45
 
46
  cli = "mergekit-yaml config.yaml merge --copy-tokenizer" + (
47
+ " --cuda --low-cpu-memory --allow-crimes" if has_gpu else " --allow-crimes --lazy-unpickle"
48
  )
49
 
50
  MARKDOWN_DESCRIPTION = """
 
107
  examples = [[str(f)] for f in pathlib.Path("examples").glob("*.yaml")]
108
 
109
  # Do not set community token as `HF_TOKEN` to avoid accidentally using it in merge scripts.
110
+ # `COMMUNITY_HF_TOKEN` is used to upload models to the community organization (https://huggingface.co/djuna-test-lab)
111
  # when user do not provide a token.
112
  COMMUNITY_HF_TOKEN = os.getenv("COMMUNITY_HF_TOKEN")
113
 
114
 
115
+ def merge(program: str, yaml_config: str, out_shard_size: str, hf_token: str, repo_name: str) -> Iterable[List[Log]]:
116
  runner = LogsViewRunner()
117
 
118
  if not yaml_config:
119
  yield runner.log("Empty yaml, pick an example below", level="ERROR")
120
  return
121
+ # TODO: validate moe config and mega config?
122
+ if program not in ("mergekit-moe", "mergekit-mega"):
123
+ try:
124
+ merge_config = MergeConfiguration.model_validate(yaml.safe_load(yaml_config))
125
+ except Exception as e:
126
+ yield runner.log(f"Invalid yaml {e}", level="ERROR")
127
+ return
128
 
129
  is_community_model = False
130
  if not hf_token:
131
+ if "/" in repo_name and not repo_name.startswith("djuna-test-lab/"):
132
  yield runner.log(
133
  f"Cannot upload merge model to namespace {repo_name.split('/')[0]}: you must provide a valid token.",
134
  level="ERROR",
135
  )
136
  return
137
  yield runner.log(
138
+ "No HF token provided. Your merged model will be uploaded to the https://huggingface.co/djuna-test-lab organization."
139
  )
140
  is_community_model = True
141
  if not COMMUNITY_HF_TOKEN:
 
159
  repo_name += "-" + "".join(random.choices(string.ascii_lowercase, k=7))
160
  repo_name = repo_name.replace("/", "-").strip("-")
161
 
162
+ if is_community_model and not repo_name.startswith("djuna-test-lab/"):
163
+ repo_name = f"djuna-test-lab/{repo_name}"
164
 
165
  try:
166
  yield runner.log(f"Creating repo {repo_name}")
 
173
  # Set tmp HF_HOME to avoid filling up disk Space
174
  tmp_env = os.environ.copy() # taken from https://stackoverflow.com/a/4453495
175
  tmp_env["HF_HOME"] = f"{tmpdirname}/.cache"
176
+ full_cli = f"{program} {cli} --lora-merge-cache {tmpdirname}/.lora_cache --out-shard-size {out_shard_size}"
177
  yield from runner.run_command(full_cli.split(), cwd=merged_path, env=tmp_env)
178
 
179
  if runner.exit_code != 0:
 
190
  yield runner.log(f"Model successfully uploaded to HF: {repo_url.repo_id}")
191
 
192
 
193
+ def extract(finetuned_model: str, base_model: str, rank: int, hf_token: str, repo_name: str) -> Iterable[List[Log]]:
194
+ runner = LogsViewRunner()
195
+ if not finetuned_model or not base_model:
196
+ yield runner.log("All field should be filled")
197
+
198
+ is_community_model = False
199
+ if not hf_token:
200
+ if "/" in repo_name and not repo_name.startswith("djuna-test-lab/"):
201
+ yield runner.log(
202
+ f"Cannot upload merge model to namespace {repo_name.split('/')[0]}: you must provide a valid token.",
203
+ level="ERROR",
204
+ )
205
+ return
206
+ yield runner.log(
207
+ "No HF token provided. Your lora will be uploaded to the https://huggingface.co/djuna-test-lab organization."
208
+ )
209
+ is_community_model = True
210
+ if not COMMUNITY_HF_TOKEN:
211
+ raise gr.Error("Cannot upload to community org: community token not set by Space owner.")
212
+ hf_token = COMMUNITY_HF_TOKEN
213
+
214
+ api = huggingface_hub.HfApi(token=hf_token)
215
+
216
+ with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as tmpdirname:
217
+ tmpdir = pathlib.Path(tmpdirname)
218
+ merged_path = tmpdir / "merged"
219
+ merged_path.mkdir(parents=True, exist_ok=True)
220
+
221
+ if not repo_name:
222
+ yield runner.log("No repo name provided. Generating a random one.")
223
+ repo_name = "lora"
224
+ # Make repo_name "unique" (no need to be extra careful on uniqueness)
225
+ repo_name += "-" + "".join(random.choices(string.ascii_lowercase, k=7))
226
+ repo_name = repo_name.replace("/", "-").strip("-")
227
+
228
+ if is_community_model and not repo_name.startswith("djuna-test-lab/"):
229
+ repo_name = f"djuna-test-lab/{repo_name}"
230
+
231
+ try:
232
+ yield runner.log(f"Creating repo {repo_name}")
233
+ repo_url = api.create_repo(repo_name, exist_ok=True)
234
+ yield runner.log(f"Repo created: {repo_url}")
235
+ except Exception as e:
236
+ yield runner.log(f"Error creating repo {e}", level="ERROR")
237
+ return
238
+
239
+ # Set tmp HF_HOME to avoid filling up disk Space
240
+ tmp_env = os.environ.copy() # taken from https://stackoverflow.com/a/4453495
241
+ tmp_env["HF_HOME"] = f"{tmpdirname}/.cache"
242
+ full_cli = f"mergekit-extract-lora {finetuned_model} {base_model} lora --rank={rank}"
243
+ yield from runner.run_command(full_cli.split(), cwd=merged_path, env=tmp_env)
244
+
245
+ if runner.exit_code != 0:
246
+ yield runner.log("Lora extraction failed. Deleting repo as no lora is uploaded.", level="ERROR")
247
+ api.delete_repo(repo_url.repo_id)
248
+ return
249
+
250
+ yield runner.log("Lora extracted successfully. Uploading to HF.")
251
+ yield from runner.run_python(
252
+ api.upload_folder,
253
+ repo_id=repo_url.repo_id,
254
+ folder_path=merged_path / "lora",
255
+ )
256
+ yield runner.log(f"Lora successfully uploaded to HF: {repo_url.repo_id}")
257
+
258
+
259
  with gr.Blocks() as demo:
260
  gr.Markdown(MARKDOWN_DESCRIPTION)
261
 
262
+ with gr.Tabs():
263
+ with gr.TabItem("Merge Model"):
264
+ with gr.Row():
265
+ filename = gr.Textbox(visible=False, label="filename")
266
+ config = gr.Code(language="yaml", lines=10, label="config.yaml")
267
+ with gr.Column():
268
+ program = gr.Dropdown(
269
+ ["mergekit-yaml", "mergekit-mega", "mergekit-moe"],
270
+ label="Mergekit Command",
271
+ info="Choose CLI",
272
+ )
273
+ out_shard_size = gr.Dropdown(
274
+ ["500M", "1B", "2B", "3B", "4B", "5B"]
275
+ )
276
+ token = gr.Textbox(
277
+ lines=1,
278
+ label="HF Write Token",
279
+ info="https://hf.co/settings/token",
280
+ type="password",
281
+ placeholder="Optional. Will upload merged model to MergeKit Community if empty.",
282
+ )
283
+ repo_name = gr.Textbox(
284
+ lines=1,
285
+ label="Repo name",
286
+ placeholder="Optional. Will create a random name if empty.",
287
+ )
288
+ button = gr.Button("Merge", variant="primary")
289
+ logs = LogsView(label="Terminal output")
290
+ button.click(fn=merge, inputs=[program, config, out_shard_size, token, repo_name], outputs=[logs])
291
+
292
+ with gr.TabItem("LORA Extraction"):
293
+ with gr.row():
294
+ with gr.Column():
295
+ finetuned_model = gr.Textbox(
296
+ lines=1,
297
+ label="Finetuned Model",
298
+ )
299
+ base_model = gr.Textbox(
300
+ lines=1,
301
+ label="Base Model",
302
+ )
303
+ rank = gr.Dropdown(
304
+ [32, 64, 128],
305
+ label="Rank level",
306
+ value=32,
307
+ )
308
+ with gr.Column():
309
+ token = gr.Textbox(
310
+ lines=1,
311
+ label="HF Write Token",
312
+ info="https://hf.co/settings/token",
313
+ type="password",
314
+ placeholder="Optional. Will upload merged model to MergeKit Community if empty.",
315
+ )
316
+ repo_name = gr.Textbox(
317
+ lines=1,
318
+ label="Repo name",
319
+ placeholder="Optional. Will create a random name if empty.",
320
+ )
321
+ button = gr.Button("Extract LORA", variant="primary")
322
+ logs = LogsView(label="Terminal output")
323
+ button.click(fn=extract, inputs=[finetuned_model, base_model, rank, token, repo_name], outputs=[logs])
324
  gr.Examples(
325
  examples,
326
  fn=lambda s: (s,),
 
331
  )
332
  gr.Markdown(MARKDOWN_ARTICLE)
333
 
 
 
334
 
335
  # Run garbage collection every hour to keep the community org clean.
336
+ # Empty models might exist if the merge fails abruptly (e.g. if user leaves the Space).
337
  def _garbage_collect_every_hour():
338
  while True:
339
  try: