John6666 commited on
Commit
b45ac7a
·
verified ·
1 Parent(s): 7b29d85

Upload 9 files

Browse files
Files changed (9) hide show
  1. README.md +14 -12
  2. app.py +243 -0
  3. convert_url_to_diffusers_multi_gr.py +461 -0
  4. packages.txt +1 -0
  5. presets.py +134 -0
  6. requirements.txt +11 -0
  7. sdutils.py +157 -0
  8. stkey.py +122 -0
  9. utils.py +275 -0
README.md CHANGED
@@ -1,12 +1,14 @@
1
- ---
2
- title: Safetensors To Diffusers
3
- emoji: 👁
4
- colorFrom: yellow
5
- colorTo: pink
6
- sdk: gradio
7
- sdk_version: 5.3.0
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
1
+ ---
2
+ title: Download safetensors and convert to HF🤗 Diffusers format (SDXL / SD 1.5 / FLUX.1 / SD 3.5) Alpha
3
+ emoji: 🎨➡️🧨
4
+ colorFrom: indigo
5
+ colorTo: purple
6
+ sdk: gradio
7
+ sdk_version: 5.1.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ short_description: Convert SDXL/1.5/3.5/FLUX.1 safetensors to HF🤗 Diffusers
12
+ ---
13
+
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from convert_url_to_diffusers_multi_gr import convert_url_to_diffusers_repo, get_dtypes, FLUX_BASE_REPOS, SD35_BASE_REPOS
3
+ from presets import (DEFAULT_DTYPE, schedulers, clips, t5s, sdxl_vaes, sdxl_loras, sdxl_preset_dict, sdxl_set_presets,
4
+ sd15_vaes, sd15_loras, sd15_preset_dict, sd15_set_presets, flux_vaes, flux_loras, flux_preset_dict, flux_set_presets,
5
+ sd35_vaes, sd35_loras, sd35_preset_dict, sd35_set_presets)
6
+
7
+ css = """
8
+ .title { font-size: 3em; align-items: center; text-align: center; }
9
+ .info { align-items: center; text-align: center; }
10
+ .block.result { margin: 1em 0; padding: 1em; box-shadow: 0 0 3px 3px #664422, 0 0 3px 2px #664422 inset; border-radius: 6px; background: #665544; }
11
+ """
12
+
13
+ with gr.Blocks(theme="NoCrypt/miku@>=1.2.2", fill_width=True, css=css, delete_cache=(60, 3600)) as demo:
14
+ gr.Markdown("# Download SDXL / SD 1.5 / SD 3.5 / FLUX.1 safetensors and convert to HF🤗 Diffusers format and create your repo", elem_classes="title")
15
+ gr.Markdown(f"""
16
+ ### ⚠️IMPORTANT NOTICE⚠️<br>
17
+ It's dangerous to expose your access token or key to others.
18
+ If you do use it, I recommend that you duplicate this space on your own HF account in advance.
19
+ Keys and tokens could be set to **Secrets** (`HF_TOKEN`, `CIVITAI_API_KEY`) if it's placed in your own space.
20
+ It saves you the trouble of typing them in.<br>
21
+ It barely works in the CPU space, but larger files can be converted if duplicated on the more powerful **Zero GPU** space.
22
+ In particular, conversion of FLUX.1 or SD 3.5 is almost impossible in CPU space.
23
+ ### The steps are the following:
24
+ 1. Paste a write-access token from [hf.co/settings/tokens](https://huggingface.co/settings/tokens).
25
+ 1. Input a model download url of the Hugging Face or Civitai or other sites.
26
+ 1. If you want to download a model from Civitai, paste a Civitai API Key.
27
+ 1. Input your HF user ID. e.g. 'yourid'.
28
+ 1. Input your new repo name. If empty, auto-complete. e.g. 'newrepo'.
29
+ 1. Set the parameters. If not sure, just use the defaults.
30
+ 1. Click "Submit".
31
+ 1. Patiently wait until the output changes. It takes approximately 2 to 3 minutes (on SDXL models downloading from HF).
32
+ """)
33
+ with gr.Column():
34
+ dl_url = gr.Textbox(label="URL to download", placeholder="https://huggingface.co/bluepen5805/blue_pencil-XL/blob/main/blue_pencil-XL-v7.0.0.safetensors", value="", max_lines=1)
35
+ with gr.Group():
36
+ with gr.Row():
37
+ hf_user = gr.Textbox(label="Your HF user ID", placeholder="username", value="", max_lines=1)
38
+ hf_repo = gr.Textbox(label="New repo name", placeholder="reponame", info="If empty, auto-complete", value="", max_lines=1)
39
+ with gr.Row(equal_height=True):
40
+ with gr.Column():
41
+ hf_token = gr.Textbox(label="Your HF write token", placeholder="hf_...", value="", max_lines=1)
42
+ gr.Markdown("Your token is available at [hf.co/settings/tokens](https://huggingface.co/settings/tokens).", elem_classes="info")
43
+ with gr.Column():
44
+ civitai_key = gr.Textbox(label="Your Civitai API Key (Optional)", info="If you download model from Civitai...", placeholder="", value="", max_lines=1)
45
+ gr.Markdown("Your Civitai API key is available at [https://civitai.com/user/account](https://civitai.com/user/account).", elem_classes="info")
46
+ with gr.Row():
47
+ is_upload_sf = gr.Checkbox(label="Upload single safetensors file into new repo", value=False)
48
+ is_private = gr.Checkbox(label="Create private repo", value=True)
49
+ is_overwrite = gr.Checkbox(label="Overwrite repo", value=False)
50
+ with gr.Tab("SDXL"):
51
+ with gr.Group():
52
+ sdxl_presets = gr.Radio(label="Presets", choices=list(sdxl_preset_dict.keys()), value=list(sdxl_preset_dict.keys())[0])
53
+ sdxl_mtype = gr.Textbox(value="SDXL", visible=False)
54
+ with gr.Row():
55
+ sdxl_dtype = gr.Radio(label="Output data type", choices=get_dtypes(), value=DEFAULT_DTYPE)
56
+ sdxl_ema = gr.Checkbox(label="Extract EMA", info="For SD 1.5", value=True, visible=False)
57
+ sdxl_base_repo = gr.Dropdown(label="Base repo ID", choices=FLUX_BASE_REPOS, value=FLUX_BASE_REPOS[0], allow_custom_value=True, visible=False)
58
+ with gr.Accordion("Advanced settings", open=False):
59
+ with gr.Row():
60
+ sdxl_vae = gr.Dropdown(label="VAE", choices=sdxl_vaes, value="", allow_custom_value=True)
61
+ sdxl_clip = gr.Dropdown(label="CLIP", choices=clips, value="", allow_custom_value=True)
62
+ sdxl_t5 = gr.Dropdown(label="T5", choices=t5s, value="", allow_custom_value=True, visible=False)
63
+ sdxl_scheduler = gr.Dropdown(label="Scheduler (Sampler)", choices=schedulers, value="Euler a")
64
+ with gr.Row():
65
+ with gr.Column():
66
+ sdxl_lora1 = gr.Dropdown(label="LoRA1", choices=sdxl_loras, value="", allow_custom_value=True, min_width=320)
67
+ sdxl_lora1s = gr.Slider(minimum=-2, maximum=2, step=0.01, value=1.00, label="LoRA1 weight scale")
68
+ with gr.Column():
69
+ sdxl_lora2 = gr.Dropdown(label="LoRA2", choices=sdxl_loras, value="", allow_custom_value=True, min_width=320)
70
+ sdxl_lora2s = gr.Slider(minimum=-2, maximum=2, step=0.01, value=1.00, label="LoRA2 weight scale")
71
+ with gr.Column():
72
+ sdxl_lora3 = gr.Dropdown(label="LoRA3", choices=sdxl_loras, value="", allow_custom_value=True, min_width=320)
73
+ sdxl_lora3s = gr.Slider(minimum=-2, maximum=2, step=0.01, value=1.00, label="LoRA3 weight scale")
74
+ with gr.Column():
75
+ sdxl_lora4 = gr.Dropdown(label="LoRA4", choices=sdxl_loras, value="", allow_custom_value=True, min_width=320)
76
+ sdxl_lora4s = gr.Slider(minimum=-2, maximum=2, step=0.01, value=1.00, label="LoRA4 weight scale")
77
+ with gr.Column():
78
+ sdxl_lora5 = gr.Dropdown(label="LoRA5", choices=sdxl_loras, value="", allow_custom_value=True, min_width=320)
79
+ sdxl_lora5s = gr.Slider(minimum=-2, maximum=2, step=0.01, value=1.00, label="LoRA5 weight scale")
80
+ sdxl_run_button = gr.Button(value="Submit", variant="primary")
81
+ with gr.Tab("SD 1.5"):
82
+ with gr.Group():
83
+ sd15_presets = gr.Radio(label="Presets", choices=list(sd15_preset_dict.keys()), value=list(sd15_preset_dict.keys())[0])
84
+ sd15_mtype = gr.Textbox(value="SD 1.5", visible=False)
85
+ with gr.Row():
86
+ sd15_dtype = gr.Radio(label="Output data type", choices=get_dtypes(), value=DEFAULT_DTYPE)
87
+ sd15_ema = gr.Checkbox(label="Extract EMA", info="For SD 1.5", value=True, visible=True)
88
+ sd15_base_repo = gr.Dropdown(label="Base repo ID", choices=FLUX_BASE_REPOS, value=FLUX_BASE_REPOS[0], allow_custom_value=True, visible=False)
89
+ with gr.Accordion("Advanced settings", open=False):
90
+ with gr.Row():
91
+ sd15_vae = gr.Dropdown(label="VAE", choices=sd15_vaes, value="", allow_custom_value=True)
92
+ sd15_clip = gr.Dropdown(label="CLIP", choices=clips, value="", allow_custom_value=True)
93
+ sd15_t5 = gr.Dropdown(label="T5", choices=t5s, value="", allow_custom_value=True, visible=False)
94
+ sd15_scheduler = gr.Dropdown(label="Scheduler (Sampler)", choices=schedulers, value="Euler")
95
+ with gr.Row():
96
+ with gr.Column():
97
+ sd15_lora1 = gr.Dropdown(label="LoRA1", choices=sd15_loras, value="", allow_custom_value=True, min_width=320)
98
+ sd15_lora1s = gr.Slider(minimum=-2, maximum=2, step=0.01, value=1.00, label="LoRA1 weight scale")
99
+ with gr.Column():
100
+ sd15_lora2 = gr.Dropdown(label="LoRA2", choices=sd15_loras, value="", allow_custom_value=True, min_width=320)
101
+ sd15_lora2s = gr.Slider(minimum=-2, maximum=2, step=0.01, value=1.00, label="LoRA2 weight scale")
102
+ with gr.Column():
103
+ sd15_lora3 = gr.Dropdown(label="LoRA3", choices=sd15_loras, value="", allow_custom_value=True, min_width=320)
104
+ sd15_lora3s = gr.Slider(minimum=-2, maximum=2, step=0.01, value=1.00, label="LoRA3 weight scale")
105
+ with gr.Column():
106
+ sd15_lora4 = gr.Dropdown(label="LoRA4", choices=sd15_loras, value="", allow_custom_value=True, min_width=320)
107
+ sd15_lora4s = gr.Slider(minimum=-2, maximum=2, step=0.01, value=1.00, label="LoRA4 weight scale")
108
+ with gr.Column():
109
+ sd15_lora5 = gr.Dropdown(label="LoRA5", choices=sd15_loras, value="", allow_custom_value=True, min_width=320)
110
+ sd15_lora5s = gr.Slider(minimum=-2, maximum=2, step=0.01, value=1.00, label="LoRA5 weight scale")
111
+ sd15_run_button = gr.Button(value="Submit", variant="primary")
112
+ with gr.Tab("FLUX.1"):
113
+ with gr.Group():
114
+ flux_presets = gr.Radio(label="Presets", choices=list(flux_preset_dict.keys()), value=list(flux_preset_dict.keys())[0])
115
+ flux_mtype = gr.Textbox(value="FLUX", visible=False)
116
+ with gr.Row():
117
+ flux_dtype = gr.Radio(label="Output data type", choices=get_dtypes(), value="bf16")
118
+ flux_ema = gr.Checkbox(label="Extract EMA", info="For SD 1.5", value=True, visible=False)
119
+ flux_base_repo = gr.Dropdown(label="Base repo ID", choices=FLUX_BASE_REPOS, value=FLUX_BASE_REPOS[0], allow_custom_value=True, visible=True)
120
+ with gr.Accordion("Advanced settings", open=False):
121
+ with gr.Row():
122
+ flux_vae = gr.Dropdown(label="VAE", choices=flux_vaes, value="", allow_custom_value=True)
123
+ flux_clip = gr.Dropdown(label="CLIP", choices=clips, value="", allow_custom_value=True)
124
+ flux_t5 = gr.Dropdown(label="T5", choices=t5s, value="", allow_custom_value=True)
125
+ flux_scheduler = gr.Dropdown(label="Scheduler (Sampler)", choices=[""], value="", visible=False)
126
+ with gr.Row():
127
+ with gr.Column():
128
+ flux_lora1 = gr.Dropdown(label="LoRA1", choices=flux_loras, value="", allow_custom_value=True, min_width=320)
129
+ flux_lora1s = gr.Slider(minimum=-2, maximum=2, step=0.01, value=1.00, label="LoRA1 weight scale")
130
+ with gr.Column():
131
+ flux_lora2 = gr.Dropdown(label="LoRA2", choices=flux_loras, value="", allow_custom_value=True, min_width=320)
132
+ flux_lora2s = gr.Slider(minimum=-2, maximum=2, step=0.01, value=1.00, label="LoRA2 weight scale")
133
+ with gr.Column():
134
+ flux_lora3 = gr.Dropdown(label="LoRA3", choices=flux_loras, value="", allow_custom_value=True, min_width=320)
135
+ flux_lora3s = gr.Slider(minimum=-2, maximum=2, step=0.01, value=1.00, label="LoRA3 weight scale")
136
+ with gr.Column():
137
+ flux_lora4 = gr.Dropdown(label="LoRA4", choices=flux_loras, value="", allow_custom_value=True, min_width=320)
138
+ flux_lora4s = gr.Slider(minimum=-2, maximum=2, step=0.01, value=1.00, label="LoRA4 weight scale")
139
+ with gr.Column():
140
+ flux_lora5 = gr.Dropdown(label="LoRA5", choices=flux_loras, value="", allow_custom_value=True, min_width=320)
141
+ flux_lora5s = gr.Slider(minimum=-2, maximum=2, step=0.01, value=1.00, label="LoRA5 weight scale")
142
+ flux_run_button = gr.Button(value="Submit", variant="primary")
143
+ with gr.Tab("SD 3.5"):
144
+ with gr.Group():
145
+ sd35_presets = gr.Radio(label="Presets", choices=list(sd35_preset_dict.keys()), value=list(sd35_preset_dict.keys())[0])
146
+ sd35_mtype = gr.Textbox(value="SD 3.5", visible=False)
147
+ with gr.Row():
148
+ sd35_dtype = gr.Radio(label="Output data type", choices=get_dtypes(), value="bf16")
149
+ sd35_ema = gr.Checkbox(label="Extract EMA", info="For SD 1.5", value=True, visible=False)
150
+ sd35_base_repo = gr.Dropdown(label="Base repo ID", choices=SD35_BASE_REPOS, value=SD35_BASE_REPOS[0], allow_custom_value=True, visible=True)
151
+ with gr.Accordion("Advanced settings", open=False):
152
+ with gr.Row():
153
+ sd35_vae = gr.Dropdown(label="VAE", choices=sd35_vaes, value="", allow_custom_value=True)
154
+ sd35_clip = gr.Dropdown(label="CLIP", choices=clips, value="", allow_custom_value=True)
155
+ sd35_t5 = gr.Dropdown(label="T5", choices=t5s, value="", allow_custom_value=True)
156
+ sd35_scheduler = gr.Dropdown(label="Scheduler (Sampler)", choices=[""], value="", visible=False)
157
+ with gr.Row():
158
+ with gr.Column():
159
+ sd35_lora1 = gr.Dropdown(label="LoRA1", choices=sd35_loras, value="", allow_custom_value=True, min_width=320)
160
+ sd35_lora1s = gr.Slider(minimum=-2, maximum=2, step=0.01, value=1.00, label="LoRA1 weight scale")
161
+ with gr.Column():
162
+ sd35_lora2 = gr.Dropdown(label="LoRA2", choices=sd35_loras, value="", allow_custom_value=True, min_width=320)
163
+ sd35_lora2s = gr.Slider(minimum=-2, maximum=2, step=0.01, value=1.00, label="LoRA2 weight scale")
164
+ with gr.Column():
165
+ sd35_lora3 = gr.Dropdown(label="LoRA3", choices=sd35_loras, value="", allow_custom_value=True, min_width=320)
166
+ sd35_lora3s = gr.Slider(minimum=-2, maximum=2, step=0.01, value=1.00, label="LoRA3 weight scale")
167
+ with gr.Column():
168
+ sd35_lora4 = gr.Dropdown(label="LoRA4", choices=sd35_loras, value="", allow_custom_value=True, min_width=320)
169
+ sd35_lora4s = gr.Slider(minimum=-2, maximum=2, step=0.01, value=1.00, label="LoRA4 weight scale")
170
+ with gr.Column():
171
+ sd35_lora5 = gr.Dropdown(label="LoRA5", choices=sd35_loras, value="", allow_custom_value=True, min_width=320)
172
+ sd35_lora5s = gr.Slider(minimum=-2, maximum=2, step=0.01, value=1.00, label="LoRA5 weight scale")
173
+ sd35_run_button = gr.Button(value="Submit", variant="primary")
174
+ with gr.Group():
175
+ repo_urls = gr.CheckboxGroup(visible=False, choices=[], value=[])
176
+ output_md = gr.Markdown(label="Output", value="<br><br>", elem_classes="result")
177
+ clear_button = gr.Button(value="Clear Output", variant="secondary")
178
+ gr.DuplicateButton(value="Duplicate Space")
179
+
180
+ gr.on(
181
+ triggers=[sdxl_run_button.click],
182
+ fn=convert_url_to_diffusers_repo,
183
+ inputs=[dl_url, hf_user, hf_repo, hf_token, civitai_key, is_private, is_overwrite, is_upload_sf, repo_urls,
184
+ sdxl_dtype, sdxl_vae, sdxl_clip, sdxl_t5, sdxl_scheduler, sdxl_ema, sdxl_base_repo, sdxl_mtype,
185
+ sdxl_lora1, sdxl_lora1s, sdxl_lora2, sdxl_lora2s, sdxl_lora3, sdxl_lora3s, sdxl_lora4, sdxl_lora4s, sdxl_lora5, sdxl_lora5s],
186
+ outputs=[repo_urls, output_md],
187
+ )
188
+ sdxl_presets.change(
189
+ fn=sdxl_set_presets,
190
+ inputs=[sdxl_presets],
191
+ outputs=[sdxl_dtype, sdxl_vae, sdxl_scheduler, sdxl_lora1, sdxl_lora1s, sdxl_lora2, sdxl_lora2s, sdxl_lora3, sdxl_lora3s,
192
+ sdxl_lora4, sdxl_lora4s, sdxl_lora5, sdxl_lora5s],
193
+ queue=False,
194
+ )
195
+ gr.on(
196
+ triggers=[sd15_run_button.click],
197
+ fn=convert_url_to_diffusers_repo,
198
+ inputs=[dl_url, hf_user, hf_repo, hf_token, civitai_key, is_private, is_overwrite, is_upload_sf, repo_urls,
199
+ sd15_dtype, sd15_vae, sd15_clip, sd15_t5, sd15_scheduler, sd15_ema, sd15_base_repo, sd15_mtype,
200
+ sd15_lora1, sd15_lora1s, sd15_lora2, sd15_lora2s, sd15_lora3, sd15_lora3s, sd15_lora4, sd15_lora4s, sd15_lora5, sd15_lora5s],
201
+ outputs=[repo_urls, output_md],
202
+ )
203
+ sd15_presets.change(
204
+ fn=sd15_set_presets,
205
+ inputs=[sd15_presets],
206
+ outputs=[sd15_dtype, sd15_vae, sd15_scheduler, sd15_lora1, sd15_lora1s, sd15_lora2, sd15_lora2s, sd15_lora3, sd15_lora3s,
207
+ sd15_lora4, sd15_lora4s, sd15_lora5, sd15_lora5s, sd15_ema],
208
+ queue=False,
209
+ )
210
+ gr.on(
211
+ triggers=[flux_run_button.click],
212
+ fn=convert_url_to_diffusers_repo,
213
+ inputs=[dl_url, hf_user, hf_repo, hf_token, civitai_key, is_private, is_overwrite, is_upload_sf, repo_urls,
214
+ flux_dtype, flux_vae, flux_clip, flux_t5, flux_scheduler, flux_ema, flux_base_repo, flux_mtype,
215
+ flux_lora1, flux_lora1s, flux_lora2, flux_lora2s, flux_lora3, flux_lora3s, flux_lora4, flux_lora4s, flux_lora5, flux_lora5s],
216
+ outputs=[repo_urls, output_md],
217
+ )
218
+ flux_presets.change(
219
+ fn=flux_set_presets,
220
+ inputs=[flux_presets],
221
+ outputs=[flux_dtype, flux_vae, flux_scheduler, flux_lora1, flux_lora1s, flux_lora2, flux_lora2s, flux_lora3, flux_lora3s,
222
+ flux_lora4, flux_lora4s, flux_lora5, flux_lora5s, flux_base_repo],
223
+ queue=False,
224
+ )
225
+ gr.on(
226
+ triggers=[sd35_run_button.click],
227
+ fn=convert_url_to_diffusers_repo,
228
+ inputs=[dl_url, hf_user, hf_repo, hf_token, civitai_key, is_private, is_overwrite, is_upload_sf, repo_urls,
229
+ sd35_dtype, sd35_vae, sd35_clip, sd35_t5, sd35_scheduler, sd35_ema, sd35_base_repo, sd35_mtype,
230
+ sd35_lora1, sd35_lora1s, sd35_lora2, sd35_lora2s, sd35_lora3, sd35_lora3s, sd35_lora4, sd35_lora4s, sd35_lora5, sd35_lora5s],
231
+ outputs=[repo_urls, output_md],
232
+ )
233
+ sd35_presets.change(
234
+ fn=sd35_set_presets,
235
+ inputs=[sd35_presets],
236
+ outputs=[sd35_dtype, sd35_vae, sd35_scheduler, sd35_lora1, sd35_lora1s, sd35_lora2, sd35_lora2s, sd35_lora3, sd35_lora3s,
237
+ sd35_lora4, sd35_lora4s, sd35_lora5, sd35_lora5s, sd35_base_repo],
238
+ queue=False,
239
+ )
240
+ clear_button.click(lambda: ([], "<br><br>"), None, [repo_urls, output_md], queue=False, show_api=False)
241
+
242
+ demo.queue()
243
+ demo.launch()
convert_url_to_diffusers_multi_gr.py ADDED
@@ -0,0 +1,461 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ if os.environ.get("SPACES_ZERO_GPU") is not None:
3
+ import spaces
4
+ else:
5
+ class spaces:
6
+ @staticmethod
7
+ def GPU(func):
8
+ def wrapper(*args, **kwargs):
9
+ return func(*args, **kwargs)
10
+ return wrapper
11
+ import argparse
12
+ from pathlib import Path
13
+ import os
14
+ import torch
15
+ from diffusers import (DiffusionPipeline, AutoencoderKL, FlowMatchEulerDiscreteScheduler, StableDiffusionXLPipeline, StableDiffusionPipeline,
16
+ FluxPipeline, FluxTransformer2DModel, SD3Transformer2DModel, StableDiffusion3Pipeline)
17
+ from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection, AutoTokenizer, T5EncoderModel, BitsAndBytesConfig as TFBitsAndBytesConfig
18
+ from huggingface_hub import save_torch_state_dict, snapshot_download
19
+ from diffusers.loaders.single_file_utils import (convert_flux_transformer_checkpoint_to_diffusers, convert_sd3_transformer_checkpoint_to_diffusers,
20
+ convert_sd3_t5_checkpoint_to_diffusers)
21
+ import safetensors.torch
22
+ import gradio as gr
23
+ import shutil
24
+ import gc
25
+ import tempfile
26
+ # also requires aria, gdown, peft, huggingface_hub, safetensors, transformers, accelerate, pytorch_lightning
27
+ from utils import (get_token, set_token, is_repo_exists, is_repo_name, get_download_file, upload_repo)
28
+ from sdutils import (SCHEDULER_CONFIG_MAP, get_scheduler_config, fuse_loras, DTYPE_DEFAULT, get_dtype, get_dtypes, get_model_type_from_key, get_process_dtype)
29
+
30
+
31
+ @spaces.GPU
32
+ def fake_gpu():
33
+ pass
34
+
35
+
36
+ try:
37
+ from diffusers import BitsAndBytesConfig
38
+ is_nf4 = True
39
+ except Exception:
40
+ is_nf4 = False
41
+
42
+
43
+ FLUX_BASE_REPOS = ["camenduru/FLUX.1-dev-diffusers", "black-forest-labs/FLUX.1-schnell", "John6666/flux1-dev-fp8-flux", "John6666/flux1-schnell-fp8-flux"]
44
+ FLUX_T5_URL = "https://huggingface.co/camenduru/FLUX.1-dev/blob/main/t5xxl_fp8_e4m3fn.safetensors"
45
+ SD35_BASE_REPOS = ["adamo1139/stable-diffusion-3.5-large-ungated", "adamo1139/stable-diffusion-3.5-large-turbo-ungated"]
46
+ SD35_T5_URL = "https://huggingface.co/adamo1139/stable-diffusion-3.5-large-turbo-ungated/blob/main/text_encoders/t5xxl_fp8_e4m3fn.safetensors"
47
+ TEMP_DIR = tempfile.mkdtemp()
48
+ IS_ZERO = os.environ.get("SPACES_ZERO_GPU") is not None
49
+ IS_CUDA = torch.cuda.is_available()
50
+
51
+
52
+ def safe_clean(path: str):
53
+ try:
54
+ if Path(path).exists():
55
+ if Path(path).is_dir(): shutil.rmtree(str(Path(path)))
56
+ else: Path(path).unlink()
57
+ print(f"Deleted: {path}")
58
+ else: print(f"File not found: {path}")
59
+ except Exception as e:
60
+ print(f"Failed to delete: {path} {e}")
61
+
62
+
63
+ def save_readme_md(dir, url):
64
+ orig_url = ""
65
+ orig_name = ""
66
+ if is_repo_name(url):
67
+ orig_name = url
68
+ orig_url = f"https://huggingface.co/{url}/"
69
+ elif "http" in url:
70
+ orig_name = url
71
+ orig_url = url
72
+ if orig_name and orig_url:
73
+ md = f"""---
74
+ license: other
75
+ language:
76
+ - en
77
+ library_name: diffusers
78
+ pipeline_tag: text-to-image
79
+ tags:
80
+ - text-to-image
81
+ ---
82
+ Converted from [{orig_name}]({orig_url}).
83
+ """
84
+ else:
85
+ md = f"""---
86
+ license: other
87
+ language:
88
+ - en
89
+ library_name: diffusers
90
+ pipeline_tag: text-to-image
91
+ tags:
92
+ - text-to-image
93
+ ---
94
+ """
95
+ path = str(Path(dir, "README.md"))
96
+ with open(path, mode='w', encoding="utf-8") as f:
97
+ f.write(md)
98
+
99
+
100
+ def save_module(model, name: str, dir: str, dtype: str="fp8", progress=gr.Progress(track_tqdm=True)): # doesn't work
101
+ if name in ["vae", "transformer", "unet"]: pattern = "diffusion_pytorch_model{suffix}.safetensors"
102
+ else: pattern = "model{suffix}.safetensors"
103
+ if name in ["transformer", "unet"]: size = "10GB"
104
+ else: size = "5GB"
105
+ path = str(Path(f"{dir.removesuffix('/')}/{name}"))
106
+ os.makedirs(path, exist_ok=True)
107
+ progress(0, desc=f"Saving {name} to {dir}...")
108
+ print(f"Saving {name} to {dir}...")
109
+ model.to("cpu")
110
+ sd = dict(model.state_dict())
111
+ new_sd = {}
112
+ for key in list(sd.keys()):
113
+ q = sd.pop(key)
114
+ if dtype == "fp8": new_sd[key] = q if q.dtype == torch.float8_e4m3fn else q.to(torch.float8_e4m3fn)
115
+ else: new_sd[key] = q
116
+ del sd
117
+ gc.collect()
118
+ save_torch_state_dict(state_dict=new_sd, save_directory=path, filename_pattern=pattern, max_shard_size=size)
119
+ del new_sd
120
+ gc.collect()
121
+
122
+
123
+ def save_module_sd(sd: dict, name: str, dir: str, dtype: str="fp8", progress=gr.Progress(track_tqdm=True)):
124
+ if name in ["vae", "transformer", "unet"]: pattern = "diffusion_pytorch_model{suffix}.safetensors"
125
+ else: pattern = "model{suffix}.safetensors"
126
+ if name in ["transformer", "unet"]: size = "10GB"
127
+ else: size = "5GB"
128
+ path = str(Path(f"{dir.removesuffix('/')}/{name}"))
129
+ os.makedirs(path, exist_ok=True)
130
+ progress(0, desc=f"Saving state_dict of {name} to {dir}...")
131
+ print(f"Saving state_dict of {name} to {dir}...")
132
+ new_sd = {}
133
+ for key in list(sd.keys()):
134
+ q = sd.pop(key).to("cpu")
135
+ if dtype == "fp8": new_sd[key] = q if q.dtype == torch.float8_e4m3fn else q.to(torch.float8_e4m3fn)
136
+ else: new_sd[key] = q
137
+ save_torch_state_dict(state_dict=new_sd, save_directory=path, filename_pattern=pattern, max_shard_size=size)
138
+ del new_sd
139
+ gc.collect()
140
+
141
+
142
+ def convert_flux_fp8_cpu(new_file: str, new_dir: str, dtype: str, base_repo: str, civitai_key: str, kwargs: dict, progress=gr.Progress(track_tqdm=True)):
143
+ temp_dir = TEMP_DIR
144
+ down_dir = str(Path(f"{TEMP_DIR}/down"))
145
+ os.makedirs(down_dir, exist_ok=True)
146
+ hf_token = get_token()
147
+ progress(0.25, desc=f"Loading {new_file}...")
148
+ orig_sd = safetensors.torch.load_file(new_file)
149
+ progress(0.3, desc=f"Converting {new_file}...")
150
+ conv_sd = convert_flux_transformer_checkpoint_to_diffusers(orig_sd)
151
+ del orig_sd
152
+ gc.collect()
153
+ progress(0.35, desc=f"Saving {new_file}...")
154
+ save_module_sd(conv_sd, "transformer", new_dir, dtype)
155
+ del conv_sd
156
+ gc.collect()
157
+ progress(0.5, desc=f"Loading text_encoder_2 from {FLUX_T5_URL}...")
158
+ t5_file = get_download_file(temp_dir, FLUX_T5_URL, civitai_key)
159
+ if not t5_file: raise Exception(f"Safetensors file not found: {FLUX_T5_URL}")
160
+ t5_sd = safetensors.torch.load_file(t5_file)
161
+ safe_clean(t5_file)
162
+ save_module_sd(t5_sd, "text_encoder_2", new_dir, dtype)
163
+ del t5_sd
164
+ gc.collect()
165
+ progress(0.6, desc=f"Loading other components from {base_repo}...")
166
+ pipe = FluxPipeline.from_pretrained(base_repo, transformer=None, text_encoder_2=None, use_safetensors=True, **kwargs,
167
+ torch_dtype=torch.bfloat16, token=hf_token)
168
+ pipe.save_pretrained(new_dir)
169
+ progress(0.75, desc=f"Loading nontensor files from {base_repo}...")
170
+ snapshot_download(repo_id=base_repo, local_dir=down_dir, token=hf_token, force_download=True,
171
+ ignore_patterns=["*.safetensors", "*.sft", ".*", "README*", "*.md", "*.index", "*.jpg", "*.jpeg", "*.png", "*.webp"])
172
+ shutil.copytree(down_dir, new_dir, ignore=shutil.ignore_patterns(".*", "README*", "*.md", "*.jpg", "*.jpeg", "*.png", "*.webp"), dirs_exist_ok=True)
173
+ safe_clean(down_dir)
174
+
175
+
176
+ def convert_sd35_fp8_cpu(new_file: str, new_dir: str, dtype: str, base_repo: str, civitai_key: str, kwargs: dict, progress=gr.Progress(track_tqdm=True)):
177
+ temp_dir = TEMP_DIR
178
+ down_dir = str(Path(f"{TEMP_DIR}/down"))
179
+ os.makedirs(down_dir, exist_ok=True)
180
+ hf_token = get_token()
181
+ progress(0.25, desc=f"Loading {new_file}...")
182
+ orig_sd = safetensors.torch.load_file(new_file)
183
+ progress(0.3, desc=f"Converting {new_file}...")
184
+ conv_sd = convert_sd3_transformer_checkpoint_to_diffusers(orig_sd)
185
+ del orig_sd
186
+ gc.collect()
187
+ progress(0.35, desc=f"Saving {new_file}...")
188
+ save_module_sd(conv_sd, "transformer", new_dir, dtype)
189
+ del conv_sd
190
+ gc.collect()
191
+ progress(0.5, desc=f"Loading text_encoder_3 from {SD35_T5_URL}...")
192
+ t5_file = get_download_file(temp_dir, SD35_T5_URL, civitai_key)
193
+ if not t5_file: raise Exception(f"Safetensors file not found: {SD35_T5_URL}")
194
+ t5_sd = safetensors.torch.load_file(t5_file)
195
+ safe_clean(t5_file)
196
+ conv_t5_sd = convert_sd3_t5_checkpoint_to_diffusers(t5_sd)
197
+ del t5_sd
198
+ gc.collect()
199
+ save_module_sd(conv_t5_sd, "text_encoder_3", new_dir, dtype)
200
+ del conv_t5_sd
201
+ gc.collect()
202
+ progress(0.6, desc=f"Loading other components from {base_repo}...")
203
+ pipe = StableDiffusion3Pipeline.from_pretrained(base_repo, transformer=None, text_encoder_3=None, use_safetensors=True, **kwargs,
204
+ torch_dtype=torch.bfloat16, token=hf_token)
205
+ pipe.save_pretrained(new_dir)
206
+ progress(0.75, desc=f"Loading nontensor files from {base_repo}...")
207
+ snapshot_download(repo_id=base_repo, local_dir=down_dir, token=hf_token, force_download=True,
208
+ ignore_patterns=["*.safetensors", "*.sft", ".*", "README*", "*.md", "*.index", "*.jpg", "*.jpeg", "*.png", "*.webp"])
209
+ shutil.copytree(down_dir, new_dir, ignore=shutil.ignore_patterns(".*", "README*", "*.md", "*.jpg", "*.jpeg", "*.png", "*.webp"), dirs_exist_ok=True)
210
+ safe_clean(down_dir)
211
+
212
+
213
+ #@spaces.GPU(duration=60)
214
+ def load_and_save_pipeline(pipe, model_type: str, url: str, new_file: str, new_dir: str, dtype: str,
215
+ scheduler: str, base_repo: str, civitai_key: str, lora_dict: dict,
216
+ my_vae, my_clip_tokenizer, my_clip_encoder, my_t5_tokenizer, my_t5_encoder,
217
+ kwargs: dict, dkwargs: dict, progress=gr.Progress(track_tqdm=True)):
218
+ try:
219
+ hf_token = get_token()
220
+ temp_dir = TEMP_DIR
221
+ qkwargs = {}
222
+ tfqkwargs = {}
223
+ if is_nf4:
224
+ nf4_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4",
225
+ bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=torch.bfloat16)
226
+ nf4_config_tf = TFBitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4",
227
+ bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=torch.bfloat16)
228
+ else:
229
+ nf4_config = None
230
+ nf4_config_tf = None
231
+ if dtype == "NF4" and nf4_config is not None and nf4_config_tf is not None:
232
+ qkwargs["quantization_config"] = nf4_config
233
+ tfqkwargs["quantization_config"] = nf4_config_tf
234
+
235
+ #t5 = None
236
+
237
+ if model_type == "SDXL":
238
+ if is_repo_name(url): pipe = StableDiffusionXLPipeline.from_pretrained(url, use_safetensors=True, **kwargs, **dkwargs, token=hf_token)
239
+ else: pipe = StableDiffusionXLPipeline.from_single_file(new_file, use_safetensors=True, **kwargs, **dkwargs)
240
+ pipe = fuse_loras(pipe, lora_dict, temp_dir, civitai_key, dkwargs)
241
+ sconf = get_scheduler_config(scheduler)
242
+ pipe.scheduler = sconf[0].from_config(pipe.scheduler.config, **sconf[1])
243
+ pipe.save_pretrained(new_dir)
244
+ elif model_type == "SD 1.5":
245
+ if is_repo_name(url): pipe = StableDiffusionPipeline.from_pretrained(url, extract_ema=ema, requires_safety_checker=False,
246
+ use_safetensors=True, **kwargs, **dkwargs, token=hf_token)
247
+ else: pipe = StableDiffusionPipeline.from_single_file(new_file, extract_ema=ema, requires_safety_checker=False, use_safetensors=True, **kwargs, **dkwargs)
248
+ pipe = fuse_loras(pipe, lora_dict, temp_dir, civitai_key, dkwargs)
249
+ sconf = get_scheduler_config(scheduler)
250
+ pipe.scheduler = sconf[0].from_config(pipe.scheduler.config, **sconf[1])
251
+ pipe.save_pretrained(new_dir)
252
+ elif model_type == "FLUX":
253
+ if dtype != "fp8":
254
+ if is_repo_name(url):
255
+ transformer = FluxTransformer2DModel.from_pretrained(url, subfolder="transformer", config=base_repo, **dkwargs, **qkwargs)
256
+ #if my_t5_encoder is None:
257
+ # t5 = T5EncoderModel.from_pretrained(url, subfolder="text_encoder_2", config=base_repo, **dkwargs, **tfqkwargs)
258
+ # kwargs["text_encoder_2"] = t5
259
+ pipe = FluxPipeline.from_pretrained(url, transformer=transformer, use_safetensors=True, **kwargs, **dkwargs, token=hf_token)
260
+ else:
261
+ transformer = FluxTransformer2DModel.from_single_file(new_file, subfolder="transformer", config=base_repo, **dkwargs, **qkwargs)
262
+ #if my_t5_encoder is None:
263
+ # t5 = T5EncoderModel.from_pretrained(base_repo, subfolder="text_encoder_2", config=base_repo, **dkwargs, **tfqkwargs)
264
+ # kwargs["text_encoder_2"] = t5
265
+ pipe = FluxPipeline.from_pretrained(base_repo, transformer=transformer, use_safetensors=True, **kwargs, **dkwargs, token=hf_token)
266
+ pipe = fuse_loras(pipe, lora_dict, temp_dir, civitai_key, dkwargs)
267
+ pipe.save_pretrained(new_dir)
268
+ elif not is_repo_name(url): convert_flux_fp8_cpu(new_file, new_dir, dtype, base_repo, civitai_key, kwargs)
269
+ elif model_type == "SD 3.5":
270
+ if dtype != "fp8":
271
+ if is_repo_name(url):
272
+ transformer = SD3Transformer2DModel.from_pretrained(url, subfolder="transformer", config=base_repo, **dkwargs, **qkwargs)
273
+ #if my_t5_encoder is None:
274
+ # t5 = T5EncoderModel.from_pretrained(url, subfolder="text_encoder_3", config=base_repo, **dkwargs, **tfqkwargs)
275
+ # kwargs["text_encoder_3"] = t5
276
+ pipe = StableDiffusion3Pipeline.from_pretrained(url, transformer=transformer, use_safetensors=True, **kwargs, **dkwargs, token=hf_token)
277
+ else:
278
+ transformer = SD3Transformer2DModel.from_single_file(new_file, subfolder="transformer", config=base_repo, **dkwargs, **qkwargs)
279
+ #if my_t5_encoder is None:
280
+ # t5 = T5EncoderModel.from_pretrained(base_repo, subfolder="text_encoder_3", config=base_repo, **dkwargs, **tfqkwargs)
281
+ # kwargs["text_encoder_3"] = t5
282
+ pipe = StableDiffusion3Pipeline.from_pretrained(base_repo, transformer=transformer, use_safetensors=True, **kwargs, **dkwargs, token=hf_token)
283
+ pipe = fuse_loras(pipe, lora_dict, temp_dir, civitai_key, dkwargs)
284
+ pipe.save_pretrained(new_dir)
285
+ elif not is_repo_name(url): convert_sd35_fp8_cpu(new_file, new_dir, dtype, base_repo, civitai_key, kwargs)
286
+ else: # unknown model type
287
+ if is_repo_name(url): pipe = DiffusionPipeline.from_pretrained(url, use_safetensors=True, **kwargs, **dkwargs, token=hf_token)
288
+ else: pipe = DiffusionPipeline.from_single_file(new_file, use_safetensors=True, **kwargs, **dkwargs)
289
+ pipe = fuse_loras(pipe, lora_dict, temp_dir, civitai_key, dkwargs)
290
+ pipe.save_pretrained(new_dir)
291
+ except Exception as e:
292
+ raise Exception("Failed to load pipeline.") from e
293
+ finally:
294
+ return pipe
295
+
296
+
297
+ def convert_url_to_diffusers(url: str, civitai_key: str="", is_upload_sf: bool=False, dtype: str="fp16", vae: str="", clip: str="", t5: str="",
298
+ scheduler: str="Euler a", ema: bool=True, base_repo: str="", mtype: str="", lora_dict: dict={}, is_local: bool=True, progress=gr.Progress(track_tqdm=True)):
299
+ try:
300
+ hf_token = get_token()
301
+ progress(0, desc="Start converting...")
302
+ temp_dir = TEMP_DIR
303
+
304
+ if is_repo_name(url) and is_repo_exists(url):
305
+ new_file = url
306
+ model_type = mtype
307
+ else:
308
+ new_file = get_download_file(temp_dir, url, civitai_key)
309
+ if not new_file: raise Exception(f"Safetensors file not found: {url}")
310
+ model_type = get_model_type_from_key(new_file)
311
+ new_dir = Path(new_file).stem.replace(" ", "_").replace(",", "_").replace(".", "_") #
312
+
313
+ kwargs = {}
314
+ dkwargs = {}
315
+ if dtype != DTYPE_DEFAULT: dkwargs["torch_dtype"] = get_process_dtype(dtype, model_type)
316
+ pipe = None
317
+
318
+ print(f"Model type: {model_type} / VAE: {vae} / CLIP: {clip} / T5: {t5} / Scheduler: {scheduler} / dtype: {dtype} / EMA: {ema} / Base repo: {base_repo} / LoRAs: {lora_dict}")
319
+
320
+ my_vae = None
321
+ if vae:
322
+ progress(0, desc=f"Loading VAE: {vae}...")
323
+ if is_repo_name(vae): my_vae = AutoencoderKL.from_pretrained(vae, **dkwargs, token=hf_token)
324
+ else:
325
+ new_vae_file = get_download_file(temp_dir, vae, civitai_key)
326
+ my_vae = AutoencoderKL.from_single_file(new_vae_file, **dkwargs) if new_vae_file else None
327
+ safe_clean(new_vae_file)
328
+ if my_vae: kwargs["vae"] = my_vae
329
+
330
+ my_clip_tokenizer = None
331
+ my_clip_encoder = None
332
+ if clip:
333
+ progress(0, desc=f"Loading CLIP: {clip}...")
334
+ if is_repo_name(clip):
335
+ my_clip_tokenizer = CLIPTokenizer.from_pretrained(clip, token=hf_token)
336
+ if model_type == "SD 3.5": my_clip_encoder = CLIPTextModelWithProjection.from_pretrained(clip, **dkwargs, token=hf_token)
337
+ else: my_clip_encoder = CLIPTextModel.from_pretrained(clip, **dkwargs, token=hf_token)
338
+ else:
339
+ new_clip_file = get_download_file(temp_dir, clip, civitai_key)
340
+ if model_type == "SD 3.5": my_clip_encoder = CLIPTextModelWithProjection.from_single_file(new_clip_file, **dkwargs) if new_clip_file else None
341
+ else: my_clip_encoder = CLIPTextModel.from_single_file(new_clip_file, **dkwargs) if new_clip_file else None
342
+ safe_clean(new_clip_file)
343
+ if model_type == "SD 3.5":
344
+ if my_clip_tokenizer:
345
+ kwargs["tokenizer"] = my_clip_tokenizer
346
+ kwargs["tokenizer_2"] = my_clip_tokenizer
347
+ if my_clip_encoder:
348
+ kwargs["text_encoder"] = my_clip_encoder
349
+ kwargs["text_encoder_2"] = my_clip_encoder
350
+ else:
351
+ if my_clip_tokenizer: kwargs["tokenizer"] = my_clip_tokenizer
352
+ if my_clip_encoder: kwargs["text_encoder"] = my_clip_encoder
353
+
354
+ my_t5_tokenizer = None
355
+ my_t5_encoder = None
356
+ if t5:
357
+ progress(0, desc=f"Loading T5: {t5}...")
358
+ if is_repo_name(t5):
359
+ my_t5_tokenizer = AutoTokenizer.from_pretrained(t5, token=hf_token)
360
+ my_t5_encoder = T5EncoderModel.from_pretrained(t5, **dkwargs, token=hf_token)
361
+ else:
362
+ new_t5_file = get_download_file(temp_dir, t5, civitai_key)
363
+ my_t5_encoder = T5EncoderModel.from_single_file(new_t5_file, **dkwargs) if new_t5_file else None
364
+ safe_clean(new_t5_file)
365
+ if model_type == "SD 3.5":
366
+ if my_t5_tokenizer: kwargs["tokenizer_3"] = my_t5_tokenizer
367
+ if my_t5_encoder: kwargs["text_encoder_3"] = my_t5_encoder
368
+ else:
369
+ if my_t5_tokenizer: kwargs["tokenizer_2"] = my_t5_tokenizer
370
+ if my_t5_encoder: kwargs["text_encoder_2"] = my_t5_encoder
371
+
372
+ pipe = load_and_save_pipeline(pipe, model_type, url, new_file, new_dir, dtype, scheduler, base_repo, civitai_key, lora_dict,
373
+ my_vae, my_clip_tokenizer, my_clip_encoder, my_t5_tokenizer, my_t5_encoder, kwargs, dkwargs)
374
+
375
+ if Path(new_dir).exists(): save_readme_md(new_dir, url)
376
+
377
+ if not is_local:
378
+ if not is_repo_name(new_file) and is_upload_sf: shutil.move(str(Path(new_file).resolve()), str(Path(new_dir, Path(new_file).name).resolve()))
379
+ else: safe_clean(new_file)
380
+
381
+ progress(1, desc="Converted.")
382
+ return new_dir
383
+ except Exception as e:
384
+ print(f"Failed to convert. {e}")
385
+ raise Exception("Failed to convert.") from e
386
+ finally:
387
+ del pipe
388
+ torch.cuda.empty_cache()
389
+ gc.collect()
390
+
391
+
392
+ def convert_url_to_diffusers_repo(dl_url: str, hf_user: str, hf_repo: str, hf_token: str, civitai_key="", is_private: bool=True, is_overwrite: bool=False,
393
+ is_upload_sf: bool=False, urls: list=[], dtype: str="fp16", vae: str="", clip: str="", t5: str="", scheduler: str="Euler a", ema: bool=True,
394
+ base_repo: str="", mtype: str="", lora1: str="", lora1s=1.0, lora2: str="", lora2s=1.0, lora3: str="", lora3s=1.0,
395
+ lora4: str="", lora4s=1.0, lora5: str="", lora5s=1.0, progress=gr.Progress(track_tqdm=True)):
396
+ try:
397
+ is_local = False
398
+ if not civitai_key and os.environ.get("CIVITAI_API_KEY"): civitai_key = os.environ.get("CIVITAI_API_KEY") # default Civitai API key
399
+ if not hf_token and os.environ.get("HF_TOKEN"): hf_token = os.environ.get("HF_TOKEN") # default HF write token
400
+ if not hf_user and os.environ.get("HF_USER"): hf_user = os.environ.get("HF_USER") # default username
401
+ if not hf_user: raise gr.Error(f"Invalid user name: {hf_user}")
402
+ if not hf_repo and os.environ.get("HF_REPO"): hf_repo = os.environ.get("HF_REPO") # default reponame
403
+ if not is_overwrite and os.environ.get("HF_OW"): is_overwrite = os.environ.get("HF_OW") # for debugging
404
+ if not dl_url and os.environ.get("HF_URL"): dl_url = os.environ.get("HF_URL") # for debugging
405
+ set_token(hf_token)
406
+ lora_dict = {lora1: lora1s, lora2: lora2s, lora3: lora3s, lora4: lora4s, lora5: lora5s}
407
+ new_path = convert_url_to_diffusers(dl_url, civitai_key, is_upload_sf, dtype, vae, clip, t5, scheduler, ema, base_repo, mtype, lora_dict, is_local)
408
+ if not new_path: return ""
409
+ new_repo_id = f"{hf_user}/{Path(new_path).stem}"
410
+ if hf_repo != "": new_repo_id = f"{hf_user}/{hf_repo}"
411
+ if not is_repo_name(new_repo_id): raise gr.Error(f"Invalid repo name: {new_repo_id}")
412
+ if not is_overwrite and is_repo_exists(new_repo_id): raise gr.Error(f"Repo already exists: {new_repo_id}")
413
+ repo_url = upload_repo(new_repo_id, new_path, is_private)
414
+ safe_clean(new_path)
415
+ if not urls: urls = []
416
+ urls.append(repo_url)
417
+ md = "### Your new repo:\n"
418
+ for u in urls:
419
+ md += f"[{str(u).split('/')[-2]}/{str(u).split('/')[-1]}]({str(u)})<br>"
420
+ return gr.update(value=urls, choices=urls), gr.update(value=md)
421
+ except Exception as e:
422
+ print(f"Error occured. {e}")
423
+ raise gr.Error(f"Error occured. {e}")
424
+
425
+
426
+ if __name__ == "__main__":
427
+ parser = argparse.ArgumentParser()
428
+
429
+ parser.add_argument("--url", type=str, required=True, help="URL of the model to convert.")
430
+ parser.add_argument("--dtype", default="fp16", type=str, choices=get_dtypes(), help='Output data type. (Default: "fp16")')
431
+ parser.add_argument("--scheduler", default="Euler a", type=str, choices=list(SCHEDULER_CONFIG_MAP.keys()), required=False, help="Scheduler name to use.")
432
+ parser.add_argument("--vae", default="", type=str, required=False, help="URL or Repo ID of the VAE to use.")
433
+ parser.add_argument("--clip", default="", type=str, required=False, help="URL or Repo ID of the CLIP to use.")
434
+ parser.add_argument("--t5", default="", type=str, required=False, help="URL or Repo ID of the T5 to use.")
435
+ parser.add_argument("--base", default="", type=str, required=False, help="Repo ID of the base repo.")
436
+ parser.add_argument("--nonema", action="store_true", default=False, help="Don't extract EMA (for SD 1.5).")
437
+ parser.add_argument("--civitai_key", default="", type=str, required=False, help="Civitai API Key (If you want to download file from Civitai).")
438
+ parser.add_argument("--lora1", default="", type=str, required=False, help="URL of the LoRA to use.")
439
+ parser.add_argument("--lora1s", default=1.0, type=float, required=False, help="LoRA weight scale of --lora1.")
440
+ parser.add_argument("--lora2", default="", type=str, required=False, help="URL of the LoRA to use.")
441
+ parser.add_argument("--lora2s", default=1.0, type=float, required=False, help="LoRA weight scale of --lora2.")
442
+ parser.add_argument("--lora3", default="", type=str, required=False, help="URL of the LoRA to use.")
443
+ parser.add_argument("--lora3s", default=1.0, type=float, required=False, help="LoRA weight scale of --lora3.")
444
+ parser.add_argument("--lora4", default="", type=str, required=False, help="URL of the LoRA to use.")
445
+ parser.add_argument("--lora4s", default=1.0, type=float, required=False, help="LoRA weight scale of --lora4.")
446
+ parser.add_argument("--lora5", default="", type=str, required=False, help="URL of the LoRA to use.")
447
+ parser.add_argument("--lora5s", default=1.0, type=float, required=False, help="LoRA weight scale of --lora5.")
448
+ parser.add_argument("--loras", default="", type=str, required=False, help="Folder of the LoRA to use.")
449
+
450
+ args = parser.parse_args()
451
+ assert args.url is not None, "Must provide a URL!"
452
+
453
+ is_local = True
454
+ lora_dict = {args.lora1: args.lora1s, args.lora2: args.lora2s, args.lora3: args.lora3s, args.lora4: args.lora4s, args.lora5: args.lora5s}
455
+ if args.loras and Path(args.loras).exists():
456
+ for p in Path(args.loras).glob('**/*.safetensors'):
457
+ lora_dict[str(p)] = 1.0
458
+ ema = not args.nonema
459
+ mtype = "SDXL"
460
+
461
+ convert_url_to_diffusers(args.url, args.civitai_key, args.dtype, args.vae, args.clip, args.t5, args.scheduler, ema, args.base, mtype, lora_dict, is_local)
packages.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ git-lfs aria2
presets.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sdutils import get_dtypes, SCHEDULER_CONFIG_MAP
2
+
3
+
4
+ DEFAULT_DTYPE = get_dtypes()[0]
5
+ schedulers = list(SCHEDULER_CONFIG_MAP.keys())
6
+
7
+
8
+ clips = [
9
+ "",
10
+ "openai/clip-vit-large-patch14",
11
+ ]
12
+
13
+
14
+ t5s = [
15
+ "",
16
+ "https://huggingface.co/camenduru/FLUX.1-dev/blob/main/t5xxl_fp8_e4m3fn.safetensors",
17
+ ]
18
+
19
+
20
+ sdxl_vaes = [
21
+ "",
22
+ "https://huggingface.co/madebyollin/sdxl-vae-fp16-fix/blob/main/sdxl.vae.safetensors",
23
+ "https://huggingface.co/nubby/blessed-sdxl-vae-fp16-fix/blob/main/sdxl_vae-fp16fix-blessed.safetensors",
24
+ "https://huggingface.co/John6666/safetensors_converting_test/blob/main/xlVAEC_e7.safetensors",
25
+ "https://huggingface.co/John6666/safetensors_converting_test/blob/main/xlVAEC_f1.safetensors",
26
+ ]
27
+
28
+
29
+ sdxl_loras = [
30
+ "",
31
+ "https://huggingface.co/SPO-Diffusion-Models/SPO-SDXL_4k-p_10ep_LoRA/blob/main/spo_sdxl_10ep_4k-data_lora_diffusers.safetensors",
32
+ "https://huggingface.co/wangfuyun/PCM_Weights/blob/main/sdxl/pcm_sdxl_smallcfg_2step_converted.safetensors",
33
+ "https://huggingface.co/wangfuyun/PCM_Weights/blob/main/sdxl/pcm_sdxl_smallcfg_4step_converted.safetensors",
34
+ "https://huggingface.co/wangfuyun/PCM_Weights/blob/main/sdxl/pcm_sdxl_smallcfg_8step_converted.safetensors",
35
+ "https://huggingface.co/wangfuyun/PCM_Weights/blob/main/sdxl/pcm_sdxl_normalcfg_8step_converted.safetensors",
36
+ "https://huggingface.co/wangfuyun/PCM_Weights/blob/main/sdxl/pcm_sdxl_normalcfg_16step_converted.safetensors",
37
+ "https://huggingface.co/ByteDance/Hyper-SD/blob/main/Hyper-SDXL-1step-lora.safetensors",
38
+ "https://huggingface.co/ByteDance/Hyper-SD/blob/main/Hyper-SDXL-2steps-lora.safetensors",
39
+ "https://huggingface.co/ByteDance/Hyper-SD/blob/main/Hyper-SDXL-4steps-lora.safetensors",
40
+ "https://huggingface.co/ByteDance/Hyper-SD/blob/main/Hyper-SDXL-8steps-CFG-lora.safetensors",
41
+ "https://huggingface.co/ByteDance/Hyper-SD/blob/main/Hyper-SDXL-12steps-CFG-lora.safetensors",
42
+ "https://huggingface.co/latent-consistency/lcm-lora-sdxl/blob/main/pytorch_lora_weights.safetensors",
43
+ ]
44
+
45
+
46
+ sdxl_preset_dict = {
47
+ "Default": [DEFAULT_DTYPE, "", "Euler a", "", 1.0, "", 1.0, "", 1.0, "", 1.0, "", 1.0],
48
+ "Bake in standard VAE": [DEFAULT_DTYPE, "https://huggingface.co/madebyollin/sdxl-vae-fp16-fix/blob/main/sdxl.vae.safetensors",
49
+ "Euler a", "", "", 1.0, "", 1.0, "", 1.0, "", 1.0, "", 1.0],
50
+ "Hyper-SDXL / SPO": [DEFAULT_DTYPE, "https://huggingface.co/madebyollin/sdxl-vae-fp16-fix/blob/main/sdxl.vae.safetensors",
51
+ "TCD", "https://huggingface.co/ByteDance/Hyper-SD/blob/main/Hyper-SDXL-8steps-CFG-lora.safetensors", 1.0,
52
+ "https://huggingface.co/SPO-Diffusion-Models/SPO-SDXL_4k-p_10ep_LoRA/blob/main/spo_sdxl_10ep_4k-data_lora_diffusers.safetensors",
53
+ 1.0, "", 1.0, "", 1.0, "", 1.0],
54
+ }
55
+
56
+
57
+ def sdxl_set_presets(preset: str="Default"):
58
+ p = []
59
+ if preset in sdxl_preset_dict.keys(): p = sdxl_preset_dict[preset]
60
+ else: p = sdxl_preset_dict["Default"]
61
+ return p[0], p[1], p[2], p[3], p[4], p[5], p[6], p[7], p[8], p[9], p[10], p[11], p[12]
62
+
63
+
64
+ sd15_vaes = [
65
+ "",
66
+ "https://huggingface.co/stabilityai/sd-vae-ft-mse-original/resolve/main/vae-ft-mse-840000-ema-pruned.ckpt",
67
+ "https://huggingface.co/stabilityai/sd-vae-ft-ema-original/resolve/main/vae-ft-ema-560000-ema-pruned.ckpt",
68
+ ]
69
+
70
+
71
+ sd15_loras = [
72
+ "",
73
+ "https://huggingface.co/SPO-Diffusion-Models/SPO-SD-v1-5_4k-p_10ep_LoRA/blob/main/spo-sd-v1-5_4k-p_10ep_lora_diffusers.safetensors",
74
+ ]
75
+
76
+
77
+ sd15_preset_dict = {
78
+ "Default": [DEFAULT_DTYPE, "", "Euler", "", 1.0, "", 1.0, "", 1.0, "", 1.0, "", 1.0, True],
79
+ "Bake in standard VAE": [DEFAULT_DTYPE, "https://huggingface.co/stabilityai/sd-vae-ft-mse-original/resolve/main/vae-ft-mse-840000-ema-pruned.ckpt",
80
+ "Euler", "", 1.0, "", 1.0, "", 1.0, "", 1.0, "", 1.0, True],
81
+ }
82
+
83
+
84
+ def sd15_set_presets(preset: str="Default"):
85
+ p = []
86
+ if preset in sd15_preset_dict.keys(): p = sd15_preset_dict[preset]
87
+ else: p = sd15_preset_dict["Default"]
88
+ return p[0], p[1], p[2], p[3], p[4], p[5], p[6], p[7], p[8], p[9], p[10], p[11], p[12], p[13]
89
+
90
+
91
+ flux_vaes = [
92
+ "",
93
+ ]
94
+
95
+
96
+ flux_loras = [
97
+ "",
98
+ ]
99
+
100
+
101
+ flux_preset_dict = {
102
+ "dev": ["bf16", "", "", "", 1.0, "", 1.0, "", 1.0, "", 1.0, "", 1.0, "camenduru/FLUX.1-dev-diffusers"],
103
+ "schnell": ["bf16", "", "", "", 1.0, "", 1.0, "", 1.0, "", 1.0, "", 1.0, "black-forest-labs/FLUX.1-schnell"],
104
+ }
105
+
106
+
107
+ def flux_set_presets(preset: str="dev"):
108
+ p = []
109
+ if preset in flux_preset_dict.keys(): p = flux_preset_dict[preset]
110
+ else: p = flux_preset_dict["dev"]
111
+ return p[0], p[1], p[2], p[3], p[4], p[5], p[6], p[7], p[8], p[9], p[10], p[11], p[12], p[13]
112
+
113
+
114
+
115
+ sd35_vaes = [
116
+ "",
117
+ ]
118
+
119
+
120
+ sd35_loras = [
121
+ "",
122
+ ]
123
+
124
+
125
+ sd35_preset_dict = {
126
+ "Default": ["bf16", "", "", "", 1.0, "", 1.0, "", 1.0, "", 1.0, "", 1.0, "adamo1139/stable-diffusion-3.5-large-ungated"],
127
+ }
128
+
129
+
130
+ def sd35_set_presets(preset: str="dev"):
131
+ p = []
132
+ if preset in sd35_preset_dict.keys(): p = sd35_preset_dict[preset]
133
+ else: p = sd35_preset_dict["Default"]
134
+ return p[0], p[1], p[2], p[3], p[4], p[5], p[6], p[7], p[8], p[9], p[10], p[11], p[12], p[13]
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ huggingface_hub
2
+ safetensors
3
+ transformers==4.44.0
4
+ diffusers==0.30.3
5
+ peft
6
+ sentencepiece
7
+ torch
8
+ pytorch_lightning
9
+ gdown
10
+ bitsandbytes
11
+ accelerate
sdutils.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from pathlib import Path
3
+ from utils import get_download_file
4
+ from stkey import read_safetensors_key
5
+ try:
6
+ from diffusers import BitsAndBytesConfig
7
+ is_nf4 = True
8
+ except Exception:
9
+ is_nf4 = False
10
+
11
+
12
+ DTYPE_DEFAULT = "default"
13
+ DTYPE_DICT = {
14
+ "fp16": torch.float16,
15
+ "bf16": torch.bfloat16,
16
+ "fp32": torch.float32,
17
+ "fp8": torch.float8_e4m3fn,
18
+ }
19
+ #QTYPES = ["NF4"] if is_nf4 else []
20
+ QTYPES = []
21
+
22
+ def get_dtypes():
23
+ return list(DTYPE_DICT.keys()) + [DTYPE_DEFAULT] + QTYPES
24
+
25
+
26
+ def get_dtype(dtype: str):
27
+ if dtype in set(QTYPES): return torch.bfloat16
28
+ return DTYPE_DICT.get(dtype, torch.float16)
29
+
30
+
31
+ from diffusers import (
32
+ DPMSolverMultistepScheduler,
33
+ DPMSolverSinglestepScheduler,
34
+ KDPM2DiscreteScheduler,
35
+ EulerDiscreteScheduler,
36
+ EulerAncestralDiscreteScheduler,
37
+ HeunDiscreteScheduler,
38
+ LMSDiscreteScheduler,
39
+ DDIMScheduler,
40
+ DEISMultistepScheduler,
41
+ UniPCMultistepScheduler,
42
+ LCMScheduler,
43
+ PNDMScheduler,
44
+ KDPM2AncestralDiscreteScheduler,
45
+ DPMSolverSDEScheduler,
46
+ EDMDPMSolverMultistepScheduler,
47
+ DDPMScheduler,
48
+ EDMEulerScheduler,
49
+ TCDScheduler,
50
+ )
51
+
52
+
53
+ SCHEDULER_CONFIG_MAP = {
54
+ "DPM++ 2M": (DPMSolverMultistepScheduler, {"use_karras_sigmas": False}),
55
+ "DPM++ 2M Karras": (DPMSolverMultistepScheduler, {"use_karras_sigmas": True}),
56
+ "DPM++ 2M SDE": (DPMSolverMultistepScheduler, {"use_karras_sigmas": False, "algorithm_type": "sde-dpmsolver++"}),
57
+ "DPM++ 2M SDE Karras": (DPMSolverMultistepScheduler, {"use_karras_sigmas": True, "algorithm_type": "sde-dpmsolver++"}),
58
+ "DPM++ 2S": (DPMSolverSinglestepScheduler, {"use_karras_sigmas": False}),
59
+ "DPM++ 2S Karras": (DPMSolverSinglestepScheduler, {"use_karras_sigmas": True}),
60
+ "DPM++ 1S": (DPMSolverMultistepScheduler, {"solver_order": 1}),
61
+ "DPM++ 1S Karras": (DPMSolverMultistepScheduler, {"solver_order": 1, "use_karras_sigmas": True}),
62
+ "DPM++ 3M": (DPMSolverMultistepScheduler, {"solver_order": 3}),
63
+ "DPM++ 3M Karras": (DPMSolverMultistepScheduler, {"solver_order": 3, "use_karras_sigmas": True}),
64
+ "DPM++ SDE": (DPMSolverSDEScheduler, {"use_karras_sigmas": False}),
65
+ "DPM++ SDE Karras": (DPMSolverSDEScheduler, {"use_karras_sigmas": True}),
66
+ "DPM2": (KDPM2DiscreteScheduler, {}),
67
+ "DPM2 Karras": (KDPM2DiscreteScheduler, {"use_karras_sigmas": True}),
68
+ "DPM2 a": (KDPM2AncestralDiscreteScheduler, {}),
69
+ "DPM2 a Karras": (KDPM2AncestralDiscreteScheduler, {"use_karras_sigmas": True}),
70
+ "Euler": (EulerDiscreteScheduler, {}),
71
+ "Euler a": (EulerAncestralDiscreteScheduler, {}),
72
+ "Euler trailing": (EulerDiscreteScheduler, {"timestep_spacing": "trailing", "prediction_type": "sample"}),
73
+ "Euler a trailing": (EulerAncestralDiscreteScheduler, {"timestep_spacing": "trailing"}),
74
+ "Heun": (HeunDiscreteScheduler, {}),
75
+ "Heun Karras": (HeunDiscreteScheduler, {"use_karras_sigmas": True}),
76
+ "LMS": (LMSDiscreteScheduler, {}),
77
+ "LMS Karras": (LMSDiscreteScheduler, {"use_karras_sigmas": True}),
78
+ "DDIM": (DDIMScheduler, {}),
79
+ "DDIM trailing": (DDIMScheduler, {"timestep_spacing": "trailing"}),
80
+ "DEIS": (DEISMultistepScheduler, {}),
81
+ "UniPC": (UniPCMultistepScheduler, {}),
82
+ "UniPC Karras": (UniPCMultistepScheduler, {"use_karras_sigmas": True}),
83
+ "PNDM": (PNDMScheduler, {}),
84
+ "Euler EDM": (EDMEulerScheduler, {}),
85
+ "Euler EDM Karras": (EDMEulerScheduler, {"use_karras_sigmas": True}),
86
+ "DPM++ 2M EDM": (EDMDPMSolverMultistepScheduler, {"solver_order": 2, "solver_type": "midpoint", "final_sigmas_type": "zero", "algorithm_type": "dpmsolver++"}),
87
+ "DPM++ 2M EDM Karras": (EDMDPMSolverMultistepScheduler, {"use_karras_sigmas": True, "solver_order": 2, "solver_type": "midpoint", "final_sigmas_type": "zero", "algorithm_type": "dpmsolver++"}),
88
+ "DDPM": (DDPMScheduler, {}),
89
+
90
+ "DPM++ 2M Lu": (DPMSolverMultistepScheduler, {"use_lu_lambdas": True}),
91
+ "DPM++ 2M Ef": (DPMSolverMultistepScheduler, {"euler_at_final": True}),
92
+ "DPM++ 2M SDE Lu": (DPMSolverMultistepScheduler, {"use_lu_lambdas": True, "algorithm_type": "sde-dpmsolver++"}),
93
+ "DPM++ 2M SDE Ef": (DPMSolverMultistepScheduler, {"algorithm_type": "sde-dpmsolver++", "euler_at_final": True}),
94
+
95
+ "LCM": (LCMScheduler, {}),
96
+ "TCD": (TCDScheduler, {}),
97
+ "LCM trailing": (LCMScheduler, {"timestep_spacing": "trailing"}),
98
+ "TCD trailing": (TCDScheduler, {"timestep_spacing": "trailing"}),
99
+ "LCM Auto-Loader": (LCMScheduler, {}),
100
+ "TCD Auto-Loader": (TCDScheduler, {}),
101
+ }
102
+
103
+
104
+ def get_scheduler_config(name: str):
105
+ if not name in SCHEDULER_CONFIG_MAP.keys(): return SCHEDULER_CONFIG_MAP["Euler a"]
106
+ return SCHEDULER_CONFIG_MAP[name]
107
+
108
+
109
+ def fuse_loras(pipe, lora_dict: dict, temp_dir: str, civitai_key: str="", dkwargs: dict={}):
110
+ if not lora_dict or not isinstance(lora_dict, dict): return pipe
111
+ a_list = []
112
+ w_list = []
113
+ for k, v in lora_dict.items():
114
+ if not k: continue
115
+ new_lora_file = get_download_file(temp_dir, k, civitai_key)
116
+ if not new_lora_file or not Path(new_lora_file).exists():
117
+ print(f"LoRA file not found: {k}")
118
+ continue
119
+ w_name = Path(new_lora_file).name
120
+ a_name = Path(new_lora_file).stem
121
+ pipe.load_lora_weights(new_lora_file, weight_name=w_name, adapter_name=a_name, low_cpu_mem_usage=False, **dkwargs)
122
+ a_list.append(a_name)
123
+ w_list.append(v)
124
+ if Path(new_lora_file).exists(): Path(new_lora_file).unlink()
125
+ if len(a_list) == 0: return pipe
126
+ pipe.set_adapters(a_list, adapter_weights=w_list)
127
+ pipe.fuse_lora(adapter_names=a_list, lora_scale=1.0)
128
+ pipe.unload_lora_weights()
129
+ return pipe
130
+
131
+
132
+ MODEL_TYPE_KEY = {
133
+ "model.diffusion_model.output_blocks.1.1.norm.bias": "SDXL",
134
+ "model.diffusion_model.input_blocks.11.0.out_layers.3.weight": "SD 1.5",
135
+ "double_blocks.0.img_attn.norm.key_norm.scale": "FLUX",
136
+ "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale": "FLUX",
137
+ "model.diffusion_model.joint_blocks.9.x_block.attn.ln_k.weight": "SD 3.5",
138
+ }
139
+
140
+
141
+ def get_model_type_from_key(path: str):
142
+ default = "SDXL"
143
+ try:
144
+ keys = read_safetensors_key(path)
145
+ for k, v in MODEL_TYPE_KEY.items():
146
+ if k in set(keys):
147
+ print(f"Model type is {v}.")
148
+ return v
149
+ print("Model type could not be identified.")
150
+ except Exception:
151
+ return default
152
+ return default
153
+
154
+
155
+ def get_process_dtype(dtype: str, model_type: str):
156
+ if dtype in set(["fp8"] + QTYPES): return torch.bfloat16 if model_type in ["FLUX", "SD 3.5"] else torch.float16
157
+ return DTYPE_DICT.get(dtype, torch.float16)
stkey.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from pathlib import Path
3
+ import json
4
+ import re
5
+ import gc
6
+ from safetensors.torch import load_file, save_file
7
+ import torch
8
+
9
+
10
+ SDXL_KEYS_FILE = "keys/sdxl_keys.txt"
11
+
12
+
13
+ def list_uniq(l):
14
+ return sorted(set(l), key=l.index)
15
+
16
+
17
+ def read_safetensors_metadata(path: str):
18
+ with open(path, 'rb') as f:
19
+ header_size = int.from_bytes(f.read(8), 'little')
20
+ header_json = f.read(header_size).decode('utf-8')
21
+ header = json.loads(header_json)
22
+ metadata = header.get('__metadata__', {})
23
+ return metadata
24
+
25
+
26
+ def keys_from_file(path: str):
27
+ keys = []
28
+ try:
29
+ with open(str(Path(path)), encoding='utf-8', mode='r') as f:
30
+ lines = f.readlines()
31
+ for line in lines:
32
+ keys.append(line.strip())
33
+ except Exception as e:
34
+ print(e)
35
+ finally:
36
+ return keys
37
+
38
+
39
+ def validate_keys(keys: list[str], rfile: str=SDXL_KEYS_FILE):
40
+ missing = []
41
+ added = []
42
+ try:
43
+ rkeys = keys_from_file(rfile)
44
+ all_keys = list_uniq(keys + rkeys)
45
+ for key in all_keys:
46
+ if key in set(rkeys) and key not in set(keys): missing.append(key)
47
+ if key in set(keys) and key not in set(rkeys): added.append(key)
48
+ except Exception as e:
49
+ print(e)
50
+ finally:
51
+ return missing, added
52
+
53
+
54
+ def read_safetensors_key(path: str):
55
+ try:
56
+ keys = []
57
+ state_dict = load_file(str(Path(path)))
58
+ for k in list(state_dict.keys()):
59
+ keys.append(k)
60
+ state_dict.pop(k)
61
+ except Exception as e:
62
+ print(e)
63
+ finally:
64
+ del state_dict
65
+ torch.cuda.empty_cache()
66
+ gc.collect()
67
+ return keys
68
+
69
+
70
+ def write_safetensors_key(keys: list[str], path: str, is_validate: bool=True, rpath: str=SDXL_KEYS_FILE):
71
+ if len(keys) == 0: return False
72
+ try:
73
+ with open(str(Path(path)), encoding='utf-8', mode='w') as f:
74
+ f.write("\n".join(keys))
75
+ if is_validate:
76
+ missing, added = validate_keys(keys, rpath)
77
+ with open(str(Path(path).stem + "_missing.txt"), encoding='utf-8', mode='w') as f:
78
+ f.write("\n".join(missing))
79
+ with open(str(Path(path).stem + "_added.txt"), encoding='utf-8', mode='w') as f:
80
+ f.write("\n".join(added))
81
+ return True
82
+ except Exception as e:
83
+ print(e)
84
+ return False
85
+
86
+
87
+ def stkey(input: str, out_filename: str="", is_validate: bool=True, rfile: str=SDXL_KEYS_FILE):
88
+ keys = read_safetensors_key(input)
89
+ if len(keys) != 0 and out_filename: write_safetensors_key(keys, out_filename, is_validate, rfile)
90
+ if len(keys) != 0:
91
+ print("Metadata:")
92
+ print(read_safetensors_metadata(input))
93
+ print("\nKeys:")
94
+ print("\n".join(keys))
95
+ if is_validate:
96
+ missing, added = validate_keys(keys, rfile)
97
+ print("\nMissing Keys:")
98
+ print("\n".join(missing))
99
+ print("\nAdded Keys:")
100
+ print("\n".join(added))
101
+
102
+
103
+ if __name__ == "__main__":
104
+ parser = argparse.ArgumentParser()
105
+ parser.add_argument("input", type=str, help="Input safetensors file.")
106
+ parser.add_argument("-s", "--save", action="store_true", default=False, help="Output to text file.")
107
+ parser.add_argument("-o", "--output", default="", type=str, help="Output to specific text file.")
108
+ parser.add_argument("-v", "--val", action="store_false", default=True, help="Disable key validation.")
109
+ parser.add_argument("-r", "--rfile", default=SDXL_KEYS_FILE, type=str, help="Specify reference file to validate keys.")
110
+
111
+ args = parser.parse_args()
112
+
113
+ if args.save: out_filename = Path(args.input).stem + ".txt"
114
+ out_filename = args.output if args.output else out_filename
115
+
116
+ stkey(args.input, out_filename, args.val, args.rfile)
117
+
118
+
119
+ # Usage:
120
+ # python stkey.py sd_xl_base_1.0_0.9vae.safetensors
121
+ # python stkey.py sd_xl_base_1.0_0.9vae.safetensors -s
122
+ # python stkey.py sd_xl_base_1.0_0.9vae.safetensors -o key.txt
utils.py ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from huggingface_hub import HfApi, HfFolder, hf_hub_download, snapshot_download
3
+ import os
4
+ from pathlib import Path
5
+ import shutil
6
+ import gc
7
+ import re
8
+ import urllib.parse
9
+ import subprocess
10
+ import time
11
+
12
+
13
+ def get_token():
14
+ try:
15
+ token = HfFolder.get_token()
16
+ except Exception:
17
+ token = ""
18
+ return token
19
+
20
+
21
+ def set_token(token):
22
+ try:
23
+ HfFolder.save_token(token)
24
+ except Exception:
25
+ print(f"Error: Failed to save token.")
26
+
27
+
28
+ def get_user_agent():
29
+ return 'Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:127.0) Gecko/20100101 Firefox/127.0'
30
+
31
+
32
+ def is_repo_exists(repo_id: str, repo_type: str="model"):
33
+ hf_token = get_token()
34
+ api = HfApi(token=hf_token)
35
+ try:
36
+ if api.repo_exists(repo_id=repo_id, repo_type=repo_type, token=hf_token): return True
37
+ else: return False
38
+ except Exception as e:
39
+ print(f"Error: Failed to connect {repo_id} ({repo_type}). {e}")
40
+ return True # for safe
41
+
42
+
43
+ MODEL_TYPE_CLASS = {
44
+ "diffusers:StableDiffusionPipeline": "SD 1.5",
45
+ "diffusers:StableDiffusionXLPipeline": "SDXL",
46
+ "diffusers:FluxPipeline": "FLUX",
47
+ }
48
+
49
+
50
+ def get_model_type(repo_id: str):
51
+ hf_token = get_token()
52
+ api = HfApi(token=hf_token)
53
+ lora_filename = "pytorch_lora_weights.safetensors"
54
+ diffusers_filename = "model_index.json"
55
+ default = "SDXL"
56
+ try:
57
+ if api.file_exists(repo_id=repo_id, filename=lora_filename, token=hf_token): return "LoRA"
58
+ if not api.file_exists(repo_id=repo_id, filename=diffusers_filename, token=hf_token): return "None"
59
+ model = api.model_info(repo_id=repo_id, token=hf_token)
60
+ tags = model.tags
61
+ for tag in tags:
62
+ if tag in MODEL_TYPE_CLASS.keys(): return MODEL_TYPE_CLASS.get(tag, default)
63
+ except Exception:
64
+ return default
65
+ return default
66
+
67
+
68
+ def list_uniq(l):
69
+ return sorted(set(l), key=l.index)
70
+
71
+
72
+ def list_sub(a, b):
73
+ return [e for e in a if e not in b]
74
+
75
+
76
+ def is_repo_name(s):
77
+ return re.fullmatch(r'^[\w_\-\.]+/[\w_\-\.]+$', s)
78
+
79
+
80
+ def get_hf_url(repo_id: str, repo_type: str="model"):
81
+ if repo_type == "dataset": url = f"https://huggingface.co/datasets/{repo_id}"
82
+ elif repo_type == "space": url = f"https://huggingface.co/spaces/{repo_id}"
83
+ else: url = f"https://huggingface.co/{repo_id}"
84
+ return url
85
+
86
+
87
+ def split_hf_url(url: str):
88
+ try:
89
+ s = list(re.findall(r'^(?:https?://huggingface.co/)(?:(datasets|spaces)/)?(.+?/.+?)/\w+?/.+?/(?:(.+)/)?(.+?.\w+)(?:\?download=true)?$', url)[0])
90
+ if len(s) < 4: return "", "", "", ""
91
+ repo_id = s[1]
92
+ if s[0] == "datasets": repo_type = "dataset"
93
+ elif s[0] == "spaces": repo_type = "space"
94
+ else: repo_type = "model"
95
+ subfolder = urllib.parse.unquote(s[2]) if s[2] else None
96
+ filename = urllib.parse.unquote(s[3])
97
+ return repo_id, filename, subfolder, repo_type
98
+ except Exception as e:
99
+ print(e)
100
+
101
+
102
+ def download_hf_file(directory, url, progress=gr.Progress(track_tqdm=True)):
103
+ hf_token = get_token()
104
+ repo_id, filename, subfolder, repo_type = split_hf_url(url)
105
+ try:
106
+ print(f"Downloading {url} to {directory}")
107
+ if subfolder is not None: path = hf_hub_download(repo_id=repo_id, filename=filename, subfolder=subfolder, repo_type=repo_type, local_dir=directory, token=hf_token)
108
+ else: path = hf_hub_download(repo_id=repo_id, filename=filename, repo_type=repo_type, local_dir=directory, token=hf_token)
109
+ return path
110
+ except Exception as e:
111
+ print(f"Failed to download: {e}")
112
+ return None
113
+
114
+
115
+ def download_thing(directory, url, civitai_api_key="", progress=gr.Progress(track_tqdm=True)): # requires aria2, gdown
116
+ url = url.strip()
117
+ if "drive.google.com" in url:
118
+ original_dir = os.getcwd()
119
+ os.chdir(directory)
120
+ os.system(f"gdown --fuzzy {url}")
121
+ os.chdir(original_dir)
122
+ elif "huggingface.co" in url:
123
+ url = url.replace("?download=true", "")
124
+ if "/blob/" in url: url = url.replace("/blob/", "/resolve/")
125
+ download_hf_file(directory, url)
126
+ elif "civitai.com" in url:
127
+ if "?" in url:
128
+ url = url.split("?")[0]
129
+ if civitai_api_key:
130
+ url = url + f"?token={civitai_api_key}"
131
+ os.system(f"aria2c --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 -d {directory} {url}")
132
+ else:
133
+ print("You need an API key to download Civitai models.")
134
+ else:
135
+ os.system(f"aria2c --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 -d {directory} {url}")
136
+
137
+
138
+ def get_local_file_list(dir_path):
139
+ file_list = []
140
+ for file in Path(dir_path).glob("**/*.*"):
141
+ if file.is_file():
142
+ file_path = str(file)
143
+ file_list.append(file_path)
144
+ return file_list
145
+
146
+
147
+ def get_download_file(temp_dir, url, civitai_key, progress=gr.Progress(track_tqdm=True)):
148
+ if not "http" in url and is_repo_name(url) and not Path(url).exists():
149
+ print(f"Use HF Repo: {url}")
150
+ new_file = url
151
+ elif not "http" in url and Path(url).exists():
152
+ print(f"Use local file: {url}")
153
+ new_file = url
154
+ elif Path(f"{temp_dir}/{url.split('/')[-1]}").exists():
155
+ print(f"File to download alreday exists: {url}")
156
+ new_file = f"{temp_dir}/{url.split('/')[-1]}"
157
+ else:
158
+ print(f"Start downloading: {url}")
159
+ before = get_local_file_list(temp_dir)
160
+ try:
161
+ download_thing(temp_dir, url.strip(), civitai_key)
162
+ except Exception:
163
+ print(f"Download failed: {url}")
164
+ return ""
165
+ after = get_local_file_list(temp_dir)
166
+ new_file = list_sub(after, before)[0] if list_sub(after, before) else ""
167
+ if not new_file:
168
+ print(f"Download failed: {url}")
169
+ return ""
170
+ print(f"Download completed: {url}")
171
+ return new_file
172
+
173
+
174
+ def download_repo(repo_id: str, dir_path: str, progress=gr.Progress(track_tqdm=True)): # for diffusers repo
175
+ hf_token = get_token()
176
+ try:
177
+ snapshot_download(repo_id=repo_id, local_dir=dir_path, token=hf_token, allow_patterns=["*.safetensors", "*.bin"],
178
+ ignore_patterns=["*.fp16.*", "/*.safetensors", "/*.bin"], force_download=True)
179
+ return True
180
+ except Exception as e:
181
+ print(f"Error: Failed to download {repo_id}. {e}")
182
+ gr.Warning(f"Error: Failed to download {repo_id}. {e}")
183
+ return False
184
+
185
+
186
+ def upload_repo(repo_id: str, dir_path: str, is_private: bool, progress=gr.Progress(track_tqdm=True)): # for diffusers repo
187
+ hf_token = get_token()
188
+ api = HfApi(token=hf_token)
189
+ try:
190
+ progress(0, desc="Start uploading...")
191
+ api.create_repo(repo_id=repo_id, token=hf_token, private=is_private, exist_ok=True)
192
+ for path in Path(dir_path).glob("*"):
193
+ if path.is_dir():
194
+ api.upload_folder(repo_id=repo_id, folder_path=str(path), path_in_repo=path.name, token=hf_token)
195
+ elif path.is_file():
196
+ api.upload_file(repo_id=repo_id, path_or_fileobj=str(path), path_in_repo=path.name, token=hf_token)
197
+ progress(1, desc="Uploaded.")
198
+ return get_hf_url(repo_id, "model")
199
+ except Exception as e:
200
+ print(f"Error: Failed to upload to {repo_id}. {e}")
201
+ return ""
202
+
203
+
204
+ HF_SUBFOLDER_NAME = ["None", "user_repo"]
205
+
206
+
207
+ def duplicate_hf_repo(src_repo: str, dst_repo: str, src_repo_type: str, dst_repo_type: str,
208
+ is_private: bool, subfolder_type: str=HF_SUBFOLDER_NAME[1], progress=gr.Progress(track_tqdm=True)):
209
+ hf_token = get_token()
210
+ api = HfApi(token=hf_token)
211
+ try:
212
+ if subfolder_type == "user_repo": subfolder = src_repo.replace("/", "_")
213
+ else: subfolder = ""
214
+ progress(0, desc="Start duplicating...")
215
+ api.create_repo(repo_id=dst_repo, repo_type=dst_repo_type, private=is_private, exist_ok=True, token=hf_token)
216
+ for path in api.list_repo_files(repo_id=src_repo, repo_type=src_repo_type, token=hf_token):
217
+ file = hf_hub_download(repo_id=src_repo, filename=path, repo_type=src_repo_type, token=hf_token)
218
+ if not Path(file).exists(): continue
219
+ if Path(file).is_dir(): # unused for now
220
+ api.upload_folder(repo_id=dst_repo, folder_path=file, path_in_repo=f"{subfolder}/{path}" if subfolder else path,
221
+ repo_type=dst_repo_type, token=hf_token)
222
+ elif Path(file).is_file():
223
+ api.upload_file(repo_id=dst_repo, path_or_fileobj=file, path_in_repo=f"{subfolder}/{path}" if subfolder else path,
224
+ repo_type=dst_repo_type, token=hf_token)
225
+ if Path(file).exists(): Path(file).unlink()
226
+ progress(1, desc="Duplicated.")
227
+ return f"{get_hf_url(dst_repo, dst_repo_type)}/tree/main/{subfolder}" if subfolder else get_hf_url(dst_repo, dst_repo_type)
228
+ except Exception as e:
229
+ print(f"Error: Failed to duplicate repo {src_repo} to {dst_repo}. {e}")
230
+ return ""
231
+
232
+
233
+ BASE_DIR = str(Path(__file__).resolve().parent.resolve())
234
+ CIVITAI_API_KEY = os.environ.get("CIVITAI_API_KEY")
235
+
236
+
237
+ def get_file(url: str, path: str): # requires aria2, gdown
238
+ print(f"Downloading {url} to {path}...")
239
+ get_download_file(path, url, CIVITAI_API_KEY)
240
+
241
+
242
+ def git_clone(url: str, path: str, pip: bool=False, addcmd: str=""): # requires git
243
+ os.makedirs(str(Path(BASE_DIR, path)), exist_ok=True)
244
+ os.chdir(Path(BASE_DIR, path))
245
+ print(f"Cloning {url} to {path}...")
246
+ cmd = f'git clone {url}'
247
+ print(f'Running {cmd} at {Path.cwd()}')
248
+ i = subprocess.run(cmd, shell=True).returncode
249
+ if i != 0: print(f'Error occured at running {cmd}')
250
+ p = url.split("/")[-1]
251
+ if not Path(p).exists: return
252
+ if pip:
253
+ os.chdir(Path(BASE_DIR, path, p))
254
+ cmd = f'pip install -r requirements.txt'
255
+ print(f'Running {cmd} at {Path.cwd()}')
256
+ i = subprocess.run(cmd, shell=True).returncode
257
+ if i != 0: print(f'Error occured at running {cmd}')
258
+ if addcmd:
259
+ os.chdir(Path(BASE_DIR, path, p))
260
+ cmd = addcmd
261
+ print(f'Running {cmd} at {Path.cwd()}')
262
+ i = subprocess.run(cmd, shell=True).returncode
263
+ if i != 0: print(f'Error occured at running {cmd}')
264
+
265
+
266
+ def run(cmd: str, timeout: float=0):
267
+ print(f'Running {cmd} at {Path.cwd()}')
268
+ if timeout == 0:
269
+ i = subprocess.run(cmd, shell=True).returncode
270
+ if i != 0: print(f'Error occured at running {cmd}')
271
+ else:
272
+ p = subprocess.Popen(cmd, shell=True)
273
+ time.sleep(timeout)
274
+ p.terminate()
275
+ print(f'Terminated in {timeout} seconds')