Upload combine.py
Browse files- combine.py +51 -0
combine.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
from safetensors import safe_open
|
4 |
+
from safetensors.torch import save_file
|
5 |
+
|
6 |
+
|
7 |
+
def merge_safetensors(input_dir, output_file, config_file):
|
8 |
+
# Dictionary to store all tensors
|
9 |
+
merged_tensors = {}
|
10 |
+
|
11 |
+
# Load config
|
12 |
+
with open(config_file, 'r') as f:
|
13 |
+
config = json.load(f)
|
14 |
+
|
15 |
+
# Prepare metadata
|
16 |
+
metadata = {
|
17 |
+
"format": "pt",
|
18 |
+
"total_size": "", #str(total_size), # Notice we stringify this!
|
19 |
+
"_diffusers_version": config.get("_diffusers_version", ""),
|
20 |
+
"_class_name": config.get("_class_name", ""),
|
21 |
+
# Add other fields at this level
|
22 |
+
}
|
23 |
+
|
24 |
+
total_size = 0
|
25 |
+
|
26 |
+
# Iterate through all files in the input directory
|
27 |
+
for filename in os.listdir(input_dir):
|
28 |
+
if filename.endswith('.safetensors'):
|
29 |
+
file_path = os.path.join(input_dir, filename)
|
30 |
+
|
31 |
+
# Load tensors and metadata from each file
|
32 |
+
with safe_open(file_path, framework="pt", device="cpu") as f:
|
33 |
+
file_metadata = f.metadata()
|
34 |
+
if file_metadata and "__metadata__" in file_metadata:
|
35 |
+
total_size += int(file_metadata["__metadata__"].get("total_size", 0))
|
36 |
+
|
37 |
+
for key in f.keys():
|
38 |
+
tensor = f.get_tensor(key)
|
39 |
+
merged_tensors[key] = tensor
|
40 |
+
|
41 |
+
# Add total size to metadata
|
42 |
+
metadata["total_size"] = str(total_size)
|
43 |
+
|
44 |
+
# Save the merged tensors to a single file with metadata
|
45 |
+
save_file(merged_tensors, output_file, metadata)
|
46 |
+
|
47 |
+
|
48 |
+
input_directory = './10_1'
|
49 |
+
output_file = './10_1/flux1-merge-S10_D1.safetensors'
|
50 |
+
config_file = './10_1/config.json'
|
51 |
+
merge_safetensors(input_directory, output_file, config_file)
|