Spaces:
Running
Running
import gradio as gr | |
from huggingface_hub import HfApi, HfFolder, hf_hub_download | |
import os | |
from pathlib import Path | |
import shutil | |
import gc | |
import re | |
import urllib.parse | |
def get_token(): | |
try: | |
token = HfFolder.get_token() | |
except Exception: | |
token = "" | |
return token | |
def set_token(token): | |
try: | |
HfFolder.save_token(token) | |
except Exception: | |
print(f"Error: Failed to save token.") | |
def get_user_agent(): | |
return 'Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:127.0) Gecko/20100101 Firefox/127.0' | |
def is_repo_exists(repo_id: str, repo_type: str="model"): | |
hf_token = get_token() | |
api = HfApi(token=hf_token) | |
try: | |
if api.repo_exists(repo_id=repo_id, repo_type=repo_type, token=hf_token): return True | |
else: return False | |
except Exception as e: | |
print(f"Error: Failed to connect {repo_id} ({repo_type}). {e}") | |
return True # for safe | |
MODEL_TYPE_CLASS = { | |
"diffusers:StableDiffusionPipeline": "SD 1.5", | |
"diffusers:StableDiffusionXLPipeline": "SDXL", | |
"diffusers:FluxPipeline": "FLUX", | |
} | |
def get_model_type(repo_id: str): | |
hf_token = get_token() | |
api = HfApi(token=hf_token) | |
lora_filename = "pytorch_lora_weights.safetensors" | |
diffusers_filename = "model_index.json" | |
default = "SDXL" | |
try: | |
if api.file_exists(repo_id=repo_id, filename=lora_filename, token=hf_token): return "LoRA" | |
if not api.file_exists(repo_id=repo_id, filename=diffusers_filename, token=hf_token): return "None" | |
model = api.model_info(repo_id=repo_id, token=hf_token) | |
tags = model.tags | |
for tag in tags: | |
if tag in MODEL_TYPE_CLASS.keys(): return MODEL_TYPE_CLASS.get(tag, default) | |
except Exception: | |
return default | |
return default | |
def list_uniq(l): | |
return sorted(set(l), key=l.index) | |
def list_sub(a, b): | |
return [e for e in a if e not in b] | |
def is_repo_name(s): | |
return re.fullmatch(r'^[^/,\s\"\']+/[^/,\s\"\']+$', s) | |
def split_hf_url(url: str): | |
try: | |
s = list(re.findall(r'^(?:https?://huggingface.co/)(?:(datasets)/)?(.+?/.+?)/\w+?/.+?/(?:(.+)/)?(.+?.\w+)(?:\?download=true)?$', url)[0]) | |
if len(s) < 4: return "", "", "", "" | |
repo_id = s[1] | |
repo_type = "dataset" if s[0] == "datasets" else "model" | |
subfolder = urllib.parse.unquote(s[2]) if s[2] else None | |
filename = urllib.parse.unquote(s[3]) | |
return repo_id, filename, subfolder, repo_type | |
except Exception as e: | |
print(e) | |
def download_hf_file(directory, url, progress=gr.Progress(track_tqdm=True)): | |
hf_token = get_token() | |
repo_id, filename, subfolder, repo_type = split_hf_url(url) | |
try: | |
if subfolder is not None: hf_hub_download(repo_id=repo_id, filename=filename, subfolder=subfolder, repo_type=repo_type, local_dir=directory, token=hf_token) | |
else: hf_hub_download(repo_id=repo_id, filename=filename, repo_type=repo_type, local_dir=directory, token=hf_token) | |
except Exception as e: | |
print(f"Failed to download: {e}") | |
def download_thing(directory, url, civitai_api_key="", progress=gr.Progress(track_tqdm=True)): # requires aria2, gdown | |
hf_token = get_token() | |
url = url.strip() | |
if "drive.google.com" in url: | |
original_dir = os.getcwd() | |
os.chdir(directory) | |
os.system(f"gdown --fuzzy {url}") | |
os.chdir(original_dir) | |
elif "huggingface.co" in url: | |
url = url.replace("?download=true", "") | |
if "/blob/" in url: | |
url = url.replace("/blob/", "/resolve/") | |
#user_header = f'"Authorization: Bearer {hf_token}"' | |
if hf_token: | |
download_hf_file(directory, url) | |
#os.system(f"aria2c --console-log-level=error --summary-interval=10 --header={user_header} -c -x 16 -k 1M -s 16 {url} -d {directory} -o {url.split('/')[-1]}") | |
else: | |
os.system(f"aria2c --optimize-concurrent-downloads --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 {url} -d {directory} -o {url.split('/')[-1]}") | |
elif "civitai.com" in url: | |
if "?" in url: | |
url = url.split("?")[0] | |
if civitai_api_key: | |
url = url + f"?token={civitai_api_key}" | |
os.system(f"aria2c --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 -d {directory} {url}") | |
else: | |
print("You need an API key to download Civitai models.") | |
else: | |
os.system(f"aria2c --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 -d {directory} {url}") | |
def get_local_model_list(dir_path): | |
model_list = [] | |
valid_extensions = ('.safetensors') | |
for file in Path(dir_path).glob("**/*.*"): | |
if file.is_file() and file.suffix in valid_extensions: | |
file_path = str(file) | |
model_list.append(file_path) | |
return model_list | |
def get_download_file(temp_dir, url, civitai_key, progress=gr.Progress(track_tqdm=True)): | |
if not "http" in url and is_repo_name(url) and not Path(url).exists(): | |
print(f"Use HF Repo: {url}") | |
new_file = url | |
elif not "http" in url and Path(url).exists(): | |
print(f"Use local file: {url}") | |
new_file = url | |
elif Path(f"{temp_dir}/{url.split('/')[-1]}").exists(): | |
print(f"File to download alreday exists: {url}") | |
new_file = f"{temp_dir}/{url.split('/')[-1]}" | |
else: | |
print(f"Start downloading: {url}") | |
before = get_local_model_list(temp_dir) | |
try: | |
download_thing(temp_dir, url.strip(), civitai_key) | |
except Exception: | |
print(f"Download failed: {url}") | |
return "" | |
after = get_local_model_list(temp_dir) | |
new_file = list_sub(after, before)[0] if list_sub(after, before) else "" | |
if not new_file: | |
print(f"Download failed: {url}") | |
return "" | |
print(f"Download completed: {url}") | |
return new_file | |