Spaces:
Running
Running
Update hf_merge.py
Browse files- hf_merge.py +130 -83
hf_merge.py
CHANGED
@@ -1,84 +1,131 @@
|
|
1 |
-
import gradio as gr
|
2 |
-
import subprocess
|
3 |
import os
|
4 |
-
import
|
5 |
-
|
6 |
-
import
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import os
|
2 |
+
import shutil
|
3 |
+
import argparse
|
4 |
+
import requests
|
5 |
+
from tqdm import tqdm
|
6 |
+
from huggingface_hub import HfApi, Repository, hf_hub_download, upload_folder
|
7 |
+
from merge import merge_folder, map_tensors_to_files, copy_nontensor_files, save_tensor_map
|
8 |
+
|
9 |
+
class RepositoryManager:
|
10 |
+
base_model_path = os.path.join(os.getcwd(), "base_model")
|
11 |
+
|
12 |
+
def __init__(self, repo_id=None, token=None):
|
13 |
+
self.repo_id = repo_id
|
14 |
+
self.token = token
|
15 |
+
self.api = HfApi(token=token) if token else HfApi()
|
16 |
+
|
17 |
+
def download_repo(self, repo_name, path):
|
18 |
+
if os.path.isdir(repo_name):
|
19 |
+
if not os.path.exists(path):
|
20 |
+
os.makedirs(path)
|
21 |
+
shutil.copytree(repo_name, path, dirs_exist_ok=True)
|
22 |
+
else:
|
23 |
+
if not os.path.exists(path):
|
24 |
+
os.makedirs(path)
|
25 |
+
|
26 |
+
repo_files = self.api.list_repo_files(repo_name)
|
27 |
+
|
28 |
+
for file_path in tqdm(repo_files, desc=f"Downloading {repo_name}"):
|
29 |
+
file_url = f"https://huggingface.co/{repo_name}/resolve/main/{file_path}"
|
30 |
+
hf_hub_download(repo_id=repo_name, filename=file_path, cache_dir=path, local_dir=path)
|
31 |
+
|
32 |
+
def delete_repo(self, path):
|
33 |
+
shutil.rmtree(path, ignore_errors=True)
|
34 |
+
|
35 |
+
class ModelMerger:
|
36 |
+
|
37 |
+
def __init__(self, repo_id=None, token=None):
|
38 |
+
self.repo_id = repo_id
|
39 |
+
self.token = token
|
40 |
+
self.api = HfApi(token=token) if token else HfApi()
|
41 |
+
self.tensor_map = None
|
42 |
+
|
43 |
+
def prepare_base_model(self, base_model_name, base_model_path):
|
44 |
+
repo_manager = RepositoryManager(self.repo_id, self.token)
|
45 |
+
repo_manager.download_repo(base_model_name, base_model_path)
|
46 |
+
self.tensor_map = map_tensors_to_files(base_model_path)
|
47 |
+
|
48 |
+
def merge_repo(self, repo_name, repo_path, p, lambda_val):
|
49 |
+
repo_manager = RepositoryManager(self.repo_id, self.token)
|
50 |
+
repo_manager.delete_repo(repo_path)
|
51 |
+
repo_manager.download_repo(repo_name, repo_path)
|
52 |
+
|
53 |
+
try:
|
54 |
+
self.tensor_map = merge_folder(self.tensor_map, repo_path, p, lambda_val)
|
55 |
+
print(f"Merged {repo_name}")
|
56 |
+
except Exception as e:
|
57 |
+
print(f"Error merging {repo_name}: {e}")
|
58 |
+
|
59 |
+
def finalize_merge(self, output_dir):
|
60 |
+
base_model_path = os.path.join(os.getcwd(), "base_model")
|
61 |
+
copy_nontensor_files(base_model_path, output_dir)
|
62 |
+
save_tensor_map(self.tensor_map, output_dir)
|
63 |
+
|
64 |
+
def upload_model(self, output_dir, repo_name, commit_message):
|
65 |
+
repo = Repository(repo_id=self.repo_id, token=self.token)
|
66 |
+
repo.create_branch("main", "main")
|
67 |
+
repo.upload_folder(output_dir, repo_path=repo_name, commit_message=commit_message)
|
68 |
+
print(f"Model uploaded to {repo_name}")
|
69 |
+
|
70 |
+
def get_max_vocab_size(repo_list):
|
71 |
+
max_vocab_size = 0
|
72 |
+
repo_with_max_vocab = None
|
73 |
+
base_url = "https://huggingface.co/{}/raw/main/config.json"
|
74 |
+
|
75 |
+
for repo_name, _, _ in repo_list:
|
76 |
+
url = base_url.format(repo_name)
|
77 |
+
try:
|
78 |
+
response = requests.get(url)
|
79 |
+
config = response.json()
|
80 |
+
vocab_size = config.get('vocab_size', 0)
|
81 |
+
if vocab_size > max_vocab_size:
|
82 |
+
max_vocab_size = vocab_size
|
83 |
+
repo_with_max_vocab = repo_name
|
84 |
+
except requests.RequestException as e:
|
85 |
+
print(f"Error fetching vocab size from {repo_name}: {e}")
|
86 |
+
|
87 |
+
return max_vocab_size, repo_with_max_vocab
|
88 |
+
|
89 |
+
def download_json_files(repo_name, file_paths, output_dir):
|
90 |
+
base_url = f"https://huggingface.co/{repo_name}/raw/main/"
|
91 |
+
for file_path in file_paths:
|
92 |
+
url = base_url + file_path
|
93 |
+
response = requests.get(url)
|
94 |
+
if response.status_code == 200:
|
95 |
+
with open(os.path.join(output_dir, os.path.basename(file_path)), 'wb') as file:
|
96 |
+
file.write(response.content)
|
97 |
+
else:
|
98 |
+
print(f"Failed to download {file_path} from {repo_name}")
|
99 |
+
|
100 |
+
def main():
|
101 |
+
parser = argparse.ArgumentParser(description="Merge and upload HuggingFace models")
|
102 |
+
parser.add_argument('base_model', type=str, help='Base model safetensors file')
|
103 |
+
parser.add_argument('model_to_merge', type=str, help='Model to merge (.safetensors or .bin)')
|
104 |
+
parser.add_argument('-p', type=float, default=0.5, help='Dropout probability')
|
105 |
+
parser.add_argument('-lambda', '--lambda_value', type=float, default=3.0, help='Scaling factor (optional)')
|
106 |
+
parser.add_argument('--token', type=str, help='HuggingFace token (required for uploading)')
|
107 |
+
parser.add_argument('--repo', type=str, help='HuggingFace repo to upload to (required for uploading)')
|
108 |
+
parser.add_argument('--commit-message', type=str, default='Upload merged model', help='Commit message for model upload')
|
109 |
+
parser.add_argument('-U', '--upload', action='store_true', help='Upload the merged model to HuggingFace Hub')
|
110 |
+
args = parser.parse_args()
|
111 |
+
|
112 |
+
base_model_path = os.path.join(os.getcwd(), "base_model")
|
113 |
+
model_to_merge_path = os.path.join(os.getcwd(), "model_to_merge")
|
114 |
+
output_dir = os.path.join(os.getcwd(), "output")
|
115 |
+
|
116 |
+
model_merger = ModelMerger(args.repo, args.token)
|
117 |
+
model_merger.prepare_base_model(args.base_model, base_model_path)
|
118 |
+
|
119 |
+
model_merger.merge_repo(args.model_to_merge, model_to_merge_path, args.p, args.lambda_value)
|
120 |
+
|
121 |
+
model_merger.finalize_merge(output_dir)
|
122 |
+
|
123 |
+
# Upload model only if --upload parameter is provided
|
124 |
+
if args.upload:
|
125 |
+
if not args.token or not args.repo:
|
126 |
+
print("Error: HuggingFace token and repo name are required for uploading.")
|
127 |
+
else:
|
128 |
+
model_merger.upload_model(output_dir, args.repo, args.commit_message)
|
129 |
+
|
130 |
+
if __name__ == "__main__":
|
131 |
+
main()
|