Wauplin HF staff commited on
Commit
64e99f5
1 Parent(s): c30542b

run in thread

Browse files
Files changed (1) hide show
  1. app.py +44 -13
app.py CHANGED
@@ -1,22 +1,35 @@
 
1
  import pathlib
2
  import tempfile
3
  from typing import Generator
4
 
5
  import gradio as gr
 
6
  import torch
7
  import yaml
8
- from gradio_logsview import LogsView
 
 
 
9
 
10
  has_gpu = torch.cuda.is_available()
11
 
12
- cli = "mergekit-yaml config.yaml merge --copy-tokenizer" + (
13
- " --cuda --low-cpu-memory"
 
 
 
 
 
 
14
  if has_gpu
15
- else " --allow-crimes --out-shard-size 1B --lazy-unpickle"
 
 
 
 
 
16
  )
17
-
18
- print(cli)
19
-
20
  ## This Space is heavily inspired by LazyMergeKit by Maxime Labonne
21
  ## https://colab.research.google.com/drive/1obulZ1ROXHjYLn6PPZJwRR6GzgQogxxb
22
 
@@ -71,19 +84,31 @@ def merge(
71
  if not yaml_config:
72
  raise gr.Error("Empty yaml, pick an example below")
73
  try:
74
- _ = yaml.safe_load(yaml_config)
75
  except Exception as e:
76
  raise gr.Error(f"Invalid yaml {e}")
77
 
78
  with tempfile.TemporaryDirectory() as tmpdirname:
79
  tmpdir = pathlib.Path(tmpdirname)
80
-
81
- config_path = tmpdir / "config.yaml"
 
82
  config_path.write_text(yaml_config)
83
 
84
- yield from LogsView.run_process(cli.split(), cwd=tmpdir)
 
 
 
 
 
 
 
85
 
86
  ## TODO(implement upload at the end of the merge, and display the repo URL)
 
 
 
 
87
 
88
 
89
  with gr.Blocks() as demo:
@@ -111,8 +136,14 @@ with gr.Blocks() as demo:
111
  )
112
  button = gr.Button("Merge", variant="primary")
113
  logs = LogsView()
114
- gr.Examples(examples, fn=lambda s: (s,), run_on_click=True,
115
- label="Examples", inputs=[filename], outputs=[config])
 
 
 
 
 
 
116
  gr.Markdown(MARKDOWN_ARTICLE)
117
 
118
  button.click(fn=merge, inputs=[filename, config, token, repo_name], outputs=[logs])
 
1
+ import logging
2
  import pathlib
3
  import tempfile
4
  from typing import Generator
5
 
6
  import gradio as gr
7
+ import huggingface_hub
8
  import torch
9
  import yaml
10
+ from gradio_logsview.logsview import Log, LogsView
11
+ from mergekit.common import parse_kmb
12
+ from mergekit.merge import run_merge
13
+ from mergekit.options import MergeOptions
14
 
15
  has_gpu = torch.cuda.is_available()
16
 
17
+ # Inspired by https://github.com/arcee-ai/mergekit/blob/main/mergekit/scripts/run_yaml.py
18
+ merge_options = (
19
+ MergeOptions(
20
+ copy_tokenizer=True,
21
+ cuda=True,
22
+ low_cpu_memory=True,
23
+ write_model_card=True,
24
+ )
25
  if has_gpu
26
+ else MergeOptions(
27
+ allow_crimes=True,
28
+ out_shard_size=parse_kmb("1B"),
29
+ lazy_unpickle=True,
30
+ write_model_card=True,
31
+ )
32
  )
 
 
 
33
  ## This Space is heavily inspired by LazyMergeKit by Maxime Labonne
34
  ## https://colab.research.google.com/drive/1obulZ1ROXHjYLn6PPZJwRR6GzgQogxxb
35
 
 
84
  if not yaml_config:
85
  raise gr.Error("Empty yaml, pick an example below")
86
  try:
87
+ merge_config = yaml.safe_load(yaml_config)
88
  except Exception as e:
89
  raise gr.Error(f"Invalid yaml {e}")
90
 
91
  with tempfile.TemporaryDirectory() as tmpdirname:
92
  tmpdir = pathlib.Path(tmpdirname)
93
+ merged_path = tmpdir / "merged"
94
+ merged_path.mkdir(parents=True, exist_ok=True)
95
+ config_path = merged_path / "config.yaml"
96
  config_path.write_text(yaml_config)
97
 
98
+ yield from LogsView.run_thread(
99
+ run_merge,
100
+ log_level=logging.INFO,
101
+ merge_config=merge_config,
102
+ out_path=merged_path,
103
+ options=merge_options,
104
+ config_source=config_path,
105
+ )
106
 
107
  ## TODO(implement upload at the end of the merge, and display the repo URL)
108
+ api = huggingface_hub.HfApi(token=hf_token)
109
+ repo_url = api.create_repo(repo_name, exist_ok=True)
110
+ api.upload_folder(repo_id=repo_url.repo_id, folder_path=merged_path)
111
+ print(repo_url)
112
 
113
 
114
  with gr.Blocks() as demo:
 
136
  )
137
  button = gr.Button("Merge", variant="primary")
138
  logs = LogsView()
139
+ gr.Examples(
140
+ examples,
141
+ fn=lambda s: (s,),
142
+ run_on_click=True,
143
+ label="Examples",
144
+ inputs=[filename],
145
+ outputs=[config],
146
+ )
147
  gr.Markdown(MARKDOWN_ARTICLE)
148
 
149
  button.click(fn=merge, inputs=[filename, config, token, repo_name], outputs=[logs])