John6666 commited on
Commit
7847c2d
1 Parent(s): c26428e

Upload 8 files

Browse files
Files changed (8) hide show
  1. README.md +5 -4
  2. app.py +48 -0
  3. packages.txt +1 -0
  4. requirements.txt +4 -0
  5. sdxl_keys.txt +0 -0
  6. stkey.py +121 -0
  7. stkey_gr.py +104 -0
  8. utils.py +165 -0
README.md CHANGED
@@ -1,12 +1,13 @@
1
  ---
2
- title: Safetensors Key Checker
3
- emoji: 💻
4
- colorFrom: green
5
- colorTo: gray
6
  sdk: gradio
7
  sdk_version: 5.0.2
8
  app_file: app.py
9
  pinned: false
 
10
  ---
11
 
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: Safetensors file key checker
3
+ emoji: 🐶
4
+ colorFrom: yellow
5
+ colorTo: blue
6
  sdk: gradio
7
  sdk_version: 5.0.2
8
  app_file: app.py
9
  pinned: false
10
+ license: mit
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from stkey_gr import stkey_gr, KEYS_FILES
3
+
4
+ css = """
5
+ .title { font-size: 3em; align-items: center; text-align: center; }
6
+ .info { align-items: center; text-align: center; }
7
+ .block.result { margin: 1em 0; padding: 1em; box-shadow: 0 0 3px 3px #664422, 0 0 3px 2px #664422 inset; border-radius: 6px; background: #665544; }
8
+ .desc [src$='#float'] { float: right; margin: 20px; }
9
+ """
10
+
11
+ with gr.Blocks(theme="NoCrypt/miku@>=1.2.2", fill_width=True, css=css, delete_cache=(60, 3600)) as demo:
12
+ with gr.Column():
13
+ gr.Markdown("# Safetensors file key checker.", elem_classes="title")
14
+ with gr.Group():
15
+ dl_url = gr.Textbox(label="Download URL(s)", placeholder="https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_base_1.0_0.9vae.safetensors\n...", value="", lines=2, max_lines=255)
16
+ with gr.Row():
17
+ is_validate = gr.Checkbox(label="Validate safetensors file keys", value=True)
18
+ rfile = gr.Radio(label="Reference file for validation", choices=KEYS_FILES, value=KEYS_FILES[0])
19
+ with gr.Accordion("Advanced", open=False):
20
+ with gr.Row():
21
+ with gr.Column():
22
+ hf_token = gr.Textbox(label="Your HF read token (Optional)", placeholder="hf_...", value="", max_lines=1)
23
+ gr.Markdown("Your token is available at [hf.co/settings/tokens](https://huggingface.co/settings/tokens).", elem_classes="info")
24
+ with gr.Column():
25
+ civitai_key = gr.Textbox(label="Your Civitai Key (Optional)", value="", max_lines=1)
26
+ gr.Markdown("Your Civitai API key is available at [https://civitai.com/user/account](https://civitai.com/user/account).", elem_classes="info")
27
+ run_button = gr.Button(value="Check", variant="primary")
28
+ uploaded_urls = gr.CheckboxGroup(visible=False, choices=[], value=[]) # hidden
29
+ urls_md = gr.Markdown("<br><br>", elem_classes="result")
30
+ out_files = gr.Files(label="Output", interactive=False, value=[])
31
+ with gr.Group():
32
+ with gr.Row():
33
+ missing = gr.JSON(value=[], label="Missing keys")
34
+ added = gr.JSON(value=[], label="Added keys")
35
+ with gr.Row():
36
+ keys = gr.JSON(value=[], label="All keys")
37
+ metadata = gr.JSON(value={}, label="Metadata")
38
+ gr.DuplicateButton(value="Duplicate Space")
39
+
40
+ gr.on(
41
+ triggers=[run_button.click],
42
+ fn=stkey_gr,
43
+ inputs=[dl_url, civitai_key, hf_token, uploaded_urls, out_files, is_validate, rfile],
44
+ outputs=[uploaded_urls, out_files, urls_md, metadata, keys, missing, added],
45
+ )
46
+
47
+ demo.queue()
48
+ demo.launch()
packages.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ git-lfs aria2
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ huggingface-hub
2
+ gdown
3
+ safetensors
4
+ torch
sdxl_keys.txt ADDED
The diff for this file is too large to render. See raw diff
 
stkey.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from pathlib import Path
3
+ import json
4
+ import re
5
+ import gc
6
+ from safetensors.torch import load_file, save_file
7
+ import torch
8
+
9
+
10
+ SDXL_KEYS_FILE = "sdxl_keys.txt"
11
+
12
+
13
+ def list_uniq(l):
14
+ return sorted(set(l), key=l.index)
15
+
16
+
17
+ def read_safetensors_metadata(path: str):
18
+ with open(path, 'rb') as f:
19
+ header_size = int.from_bytes(f.read(8), 'little')
20
+ header_json = f.read(header_size).decode('utf-8')
21
+ header = json.loads(header_json)
22
+ metadata = header.get('__metadata__', {})
23
+ return metadata
24
+
25
+
26
+ def keys_from_file(path: str):
27
+ keys = []
28
+ try:
29
+ with open(str(Path(path)), encoding='utf-8', mode='r') as f:
30
+ lines = f.readlines()
31
+ for line in lines:
32
+ keys.append(line.strip())
33
+ except Exception as e:
34
+ print(e)
35
+ finally:
36
+ return keys
37
+
38
+
39
+ def validate_keys(keys: list[str], rfile: str=SDXL_KEYS_FILE):
40
+ missing = []
41
+ added = []
42
+ try:
43
+ rkeys = keys_from_file(rfile)
44
+ all_keys = list_uniq(keys + rkeys)
45
+ for key in all_keys:
46
+ if key in set(rkeys) and key not in set(keys): missing.append(key)
47
+ if key in set(keys) and key not in set(rkeys): added.append(key)
48
+ except Exception as e:
49
+ print(e)
50
+ finally:
51
+ return missing, added
52
+
53
+
54
+ def read_safetensors_key(path: str):
55
+ try:
56
+ keys = []
57
+ state_dict = load_file(str(Path(path)))
58
+ for k, v in state_dict.items():
59
+ keys.append(k)
60
+ except Exception as e:
61
+ print(e)
62
+ finally:
63
+ del state_dict
64
+ torch.cuda.empty_cache()
65
+ gc.collect()
66
+ return keys
67
+
68
+
69
+ def write_safetensors_key(keys: list[str], path: str, is_validate: bool=True, rpath: str=SDXL_KEYS_FILE):
70
+ if len(keys) == 0: return False
71
+ try:
72
+ with open(str(Path(path)), encoding='utf-8', mode='w') as f:
73
+ f.write("\n".join(keys))
74
+ if is_validate:
75
+ missing, added = validate_keys(keys, rpath)
76
+ with open(str(Path(path).stem + "_missing.txt"), encoding='utf-8', mode='w') as f:
77
+ f.write("\n".join(missing))
78
+ with open(str(Path(path).stem + "_added.txt"), encoding='utf-8', mode='w') as f:
79
+ f.write("\n".join(added))
80
+ return True
81
+ except Exception as e:
82
+ print(e)
83
+ return False
84
+
85
+
86
+ def stkey(input: str, out_filename: str="", is_validate: bool=True, rfile: str=SDXL_KEYS_FILE):
87
+ keys = read_safetensors_key(input)
88
+ if len(keys) != 0 and out_filename: write_safetensors_key(keys, out_filename, is_validate, rfile)
89
+ if len(keys) != 0:
90
+ print("Metadata:")
91
+ print(read_safetensors_metadata(input))
92
+ print("\nKeys:")
93
+ print("\n".join(keys))
94
+ if is_validate:
95
+ missing, added = validate_keys(keys, rfile)
96
+ print("\nMissing Keys:")
97
+ print("\n".join(missing))
98
+ print("\nAdded Keys:")
99
+ print("\n".join(added))
100
+
101
+
102
+ if __name__ == "__main__":
103
+ parser = argparse.ArgumentParser()
104
+ parser.add_argument("input", type=str, help="Input safetensors file.")
105
+ parser.add_argument("-s", "--save", action="store_true", default=False, help="Output to text file.")
106
+ parser.add_argument("-o", "--output", default="", type=str, help="Output to specific text file.")
107
+ parser.add_argument("-v", "--val", action="store_false", default=True, help="Disable key validation.")
108
+ parser.add_argument("-r", "--rfile", default=SDXL_KEYS_FILE, type=str, help="Specify reference file to validate keys.")
109
+
110
+ args = parser.parse_args()
111
+
112
+ if args.save: out_filename = Path(args.input).stem + ".txt"
113
+ out_filename = args.output if args.output else out_filename
114
+
115
+ stkey(args.input, out_filename, args.val, args.rfile)
116
+
117
+
118
+ # Usage:
119
+ # python stkey.py sd_xl_base_1.0_0.9vae.safetensors
120
+ # python stkey.py sd_xl_base_1.0_0.9vae.safetensors -s
121
+ # python stkey.py sd_xl_base_1.0_0.9vae.safetensors -o key.txt
stkey_gr.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from huggingface_hub import HfApi, hf_hub_url
3
+ import os
4
+ from pathlib import Path
5
+ import gc
6
+ import re
7
+ import json
8
+ from utils import get_token, set_token, is_repo_exists, get_user_agent, get_download_file, list_uniq
9
+ from stkey import read_safetensors_key, read_safetensors_metadata, validate_keys, write_safetensors_key
10
+
11
+
12
+ TEMP_DIR = "."
13
+ KEYS_FILES = ["sdxl_keys.txt"]
14
+
15
+
16
+ def parse_urls(s):
17
+ url_pattern = "https?://[\\w/:%#\\$&\\?\\(\\)~\\.=\\+\\-]+"
18
+ try:
19
+ urls = re.findall(url_pattern, s)
20
+ return list(urls)
21
+ except Exception:
22
+ return []
23
+
24
+
25
+ def to_urls(l: list[str]):
26
+ return "\n".join(l)
27
+
28
+
29
+ def uniq_urls(s):
30
+ return to_urls(list_uniq(parse_urls(s)))
31
+
32
+
33
+ def upload_safetensors_to_repo(filename, repo_id, repo_type, is_private, progress=gr.Progress(track_tqdm=True)):
34
+ output_filename = Path(filename).name
35
+ hf_token = get_token()
36
+ api = HfApi(token=hf_token)
37
+ try:
38
+ if not is_repo_exists(repo_id, repo_type): api.create_repo(repo_id=repo_id, repo_type=repo_type, token=hf_token, private=is_private)
39
+ progress(0, desc=f"Start uploading... {filename} to {repo_id}")
40
+ api.upload_file(path_or_fileobj=filename, path_in_repo=output_filename, repo_type=repo_type, revision="main", token=hf_token, repo_id=repo_id)
41
+ progress(1, desc="Uploaded.")
42
+ url = hf_hub_url(repo_id=repo_id, repo_type=repo_type, filename=output_filename)
43
+ except Exception as e:
44
+ print(f"Error: Failed to upload to {repo_id}. {e}")
45
+ gr.Warning(f"Error: Failed to upload to {repo_id}. {e}")
46
+ return None
47
+ finally:
48
+ Path(filename).unlink()
49
+ return url
50
+
51
+
52
+ def download_file(dl_url, civitai_key, progress=gr.Progress(track_tqdm=True)):
53
+ download_dir = TEMP_DIR
54
+ progress(0, desc=f"Start downloading... {dl_url}")
55
+ output_filename = get_download_file(download_dir, dl_url, civitai_key)
56
+ return output_filename
57
+
58
+
59
+ def get_stkey(filename: str, is_validate: bool=True, rfile: str=KEYS_FILES[0], progress=gr.Progress(track_tqdm=True)):
60
+ paths = []
61
+ metadata = {}
62
+ keys = []
63
+ missing = []
64
+ added = []
65
+ try:
66
+ progress(0, desc=f"Loading keys... {filename}")
67
+ keys = read_safetensors_key(filename)
68
+ if len(keys) == 0: raise Exception("No keys found.")
69
+ progress(0.5, desc=f"Checking keys... {filename}")
70
+ if write_safetensors_key(keys, str(Path(filename).stem + ".txt"), is_validate, rfile):
71
+ paths.append(str(Path(filename).stem + ".txt"))
72
+ paths.append(str(Path(filename).stem + "_missing.txt"))
73
+ paths.append(str(Path(filename).stem + "_added.txt"))
74
+ missing, added = validate_keys(keys, rfile)
75
+ metadata = read_safetensors_metadata(filename)
76
+ except Exception as e:
77
+ print(f"Error: Failed check {filename}. {e}")
78
+ gr.Warning(f"Error: Failed check {filename}. {e}")
79
+ finally:
80
+ Path(filename).unlink()
81
+ return paths, metadata, keys, missing, added
82
+
83
+
84
+ def stkey_gr(dl_url: str, civitai_key: str, hf_token: str, urls: list[str], files: list[str],
85
+ is_validate=True, rfile=KEYS_FILES[0], progress=gr.Progress(track_tqdm=True)):
86
+ if hf_token: set_token(hf_token)
87
+ else: set_token(os.environ.get("HF_TOKEN")) # default huggingface token
88
+ if not civitai_key: civitai_key = os.environ.get("CIVITAI_API_KEY") # default Civitai API key
89
+ dl_urls = parse_urls(dl_url)
90
+ if not urls: urls = []
91
+ if not files: files = []
92
+ metadata = {}
93
+ keys = []
94
+ missing = []
95
+ added = []
96
+ for u in dl_urls:
97
+ file = download_file(u, civitai_key)
98
+ if not Path(file).exists() or not Path(file).is_file(): continue
99
+ paths, metadata, keys, missing, added = get_stkey(file, is_validate, rfile)
100
+ if len(paths) != 0: files.extend(paths)
101
+ progress(1, desc="Processing...")
102
+ gc.collect()
103
+ return gr.update(value=urls, choices=urls), gr.update(value=files), gr.update(visible=True), metadata, keys, missing, added
104
+
utils.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from huggingface_hub import HfApi, HfFolder, hf_hub_download
3
+ import os
4
+ from pathlib import Path
5
+ import shutil
6
+ import gc
7
+ import re
8
+ import urllib.parse
9
+
10
+
11
+ def get_token():
12
+ try:
13
+ token = HfFolder.get_token()
14
+ except Exception:
15
+ token = ""
16
+ return token
17
+
18
+
19
+ def set_token(token):
20
+ try:
21
+ HfFolder.save_token(token)
22
+ except Exception:
23
+ print(f"Error: Failed to save token.")
24
+
25
+
26
+ def get_user_agent():
27
+ return 'Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:127.0) Gecko/20100101 Firefox/127.0'
28
+
29
+
30
+ def is_repo_exists(repo_id: str, repo_type: str="model"):
31
+ hf_token = get_token()
32
+ api = HfApi(token=hf_token)
33
+ try:
34
+ if api.repo_exists(repo_id=repo_id, repo_type=repo_type, token=hf_token): return True
35
+ else: return False
36
+ except Exception as e:
37
+ print(f"Error: Failed to connect {repo_id} ({repo_type}). {e}")
38
+ return True # for safe
39
+
40
+
41
+ MODEL_TYPE_CLASS = {
42
+ "diffusers:StableDiffusionPipeline": "SD 1.5",
43
+ "diffusers:StableDiffusionXLPipeline": "SDXL",
44
+ "diffusers:FluxPipeline": "FLUX",
45
+ }
46
+
47
+
48
+ def get_model_type(repo_id: str):
49
+ hf_token = get_token()
50
+ api = HfApi(token=hf_token)
51
+ lora_filename = "pytorch_lora_weights.safetensors"
52
+ diffusers_filename = "model_index.json"
53
+ default = "SDXL"
54
+ try:
55
+ if api.file_exists(repo_id=repo_id, filename=lora_filename, token=hf_token): return "LoRA"
56
+ if not api.file_exists(repo_id=repo_id, filename=diffusers_filename, token=hf_token): return "None"
57
+ model = api.model_info(repo_id=repo_id, token=hf_token)
58
+ tags = model.tags
59
+ for tag in tags:
60
+ if tag in MODEL_TYPE_CLASS.keys(): return MODEL_TYPE_CLASS.get(tag, default)
61
+ except Exception:
62
+ return default
63
+ return default
64
+
65
+
66
+ def list_uniq(l):
67
+ return sorted(set(l), key=l.index)
68
+
69
+
70
+ def list_sub(a, b):
71
+ return [e for e in a if e not in b]
72
+
73
+
74
+ def is_repo_name(s):
75
+ return re.fullmatch(r'^[^/,\s\"\']+/[^/,\s\"\']+$', s)
76
+
77
+
78
+ def split_hf_url(url: str):
79
+ try:
80
+ s = list(re.findall(r'^(?:https?://huggingface.co/)(?:(datasets)/)?(.+?/.+?)/\w+?/.+?/(?:(.+)/)?(.+?.\w+)(?:\?download=true)?$', url)[0])
81
+ if len(s) < 4: return "", "", "", ""
82
+ repo_id = s[1]
83
+ repo_type = "dataset" if s[0] == "datasets" else "model"
84
+ subfolder = urllib.parse.unquote(s[2]) if s[2] else None
85
+ filename = urllib.parse.unquote(s[3])
86
+ return repo_id, filename, subfolder, repo_type
87
+ except Exception as e:
88
+ print(e)
89
+
90
+
91
+ def download_hf_file(directory, url, progress=gr.Progress(track_tqdm=True)):
92
+ hf_token = get_token()
93
+ repo_id, filename, subfolder, repo_type = split_hf_url(url)
94
+ try:
95
+ if subfolder is not None: hf_hub_download(repo_id=repo_id, filename=filename, subfolder=subfolder, repo_type=repo_type, local_dir=directory, token=hf_token)
96
+ else: hf_hub_download(repo_id=repo_id, filename=filename, repo_type=repo_type, local_dir=directory, token=hf_token)
97
+ except Exception as e:
98
+ print(f"Failed to download: {e}")
99
+
100
+
101
+ def download_thing(directory, url, civitai_api_key="", progress=gr.Progress(track_tqdm=True)): # requires aria2, gdown
102
+ hf_token = get_token()
103
+ url = url.strip()
104
+ if "drive.google.com" in url:
105
+ original_dir = os.getcwd()
106
+ os.chdir(directory)
107
+ os.system(f"gdown --fuzzy {url}")
108
+ os.chdir(original_dir)
109
+ elif "huggingface.co" in url:
110
+ url = url.replace("?download=true", "")
111
+ if "/blob/" in url:
112
+ url = url.replace("/blob/", "/resolve/")
113
+ #user_header = f'"Authorization: Bearer {hf_token}"'
114
+ if hf_token:
115
+ download_hf_file(directory, url)
116
+ #os.system(f"aria2c --console-log-level=error --summary-interval=10 --header={user_header} -c -x 16 -k 1M -s 16 {url} -d {directory} -o {url.split('/')[-1]}")
117
+ else:
118
+ os.system(f"aria2c --optimize-concurrent-downloads --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 {url} -d {directory} -o {url.split('/')[-1]}")
119
+ elif "civitai.com" in url:
120
+ if "?" in url:
121
+ url = url.split("?")[0]
122
+ if civitai_api_key:
123
+ url = url + f"?token={civitai_api_key}"
124
+ os.system(f"aria2c --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 -d {directory} {url}")
125
+ else:
126
+ print("You need an API key to download Civitai models.")
127
+ else:
128
+ os.system(f"aria2c --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 -d {directory} {url}")
129
+
130
+
131
+ def get_local_model_list(dir_path):
132
+ model_list = []
133
+ valid_extensions = ('.safetensors')
134
+ for file in Path(dir_path).glob("**/*.*"):
135
+ if file.is_file() and file.suffix in valid_extensions:
136
+ file_path = str(file)
137
+ model_list.append(file_path)
138
+ return model_list
139
+
140
+
141
+ def get_download_file(temp_dir, url, civitai_key, progress=gr.Progress(track_tqdm=True)):
142
+ if not "http" in url and is_repo_name(url) and not Path(url).exists():
143
+ print(f"Use HF Repo: {url}")
144
+ new_file = url
145
+ elif not "http" in url and Path(url).exists():
146
+ print(f"Use local file: {url}")
147
+ new_file = url
148
+ elif Path(f"{temp_dir}/{url.split('/')[-1]}").exists():
149
+ print(f"File to download alreday exists: {url}")
150
+ new_file = f"{temp_dir}/{url.split('/')[-1]}"
151
+ else:
152
+ print(f"Start downloading: {url}")
153
+ before = get_local_model_list(temp_dir)
154
+ try:
155
+ download_thing(temp_dir, url.strip(), civitai_key)
156
+ except Exception:
157
+ print(f"Download failed: {url}")
158
+ return ""
159
+ after = get_local_model_list(temp_dir)
160
+ new_file = list_sub(after, before)[0] if list_sub(after, before) else ""
161
+ if not new_file:
162
+ print(f"Download failed: {url}")
163
+ return ""
164
+ print(f"Download completed: {url}")
165
+ return new_file