mrcuddle commited on
Commit
dbb2714
·
verified ·
1 Parent(s): 1efbb9e

Update hf_merge.py

Browse files
Files changed (1) hide show
  1. hf_merge.py +130 -83
hf_merge.py CHANGED
@@ -1,84 +1,131 @@
1
- import gradio as gr
2
- import subprocess
3
  import os
4
- import logging
5
- from pathlib import Path
6
- import spaces
7
-
8
- @spaces.GPU
9
- def merge_and_upload(base_model, model_to_merge, scaling_factor, weight_drop_prob, repo_name, token, commit_message):
10
- # Define a fixed output path
11
- outpath = Path('/tmp/output')
12
-
13
- # Construct the command to run hf_merge.py
14
- command = [
15
- "python3", "hf_merge.py",
16
- base_model,
17
- model_to_merge,
18
- "-p", str(weight_drop_prob),
19
- "-lambda", str(scaling_factor),
20
- "--token", token,
21
- "--repo", repo_name,
22
- "--commit-message", commit_message,
23
- "-U"
24
- ]
25
-
26
- # Set up logging
27
- logging.basicConfig(level=logging.INFO)
28
- log_output = ""
29
-
30
- # Run the command and capture the output
31
- result = subprocess.run(command, capture_output=True, text=True)
32
-
33
- # Log the output
34
- log_output += result.stdout + "\n"
35
- log_output += result.stderr + "\n"
36
- logging.info(result.stdout)
37
- logging.error(result.stderr)
38
-
39
- # Check if the merge was successful
40
- if result.returncode != 0:
41
- return None, f"Error in merging models: {result.stderr}", log_output
42
-
43
- # Assuming the script handles the upload and returns the repo URL
44
- repo_url = f"https://huggingface.co/{repo_name}"
45
- return repo_url, "Model merged and uploaded successfully!", log_output
46
-
47
- # Define the Gradio interface
48
- with gr.Blocks(theme="Ytheme/Minecraft", fill_width=True, delete_cache=(60, 3600)) as demo:
49
- gr.Markdown("# SuperMario Safetensors Merger")
50
- gr.Markdown("Combine any two models using a Super Mario merge(DARE)")
51
- gr.Markdown("Based on: https://github.com/martyn/safetensors-merge-supermario")
52
- gr.Markdown("Works with:")
53
- gr.Markdown("* Stable Diffusion (1.5, XL/XL Turbo)")
54
- gr.Markdown("* LLMs (Mistral, Llama, etc)")
55
- gr.Markdown("* LoRas (must be same size)")
56
- gr.Markdown("* Any two homologous models")
57
-
58
- with gr.Column():
59
- with gr.Row():
60
- token = gr.Textbox(label="Your HF write token", placeholder="hf_...", value="", max_lines=1)
61
- with gr.Row():
62
- base_model = gr.Textbox(label="Base Model", placeholder=".safetensors")
63
- with gr.Row():
64
- model_to_merge = gr.Textbox(label="Merge Model", placeholder=".bin/.safetensors")
65
- with gr.Row():
66
- repo_name = gr.Textbox(label="New Model", placeholder="SDXL-", info="If empty, auto-complete", value="", max_lines=1)
67
- with gr.Row():
68
- scaling_factor = gr.Slider(minimum=0, maximum=10, value=3.0, label="Scaling Factor")
69
- with gr.Row():
70
- weight_drop_prob = gr.Slider(minimum=0, maximum=1, value=0.3, label="Weight Drop Probability")
71
- with gr.Row():
72
- commit_message = gr.Textbox(label="Commit Message", value="Upload merged model", max_lines=1)
73
-
74
- progress = gr.Progress()
75
- repo_url = gr.Markdown(label="Repository URL")
76
- output = gr.Textbox(label="Output")
77
-
78
- gr.Button("Merge").click(
79
- merge_and_upload,
80
- inputs=[base_model, model_to_merge, scaling_factor, weight_drop_prob, repo_name, token, commit_message],
81
- outputs=[repo_url, output]
82
- )
83
-
84
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+ import shutil
3
+ import argparse
4
+ import requests
5
+ from tqdm import tqdm
6
+ from huggingface_hub import HfApi, Repository, hf_hub_download, upload_folder
7
+ from merge import merge_folder, map_tensors_to_files, copy_nontensor_files, save_tensor_map
8
+
9
+ class RepositoryManager:
10
+ base_model_path = os.path.join(os.getcwd(), "base_model")
11
+
12
+ def __init__(self, repo_id=None, token=None):
13
+ self.repo_id = repo_id
14
+ self.token = token
15
+ self.api = HfApi(token=token) if token else HfApi()
16
+
17
+ def download_repo(self, repo_name, path):
18
+ if os.path.isdir(repo_name):
19
+ if not os.path.exists(path):
20
+ os.makedirs(path)
21
+ shutil.copytree(repo_name, path, dirs_exist_ok=True)
22
+ else:
23
+ if not os.path.exists(path):
24
+ os.makedirs(path)
25
+
26
+ repo_files = self.api.list_repo_files(repo_name)
27
+
28
+ for file_path in tqdm(repo_files, desc=f"Downloading {repo_name}"):
29
+ file_url = f"https://huggingface.co/{repo_name}/resolve/main/{file_path}"
30
+ hf_hub_download(repo_id=repo_name, filename=file_path, cache_dir=path, local_dir=path)
31
+
32
+ def delete_repo(self, path):
33
+ shutil.rmtree(path, ignore_errors=True)
34
+
35
+ class ModelMerger:
36
+
37
+ def __init__(self, repo_id=None, token=None):
38
+ self.repo_id = repo_id
39
+ self.token = token
40
+ self.api = HfApi(token=token) if token else HfApi()
41
+ self.tensor_map = None
42
+
43
+ def prepare_base_model(self, base_model_name, base_model_path):
44
+ repo_manager = RepositoryManager(self.repo_id, self.token)
45
+ repo_manager.download_repo(base_model_name, base_model_path)
46
+ self.tensor_map = map_tensors_to_files(base_model_path)
47
+
48
+ def merge_repo(self, repo_name, repo_path, p, lambda_val):
49
+ repo_manager = RepositoryManager(self.repo_id, self.token)
50
+ repo_manager.delete_repo(repo_path)
51
+ repo_manager.download_repo(repo_name, repo_path)
52
+
53
+ try:
54
+ self.tensor_map = merge_folder(self.tensor_map, repo_path, p, lambda_val)
55
+ print(f"Merged {repo_name}")
56
+ except Exception as e:
57
+ print(f"Error merging {repo_name}: {e}")
58
+
59
+ def finalize_merge(self, output_dir):
60
+ base_model_path = os.path.join(os.getcwd(), "base_model")
61
+ copy_nontensor_files(base_model_path, output_dir)
62
+ save_tensor_map(self.tensor_map, output_dir)
63
+
64
+ def upload_model(self, output_dir, repo_name, commit_message):
65
+ repo = Repository(repo_id=self.repo_id, token=self.token)
66
+ repo.create_branch("main", "main")
67
+ repo.upload_folder(output_dir, repo_path=repo_name, commit_message=commit_message)
68
+ print(f"Model uploaded to {repo_name}")
69
+
70
+ def get_max_vocab_size(repo_list):
71
+ max_vocab_size = 0
72
+ repo_with_max_vocab = None
73
+ base_url = "https://huggingface.co/{}/raw/main/config.json"
74
+
75
+ for repo_name, _, _ in repo_list:
76
+ url = base_url.format(repo_name)
77
+ try:
78
+ response = requests.get(url)
79
+ config = response.json()
80
+ vocab_size = config.get('vocab_size', 0)
81
+ if vocab_size > max_vocab_size:
82
+ max_vocab_size = vocab_size
83
+ repo_with_max_vocab = repo_name
84
+ except requests.RequestException as e:
85
+ print(f"Error fetching vocab size from {repo_name}: {e}")
86
+
87
+ return max_vocab_size, repo_with_max_vocab
88
+
89
+ def download_json_files(repo_name, file_paths, output_dir):
90
+ base_url = f"https://huggingface.co/{repo_name}/raw/main/"
91
+ for file_path in file_paths:
92
+ url = base_url + file_path
93
+ response = requests.get(url)
94
+ if response.status_code == 200:
95
+ with open(os.path.join(output_dir, os.path.basename(file_path)), 'wb') as file:
96
+ file.write(response.content)
97
+ else:
98
+ print(f"Failed to download {file_path} from {repo_name}")
99
+
100
+ def main():
101
+ parser = argparse.ArgumentParser(description="Merge and upload HuggingFace models")
102
+ parser.add_argument('base_model', type=str, help='Base model safetensors file')
103
+ parser.add_argument('model_to_merge', type=str, help='Model to merge (.safetensors or .bin)')
104
+ parser.add_argument('-p', type=float, default=0.5, help='Dropout probability')
105
+ parser.add_argument('-lambda', '--lambda_value', type=float, default=3.0, help='Scaling factor (optional)')
106
+ parser.add_argument('--token', type=str, help='HuggingFace token (required for uploading)')
107
+ parser.add_argument('--repo', type=str, help='HuggingFace repo to upload to (required for uploading)')
108
+ parser.add_argument('--commit-message', type=str, default='Upload merged model', help='Commit message for model upload')
109
+ parser.add_argument('-U', '--upload', action='store_true', help='Upload the merged model to HuggingFace Hub')
110
+ args = parser.parse_args()
111
+
112
+ base_model_path = os.path.join(os.getcwd(), "base_model")
113
+ model_to_merge_path = os.path.join(os.getcwd(), "model_to_merge")
114
+ output_dir = os.path.join(os.getcwd(), "output")
115
+
116
+ model_merger = ModelMerger(args.repo, args.token)
117
+ model_merger.prepare_base_model(args.base_model, base_model_path)
118
+
119
+ model_merger.merge_repo(args.model_to_merge, model_to_merge_path, args.p, args.lambda_value)
120
+
121
+ model_merger.finalize_merge(output_dir)
122
+
123
+ # Upload model only if --upload parameter is provided
124
+ if args.upload:
125
+ if not args.token or not args.repo:
126
+ print("Error: HuggingFace token and repo name are required for uploading.")
127
+ else:
128
+ model_merger.upload_model(output_dir, args.repo, args.commit_message)
129
+
130
+ if __name__ == "__main__":
131
+ main()