John6666 commited on
Commit
c983a5f
1 Parent(s): 60f961e

Upload 8 files

Browse files
README.md CHANGED
@@ -4,7 +4,7 @@ emoji: 🐶
4
  colorFrom: yellow
5
  colorTo: red
6
  sdk: gradio
7
- sdk_version: 4.38.1
8
  app_file: app.py
9
  pinned: false
10
  ---
 
4
  colorFrom: yellow
5
  colorTo: red
6
  sdk: gradio
7
+ sdk_version: 4.44.0
8
  app_file: app.py
9
  pinned: false
10
  ---
app.py CHANGED
@@ -1,30 +1,40 @@
1
  import gradio as gr
2
  import os
3
- from convert_repo_to_safetensors_sd_gr import convert_repo_to_safetensors_multi_sd
4
  os.environ['HF_OUTPUT_REPO'] = 'John6666/safetensors_converting_test'
5
 
6
  css = """"""
7
 
8
- with gr.Blocks(theme="NoCrypt/miku@>=1.2.2", css=css) as demo:
9
  gr.Markdown(
10
  f"""
11
  - [A CLI version of this tool is available here](https://huggingface.co/spaces/John6666/convert_repo_to_safetensors_sd/tree/main/local).
12
  """)
13
  with gr.Column():
14
  repo_id = gr.Textbox(label="Repo ID", placeholder="author/model", value="", lines=1)
15
- is_half = gr.Checkbox(label="Half precision", value=True)
16
  is_upload = gr.Checkbox(label="Upload safetensors to HF Repo", info="Fast download, but files will be public.", value=False)
17
- uploaded_urls = gr.CheckboxGroup(visible=False, choices=[], value=None)
 
 
 
 
 
 
 
 
 
18
  run_button = gr.Button(value="Convert")
19
  st_file = gr.Files(label="Output", interactive=False)
20
  st_md = gr.Markdown()
 
21
 
22
  gr.on(
23
  triggers=[repo_id.submit, run_button.click],
24
  fn=convert_repo_to_safetensors_multi_sd,
25
- inputs=[repo_id, st_file, is_upload, uploaded_urls, is_half],
26
  outputs=[st_file, uploaded_urls, st_md],
27
  )
 
28
 
29
  demo.queue()
30
  demo.launch()
 
1
  import gradio as gr
2
  import os
3
+ from convert_repo_to_safetensors_sd_gr import convert_repo_to_safetensors_multi_sd, clear_safetensors
4
  os.environ['HF_OUTPUT_REPO'] = 'John6666/safetensors_converting_test'
5
 
6
  css = """"""
7
 
8
+ with gr.Blocks(theme="NoCrypt/miku@>=1.2.2", fill_width=True, css=css, delete_cache=(60, 3600)) as demo:
9
  gr.Markdown(
10
  f"""
11
  - [A CLI version of this tool is available here](https://huggingface.co/spaces/John6666/convert_repo_to_safetensors_sd/tree/main/local).
12
  """)
13
  with gr.Column():
14
  repo_id = gr.Textbox(label="Repo ID", placeholder="author/model", value="", lines=1)
 
15
  is_upload = gr.Checkbox(label="Upload safetensors to HF Repo", info="Fast download, but files will be public.", value=False)
16
+ with gr.Accordion("Advanced", open=False):
17
+ dtype = gr.Radio(label="Output data type", choices=["fp16", "fp32", "bf16", "default"], value="fp16")
18
+ with gr.Row():
19
+ hf_token = gr.Textbox(label="Your HF write token", placeholder="hf_...", value="", max_lines=1)
20
+ gr.Markdown("Your token is available at [hf.co/settings/tokens](https://huggingface.co/settings/tokens).")
21
+ with gr.Row():
22
+ newrepo_id = gr.Textbox(label="Upload repo ID", placeholder="yourid/newrepo", value="", max_lines=1)
23
+ newrepo_type = gr.Radio(label="Upload repo type", choices=["model", "dataset"], value="model")
24
+ is_private = gr.Checkbox(label="Create / Use private repo", value=True)
25
+ uploaded_urls = gr.CheckboxGroup(visible=False, choices=[], value=None) # hidden
26
  run_button = gr.Button(value="Convert")
27
  st_file = gr.Files(label="Output", interactive=False)
28
  st_md = gr.Markdown()
29
+ delete_button = gr.Button(value="Delete Safetensors")
30
 
31
  gr.on(
32
  triggers=[repo_id.submit, run_button.click],
33
  fn=convert_repo_to_safetensors_multi_sd,
34
+ inputs=[repo_id, hf_token, st_file, uploaded_urls, dtype, is_upload, newrepo_id, newrepo_type, is_private],
35
  outputs=[st_file, uploaded_urls, st_md],
36
  )
37
+ delete_button.click(clear_safetensors, None, [st_file], queue=False, show_api=False)
38
 
39
  demo.queue()
40
  demo.launch()
convert_repo_to_safetensors_sd.py CHANGED
@@ -281,7 +281,7 @@ def convert_text_enc_state_dict(text_enc_dict):
281
  return text_enc_dict
282
 
283
 
284
- def convert_diffusers_to_safetensors(model_path, checkpoint_path, half = True):
285
  # Path for safetensors
286
  unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.safetensors")
287
  vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.safetensors")
@@ -328,8 +328,10 @@ def convert_diffusers_to_safetensors(model_path, checkpoint_path, half = True):
328
 
329
  # Put together new checkpoint
330
  state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict}
331
- if half:
332
- state_dict = {k: v.half() for k, v in state_dict.items()}
 
 
333
 
334
  save_file(state_dict, checkpoint_path)
335
 
@@ -343,11 +345,11 @@ def download_repo(repo_id, dir_path):
343
  return
344
 
345
 
346
- def convert_repo_to_safetensors(repo_id, half = True):
347
  download_dir = f"{repo_id.split('/')[0]}_{repo_id.split('/')[-1]}"
348
  output_filename = f"{repo_id.split('/')[0]}_{repo_id.split('/')[-1]}.safetensors"
349
  download_repo(repo_id, download_dir)
350
- convert_diffusers_to_safetensors(download_dir, output_filename, half)
351
  return output_filename
352
 
353
 
@@ -355,9 +357,9 @@ if __name__ == "__main__":
355
  parser = argparse.ArgumentParser()
356
 
357
  parser.add_argument("--repo_id", default=None, type=str, required=True, help="HF Repo ID of the model to convert.")
358
- parser.add_argument("--half", default=True, help="Save weights in half precision.")
359
 
360
  args = parser.parse_args()
361
  assert args.repo_id is not None, "Must provide a Repo ID!"
362
 
363
- convert_repo_to_safetensors(args.repo_id, args.half)
 
281
  return text_enc_dict
282
 
283
 
284
+ def convert_diffusers_to_safetensors(model_path, checkpoint_path, dtype="fp16"):
285
  # Path for safetensors
286
  unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.safetensors")
287
  vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.safetensors")
 
328
 
329
  # Put together new checkpoint
330
  state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict}
331
+
332
+ if dtype == "fp16": state_dict = {k: v.half() for k, v in state_dict.items()}
333
+ elif dtype == "fp32": state_dict = {k: v.to(torch.float32) for k, v in state_dict.items()}
334
+ elif dtype == "bf16": state_dict = {k: v.to(torch.bfloat16) for k, v in state_dict.items()}
335
 
336
  save_file(state_dict, checkpoint_path)
337
 
 
345
  return
346
 
347
 
348
+ def convert_repo_to_safetensors(repo_id, dtype="fp16"):
349
  download_dir = f"{repo_id.split('/')[0]}_{repo_id.split('/')[-1]}"
350
  output_filename = f"{repo_id.split('/')[0]}_{repo_id.split('/')[-1]}.safetensors"
351
  download_repo(repo_id, download_dir)
352
+ convert_diffusers_to_safetensors(download_dir, output_filename, dtype)
353
  return output_filename
354
 
355
 
 
357
  parser = argparse.ArgumentParser()
358
 
359
  parser.add_argument("--repo_id", default=None, type=str, required=True, help="HF Repo ID of the model to convert.")
360
+ parser.add_argument("--dtype", default="fp16", type=str, choices=["fp16", "fp32", "bf16", "default"], help='Output data type. (Default: "fp16")')
361
 
362
  args = parser.parse_args()
363
  assert args.repo_id is not None, "Must provide a Repo ID!"
364
 
365
+ convert_repo_to_safetensors(args.repo_id, args.dtype)
convert_repo_to_safetensors_sd_gr.py CHANGED
@@ -10,6 +10,12 @@ import torch
10
  from safetensors.torch import load_file, save_file
11
  import gradio as gr
12
 
 
 
 
 
 
 
13
 
14
  # =================#
15
  # UNet Conversion #
@@ -282,8 +288,7 @@ def convert_text_enc_state_dict(text_enc_dict):
282
  return text_enc_dict
283
 
284
 
285
- def convert_diffusers_to_safetensors(model_path, checkpoint_path, half = True, progress=gr.Progress(track_tqdm=True)):
286
- progress(0, desc="Start converting...")
287
  # Path for safetensors
288
  unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.safetensors")
289
  vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.safetensors")
@@ -330,57 +335,68 @@ def convert_diffusers_to_safetensors(model_path, checkpoint_path, half = True, p
330
 
331
  # Put together new checkpoint
332
  state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict}
333
- if half:
334
- state_dict = {k: v.half() for k, v in state_dict.items()}
 
 
335
 
336
  save_file(state_dict, checkpoint_path)
337
 
338
- progress(1, desc="Converted.")
339
-
340
 
341
- def download_repo(repo_id, dir_path, progress=gr.Progress(track_tqdm=True)):
342
- from huggingface_hub import snapshot_download
 
343
  try:
344
- snapshot_download(repo_id=repo_id, local_dir=dir_path)
 
345
  except Exception as e:
346
- print(f"Error: Failed to download {repo_id}. ")
 
347
  return
348
 
349
 
350
- def upload_safetensors_to_repo(filename, progress=gr.Progress(track_tqdm=True)):
351
- from huggingface_hub import HfApi, hf_hub_url
352
- import os
353
- from pathlib import Path
354
  output_filename = Path(filename).name
355
- hf_token = os.environ.get("HF_TOKEN")
356
- repo_id = os.environ.get("HF_OUTPUT_REPO")
357
- api = HfApi()
358
  try:
 
359
  progress(0, desc="Start uploading...")
360
- api.upload_file(path_or_fileobj=filename, path_in_repo=output_filename, repo_id=repo_id, token=hf_token)
361
  progress(1, desc="Uploaded.")
362
- url = hf_hub_url(repo_id=repo_id, filename=output_filename)
363
  except Exception as e:
364
- print(f"Error: Failed to upload to {repo_id}. ")
 
365
  return None
366
  return url
367
 
368
 
369
- def convert_repo_to_safetensors(repo_id, half = True, progress=gr.Progress(track_tqdm=True)):
370
  download_dir = f"{repo_id.split('/')[0]}_{repo_id.split('/')[-1]}"
371
  output_filename = f"{repo_id.split('/')[0]}_{repo_id.split('/')[-1]}.safetensors"
 
372
  download_repo(repo_id, download_dir)
373
- convert_diffusers_to_safetensors(download_dir, output_filename, half)
 
 
 
374
  return output_filename
375
 
376
 
377
- def convert_repo_to_safetensors_multi_sd(repo_id, files, is_upload, urls, half=True, progress=gr.Progress(track_tqdm=True)):
378
- file = convert_repo_to_safetensors(repo_id, half)
 
 
 
 
 
379
  if not urls: urls = []
380
  url = ""
381
  if is_upload:
382
- url = upload_safetensors_to_repo(file)
383
  if url: urls.append(url)
 
384
  md = ""
385
  for u in urls:
386
  md += f"[Download {str(u).split('/')[-1]}]({str(u)})<br>"
@@ -389,13 +405,21 @@ def convert_repo_to_safetensors_multi_sd(repo_id, files, is_upload, urls, half=T
389
  return gr.update(value=files), gr.update(value=urls, choices=urls), gr.update(value=md)
390
 
391
 
 
 
 
 
 
 
 
 
392
  if __name__ == "__main__":
393
  parser = argparse.ArgumentParser()
394
 
395
  parser.add_argument("--repo_id", default=None, type=str, required=True, help="HF Repo ID of the model to convert.")
396
- parser.add_argument("--half", default=True, help="Save weights in half precision.")
397
 
398
  args = parser.parse_args()
399
  assert args.repo_id is not None, "Must provide a Repo ID!"
400
 
401
- convert_repo_to_safetensors(args.repo_id, args.half)
 
10
  from safetensors.torch import load_file, save_file
11
  import gradio as gr
12
 
13
+ from huggingface_hub import HfApi, HfFolder, hf_hub_url, snapshot_download
14
+ import os
15
+ from pathlib import Path
16
+ import shutil
17
+ import gc
18
+ from utils import get_token, set_token, is_repo_exists
19
 
20
  # =================#
21
  # UNet Conversion #
 
288
  return text_enc_dict
289
 
290
 
291
+ def convert_diffusers_to_safetensors(model_path, checkpoint_path, dtype="fp16", progress=gr.Progress(track_tqdm=True)):
 
292
  # Path for safetensors
293
  unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.safetensors")
294
  vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.safetensors")
 
335
 
336
  # Put together new checkpoint
337
  state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict}
338
+
339
+ if dtype == "fp16": state_dict = {k: v.half() for k, v in state_dict.items()}
340
+ elif dtype == "fp32": state_dict = {k: v.to(torch.float32) for k, v in state_dict.items()}
341
+ elif dtype == "bf16": state_dict = {k: v.to(torch.bfloat16) for k, v in state_dict.items()}
342
 
343
  save_file(state_dict, checkpoint_path)
344
 
 
 
345
 
346
+ # https://huggingface.co/docs/huggingface_hub/v0.25.1/en/package_reference/file_download#huggingface_hub.snapshot_download
347
+ def download_repo(repo_id, dir_path):
348
+ hf_token = get_token()
349
  try:
350
+ snapshot_download(repo_id=repo_id, local_dir=dir_path, token=hf_token, allow_patterns=["*.safetensors", "*.bin"],
351
+ ignore_patterns=["*.fp16.*", "/*.safetensors", "/*.bin"])
352
  except Exception as e:
353
+ print(f"Error: Failed to download {repo_id}. {e}")
354
+ gr.Warning(f"Error: Failed to download {repo_id}. {e}")
355
  return
356
 
357
 
358
+ def upload_safetensors_to_repo(filename, repo_id, repo_type, is_private, progress=gr.Progress(track_tqdm=True)):
 
 
 
359
  output_filename = Path(filename).name
360
+ hf_token = get_token()
361
+ api = HfApi(token=hf_token)
 
362
  try:
363
+ if not is_repo_exists(repo_id, repo_type): api.create_repo(repo_id=repo_id, repo_type=repo_type, token=hf_token, private=is_private)
364
  progress(0, desc="Start uploading...")
365
+ api.upload_file(path_or_fileobj=filename, path_in_repo=output_filename, repo_type=repo_type, revision="main", token=hf_token, repo_id=repo_id)
366
  progress(1, desc="Uploaded.")
367
+ url = hf_hub_url(repo_id=repo_id, repo_type=repo_type, filename=output_filename)
368
  except Exception as e:
369
+ print(f"Error: Failed to upload to {repo_id}. {e}")
370
+ gr.Warning(f"Error: Failed to upload to {repo_id}. {e}")
371
  return None
372
  return url
373
 
374
 
375
+ def convert_repo_to_safetensors(repo_id, dtype="fp16", progress=gr.Progress(track_tqdm=True)):
376
  download_dir = f"{repo_id.split('/')[0]}_{repo_id.split('/')[-1]}"
377
  output_filename = f"{repo_id.split('/')[0]}_{repo_id.split('/')[-1]}.safetensors"
378
+ progress(0, desc="Start downloading...")
379
  download_repo(repo_id, download_dir)
380
+ progress(0, desc="Start converting...")
381
+ convert_diffusers_to_safetensors(download_dir, output_filename, dtype)
382
+ progress(1, desc="Converted.")
383
+ shutil.rmtree(download_dir)
384
  return output_filename
385
 
386
 
387
+ def convert_repo_to_safetensors_multi_sd(repo_id, hf_token, files, urls, dtype="fp16", is_upload=False,
388
+ newrepo_id="", repo_type="model", is_private=True, progress=gr.Progress(track_tqdm=True)):
389
+ if hf_token: set_token(hf_token)
390
+ else: set_token(os.environ.get("HF_TOKEN"))
391
+ if is_upload and newrepo_id and not hf_token: raise gr.Error("HF write token is required for this process.")
392
+ if not newrepo_id: newrepo_id = os.environ.get("HF_OUTPUT_REPO")
393
+ file = convert_repo_to_safetensors(repo_id, dtype)
394
  if not urls: urls = []
395
  url = ""
396
  if is_upload:
397
+ url = upload_safetensors_to_repo(file, newrepo_id, repo_type, is_private)
398
  if url: urls.append(url)
399
+ progress(1, desc="Processing...")
400
  md = ""
401
  for u in urls:
402
  md += f"[Download {str(u).split('/')[-1]}]({str(u)})<br>"
 
405
  return gr.update(value=files), gr.update(value=urls, choices=urls), gr.update(value=md)
406
 
407
 
408
+ def clear_safetensors():
409
+ for p in Path('.').glob('*.safetensors'):
410
+ p.unlink()
411
+ print("Deleted.")
412
+ gc.collect()
413
+ return gr.update(value=[])
414
+
415
+
416
  if __name__ == "__main__":
417
  parser = argparse.ArgumentParser()
418
 
419
  parser.add_argument("--repo_id", default=None, type=str, required=True, help="HF Repo ID of the model to convert.")
420
+ parser.add_argument("--dtype", default="fp16", type=str, choices=["fp16", "fp32", "bf16", "default"], help='Output data type. (Default: "fp16")')
421
 
422
  args = parser.parse_args()
423
  assert args.repo_id is not None, "Must provide a Repo ID!"
424
 
425
+ convert_repo_to_safetensors(args.repo_id, args.dtype)
utils.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from huggingface_hub import HfApi, HfFolder, hf_hub_download
3
+ import os
4
+ from pathlib import Path
5
+ import shutil
6
+ import gc
7
+ import re
8
+ import urllib.parse
9
+
10
+
11
+ def get_token():
12
+ try:
13
+ token = HfFolder.get_token()
14
+ except Exception:
15
+ token = ""
16
+ return token
17
+
18
+
19
+ def set_token(token):
20
+ try:
21
+ HfFolder.save_token(token)
22
+ except Exception:
23
+ print(f"Error: Failed to save token.")
24
+
25
+
26
+ def get_user_agent():
27
+ return 'Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:127.0) Gecko/20100101 Firefox/127.0'
28
+
29
+
30
+ def is_repo_exists(repo_id: str, repo_type: str="model"):
31
+ hf_token = get_token()
32
+ api = HfApi(token=hf_token)
33
+ try:
34
+ if api.repo_exists(repo_id=repo_id, repo_type=repo_type, token=hf_token): return True
35
+ else: return False
36
+ except Exception as e:
37
+ print(f"Error: Failed to connect {repo_id} ({repo_type}). {e}")
38
+ return True # for safe
39
+
40
+
41
+ MODEL_TYPE_CLASS = {
42
+ "diffusers:StableDiffusionPipeline": "SD 1.5",
43
+ "diffusers:StableDiffusionXLPipeline": "SDXL",
44
+ "diffusers:FluxPipeline": "FLUX",
45
+ }
46
+
47
+
48
+ def get_model_type(repo_id: str):
49
+ hf_token = get_token()
50
+ api = HfApi(token=hf_token)
51
+ lora_filename = "pytorch_lora_weights.safetensors"
52
+ diffusers_filename = "model_index.json"
53
+ default = "SDXL"
54
+ try:
55
+ if api.file_exists(repo_id=repo_id, filename=lora_filename, token=hf_token): return "LoRA"
56
+ if not api.file_exists(repo_id=repo_id, filename=diffusers_filename, token=hf_token): return "None"
57
+ model = api.model_info(repo_id=repo_id, token=hf_token)
58
+ tags = model.tags
59
+ for tag in tags:
60
+ if tag in MODEL_TYPE_CLASS.keys(): return MODEL_TYPE_CLASS.get(tag, default)
61
+ except Exception:
62
+ return default
63
+ return default
64
+
65
+
66
+ def list_sub(a, b):
67
+ return [e for e in a if e not in b]
68
+
69
+
70
+ def is_repo_name(s):
71
+ return re.fullmatch(r'^[^/,\s\"\']+/[^/,\s\"\']+$', s)
72
+
73
+
74
+ def split_hf_url(url: str):
75
+ try:
76
+ s = list(re.findall(r'^(?:https?://huggingface.co/)(?:(datasets)/)?(.+?/.+?)/\w+?/.+?/(?:(.+)/)?(.+?.safetensors)(?:\?download=true)?$', url)[0])
77
+ if len(s) < 4: return "", "", "", ""
78
+ repo_id = s[1]
79
+ repo_type = "dataset" if s[0] == "datasets" else "model"
80
+ subfolder = urllib.parse.unquote(s[2]) if s[2] else None
81
+ filename = urllib.parse.unquote(s[3])
82
+ return repo_id, filename, subfolder, repo_type
83
+ except Exception as e:
84
+ print(e)
85
+
86
+
87
+ def download_hf_file(directory, url, progress=gr.Progress(track_tqdm=True)):
88
+ hf_token = get_token()
89
+ repo_id, filename, subfolder, repo_type = split_hf_url(url)
90
+ try:
91
+ if subfolder is not None: hf_hub_download(repo_id=repo_id, filename=filename, subfolder=subfolder, repo_type=repo_type, local_dir=directory, token=hf_token)
92
+ else: hf_hub_download(repo_id=repo_id, filename=filename, repo_type=repo_type, local_dir=directory, token=hf_token)
93
+ except Exception as e:
94
+ print(f"Failed to download: {e}")
95
+
96
+
97
+ def download_thing(directory, url, civitai_api_key="", progress=gr.Progress(track_tqdm=True)): # requires aria2, gdown
98
+ hf_token = get_token()
99
+ url = url.strip()
100
+ if "drive.google.com" in url:
101
+ original_dir = os.getcwd()
102
+ os.chdir(directory)
103
+ os.system(f"gdown --fuzzy {url}")
104
+ os.chdir(original_dir)
105
+ elif "huggingface.co" in url:
106
+ url = url.replace("?download=true", "")
107
+ if "/blob/" in url:
108
+ url = url.replace("/blob/", "/resolve/")
109
+ #user_header = f'"Authorization: Bearer {hf_token}"'
110
+ if hf_token:
111
+ download_hf_file(directory, url)
112
+ #os.system(f"aria2c --console-log-level=error --summary-interval=10 --header={user_header} -c -x 16 -k 1M -s 16 {url} -d {directory} -o {url.split('/')[-1]}")
113
+ else:
114
+ os.system(f"aria2c --optimize-concurrent-downloads --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 {url} -d {directory} -o {url.split('/')[-1]}")
115
+ elif "civitai.com" in url:
116
+ if "?" in url:
117
+ url = url.split("?")[0]
118
+ if civitai_api_key:
119
+ url = url + f"?token={civitai_api_key}"
120
+ os.system(f"aria2c --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 -d {directory} {url}")
121
+ else:
122
+ print("You need an API key to download Civitai models.")
123
+ else:
124
+ os.system(f"aria2c --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 -d {directory} {url}")
125
+
126
+
127
+ def get_local_model_list(dir_path):
128
+ model_list = []
129
+ valid_extensions = ('.safetensors')
130
+ for file in Path(dir_path).glob("**/*.*"):
131
+ if file.is_file() and file.suffix in valid_extensions:
132
+ file_path = str(file)
133
+ model_list.append(file_path)
134
+ return model_list
135
+
136
+
137
+ def get_download_file(temp_dir, url, civitai_key, progress=gr.Progress(track_tqdm=True)):
138
+ if not "http" in url and is_repo_name(url) and not Path(url).exists():
139
+ print(f"Use HF Repo: {url}")
140
+ new_file = url
141
+ elif not "http" in url and Path(url).exists():
142
+ print(f"Use local file: {url}")
143
+ new_file = url
144
+ elif Path(f"{temp_dir}/{url.split('/')[-1]}").exists():
145
+ print(f"File to download alreday exists: {url}")
146
+ new_file = f"{temp_dir}/{url.split('/')[-1]}"
147
+ else:
148
+ print(f"Start downloading: {url}")
149
+ before = get_local_model_list(temp_dir)
150
+ try:
151
+ download_thing(temp_dir, url.strip(), civitai_key)
152
+ except Exception:
153
+ print(f"Download failed: {url}")
154
+ return ""
155
+ after = get_local_model_list(temp_dir)
156
+ new_file = list_sub(after, before)[0] if list_sub(after, before) else ""
157
+ if not new_file:
158
+ print(f"Download failed: {url}")
159
+ return ""
160
+ print(f"Download completed: {url}")
161
+ return new_file