Spaces:
Runtime error
Runtime error
Upload 12 files
Browse files- app.py +8 -4
- env.py +3 -1
- flux.py +17 -16
- mod.py +1 -1
- modutils.py +93 -42
app.py
CHANGED
@@ -505,15 +505,19 @@ with gr.Blocks(theme='Nymbo/Nymbo_Theme', fill_width=True, css=css, delete_cache
|
|
505 |
lora_md[i] = gr.Markdown(value="", visible=False)
|
506 |
lora_num[i] = gr.Number(i, visible=False)
|
507 |
with gr.Accordion("From URL", open=True, visible=True):
|
|
|
|
|
|
|
|
|
508 |
with gr.Row():
|
509 |
lora_search_civitai_query = gr.Textbox(label="Query", placeholder="flux", lines=1)
|
|
|
510 |
lora_search_civitai_submit = gr.Button("Search on Civitai")
|
511 |
-
lora_search_civitai_basemodel = gr.CheckboxGroup(label="Search LoRA for", choices=["Flux.1 D", "Flux.1 S"], value=["Flux.1 D", "Flux.1 S"])
|
512 |
with gr.Row():
|
513 |
lora_search_civitai_json = gr.JSON(value={}, visible=False)
|
514 |
lora_search_civitai_desc = gr.Markdown(value="", visible=False)
|
515 |
lora_search_civitai_result = gr.Dropdown(label="Search Results", choices=[("", "")], value="", allow_custom_value=True, visible=False)
|
516 |
-
lora_download_url = gr.Textbox(label="URL", placeholder="
|
517 |
with gr.Row():
|
518 |
lora_download = [None] * num_loras
|
519 |
for i in range(num_loras):
|
@@ -591,9 +595,9 @@ with gr.Blocks(theme='Nymbo/Nymbo_Theme', fill_width=True, css=css, delete_cache
|
|
591 |
prompt_enhance.click(enhance_prompt, [prompt], [prompt], queue=False, show_api=False)
|
592 |
|
593 |
gr.on(
|
594 |
-
triggers=[lora_search_civitai_submit.click, lora_search_civitai_query.submit],
|
595 |
fn=search_civitai_lora,
|
596 |
-
inputs=[lora_search_civitai_query, lora_search_civitai_basemodel],
|
597 |
outputs=[lora_search_civitai_result, lora_search_civitai_desc, lora_search_civitai_submit, lora_search_civitai_query],
|
598 |
scroll_to_output=True,
|
599 |
queue=True,
|
|
|
505 |
lora_md[i] = gr.Markdown(value="", visible=False)
|
506 |
lora_num[i] = gr.Number(i, visible=False)
|
507 |
with gr.Accordion("From URL", open=True, visible=True):
|
508 |
+
with gr.Row():
|
509 |
+
lora_search_civitai_basemodel = gr.CheckboxGroup(label="Search LoRA for", choices=["Flux.1 D", "Flux.1 S"], value=["Flux.1 D", "Flux.1 S"])
|
510 |
+
lora_search_civitai_sort = gr.Radio(label="Sort", choices=["Highest Rated", "Most Downloaded", "Newest"], value="Highest Rated")
|
511 |
+
lora_search_civitai_period = gr.Radio(label="Period", choices=["AllTime", "Year", "Month", "Week", "Day"], value="AllTime")
|
512 |
with gr.Row():
|
513 |
lora_search_civitai_query = gr.Textbox(label="Query", placeholder="flux", lines=1)
|
514 |
+
lora_search_civitai_tag = gr.Textbox(label="Tag", lines=1)
|
515 |
lora_search_civitai_submit = gr.Button("Search on Civitai")
|
|
|
516 |
with gr.Row():
|
517 |
lora_search_civitai_json = gr.JSON(value={}, visible=False)
|
518 |
lora_search_civitai_desc = gr.Markdown(value="", visible=False)
|
519 |
lora_search_civitai_result = gr.Dropdown(label="Search Results", choices=[("", "")], value="", allow_custom_value=True, visible=False)
|
520 |
+
lora_download_url = gr.Textbox(label="LoRA URL", placeholder="https://civitai.com/api/download/models/28907", lines=1)
|
521 |
with gr.Row():
|
522 |
lora_download = [None] * num_loras
|
523 |
for i in range(num_loras):
|
|
|
595 |
prompt_enhance.click(enhance_prompt, [prompt], [prompt], queue=False, show_api=False)
|
596 |
|
597 |
gr.on(
|
598 |
+
triggers=[lora_search_civitai_submit.click, lora_search_civitai_query.submit, lora_search_civitai_tag.submit],
|
599 |
fn=search_civitai_lora,
|
600 |
+
inputs=[lora_search_civitai_query, lora_search_civitai_basemodel, lora_search_civitai_sort, lora_search_civitai_period, lora_search_civitai_tag],
|
601 |
outputs=[lora_search_civitai_result, lora_search_civitai_desc, lora_search_civitai_submit, lora_search_civitai_query],
|
602 |
scroll_to_output=True,
|
603 |
queue=True,
|
env.py
CHANGED
@@ -2,7 +2,7 @@ import os
|
|
2 |
|
3 |
|
4 |
CIVITAI_API_KEY = os.environ.get("CIVITAI_API_KEY")
|
5 |
-
|
6 |
hf_read_token = os.environ.get('HF_READ_TOKEN') # only use for private repo
|
7 |
|
8 |
|
@@ -67,6 +67,7 @@ HF_MODEL_USER_LIKES = [] # sorted by number of likes
|
|
67 |
HF_MODEL_USER_EX = [] # sorted by a special rule
|
68 |
|
69 |
|
|
|
70 |
# - **Download Models**
|
71 |
download_model_list = [
|
72 |
]
|
@@ -79,6 +80,7 @@ download_vae_list = [
|
|
79 |
download_lora_list = [
|
80 |
]
|
81 |
|
|
|
82 |
|
83 |
directory_models = 'models'
|
84 |
os.makedirs(directory_models, exist_ok=True)
|
|
|
2 |
|
3 |
|
4 |
CIVITAI_API_KEY = os.environ.get("CIVITAI_API_KEY")
|
5 |
+
HF_TOKEN = os.environ.get("HF_TOKEN")
|
6 |
hf_read_token = os.environ.get('HF_READ_TOKEN') # only use for private repo
|
7 |
|
8 |
|
|
|
67 |
HF_MODEL_USER_EX = [] # sorted by a special rule
|
68 |
|
69 |
|
70 |
+
|
71 |
# - **Download Models**
|
72 |
download_model_list = [
|
73 |
]
|
|
|
80 |
download_lora_list = [
|
81 |
]
|
82 |
|
83 |
+
DIFFUSERS_FORMAT_LORAS = []
|
84 |
|
85 |
directory_models = 'models'
|
86 |
os.makedirs(directory_models, exist_ok=True)
|
flux.py
CHANGED
@@ -11,14 +11,15 @@ warnings.filterwarnings(action="ignore", category=FutureWarning, module="diffuse
|
|
11 |
warnings.filterwarnings(action="ignore", category=UserWarning, module="diffusers")
|
12 |
warnings.filterwarnings(action="ignore", category=FutureWarning, module="transformers")
|
13 |
from pathlib import Path
|
14 |
-
from
|
|
|
15 |
CIVITAI_API_KEY, HF_LORA_PRIVATE_REPOS1, HF_LORA_PRIVATE_REPOS2, HF_LORA_ESSENTIAL_PRIVATE_REPO,
|
16 |
HF_VAE_PRIVATE_REPO, directory_models, directory_loras, directory_vaes,
|
17 |
download_model_list, download_lora_list, download_vae_list)
|
18 |
from modutils import (to_list, list_uniq, list_sub, get_lora_model_list, download_private_repo,
|
19 |
safe_float, escape_lora_basename, to_lora_key, to_lora_path, get_local_model_list,
|
20 |
get_private_lora_model_lists, get_valid_lora_name, get_valid_lora_path, get_valid_lora_wt,
|
21 |
-
get_lora_info, normalize_prompt_list, get_civitai_info, search_lora_on_civitai)
|
22 |
|
23 |
|
24 |
def download_things(directory, url, hf_token="", civitai_api_key=""):
|
@@ -38,7 +39,7 @@ def download_things(directory, url, hf_token="", civitai_api_key=""):
|
|
38 |
if hf_token:
|
39 |
os.system(f"aria2c --console-log-level=error --summary-interval=10 --header={user_header} -c -x 16 -k 1M -s 16 {url} -d {directory} -o {url.split('/')[-1]}")
|
40 |
else:
|
41 |
-
os.system
|
42 |
elif "civitai.com" in url:
|
43 |
if "?" in url:
|
44 |
url = url.split("?")[0]
|
@@ -94,14 +95,18 @@ vae_model_list = get_model_list(directory_vaes)
|
|
94 |
vae_model_list.insert(0, "None")
|
95 |
|
96 |
|
|
|
|
|
|
|
|
|
|
|
97 |
def get_t2i_model_info(repo_id: str):
|
98 |
-
|
99 |
-
api = HfApi()
|
100 |
try:
|
101 |
-
if
|
102 |
-
model = api.model_info(repo_id=repo_id)
|
103 |
except Exception as e:
|
104 |
-
print(f"Error: Failed to get {repo_id}'s info.
|
105 |
print(e)
|
106 |
return ""
|
107 |
if model.private or model.gated: return ""
|
@@ -109,12 +114,8 @@ def get_t2i_model_info(repo_id: str):
|
|
109 |
info = []
|
110 |
url = f"https://huggingface.co/{repo_id}/"
|
111 |
if not 'diffusers' in tags: return ""
|
112 |
-
|
113 |
-
info.append(
|
114 |
-
elif 'diffusers:StableDiffusionXLPipeline' in tags:
|
115 |
-
info.append("SDXL")
|
116 |
-
elif 'diffusers:StableDiffusionPipeline' in tags:
|
117 |
-
info.append("SD1.5")
|
118 |
if model.card_data and model.card_data.tags:
|
119 |
info.extend(list_sub(model.card_data.tags, ['text-to-image', 'stable-diffusion', 'stable-diffusion-api', 'safetensors', 'stable-diffusion-xl']))
|
120 |
info.append(f"DLs: {model.downloads}")
|
@@ -246,9 +247,9 @@ def update_loras(prompt, lora, lora_wt):
|
|
246 |
gr.update(value=tag, label=label, visible=on), gr.update(value=md, visible=on)
|
247 |
|
248 |
|
249 |
-
def search_civitai_lora(query, base_model):
|
250 |
global civitai_lora_last_results
|
251 |
-
items = search_lora_on_civitai(query, base_model)
|
252 |
if not items: return gr.update(choices=[("", "")], value="", visible=False),\
|
253 |
gr.update(value="", visible=False), gr.update(visible=True), gr.update(visible=True)
|
254 |
civitai_lora_last_results = {}
|
|
|
11 |
warnings.filterwarnings(action="ignore", category=UserWarning, module="diffusers")
|
12 |
warnings.filterwarnings(action="ignore", category=FutureWarning, module="transformers")
|
13 |
from pathlib import Path
|
14 |
+
from huggingface_hub import HfApi
|
15 |
+
from env import (HF_TOKEN, hf_read_token, # to use only for private repos
|
16 |
CIVITAI_API_KEY, HF_LORA_PRIVATE_REPOS1, HF_LORA_PRIVATE_REPOS2, HF_LORA_ESSENTIAL_PRIVATE_REPO,
|
17 |
HF_VAE_PRIVATE_REPO, directory_models, directory_loras, directory_vaes,
|
18 |
download_model_list, download_lora_list, download_vae_list)
|
19 |
from modutils import (to_list, list_uniq, list_sub, get_lora_model_list, download_private_repo,
|
20 |
safe_float, escape_lora_basename, to_lora_key, to_lora_path, get_local_model_list,
|
21 |
get_private_lora_model_lists, get_valid_lora_name, get_valid_lora_path, get_valid_lora_wt,
|
22 |
+
get_lora_info, normalize_prompt_list, get_civitai_info, search_lora_on_civitai, MODEL_TYPE_DICT)
|
23 |
|
24 |
|
25 |
def download_things(directory, url, hf_token="", civitai_api_key=""):
|
|
|
39 |
if hf_token:
|
40 |
os.system(f"aria2c --console-log-level=error --summary-interval=10 --header={user_header} -c -x 16 -k 1M -s 16 {url} -d {directory} -o {url.split('/')[-1]}")
|
41 |
else:
|
42 |
+
os.system(f"aria2c --optimize-concurrent-downloads --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 {url} -d {directory} -o {url.split('/')[-1]}")
|
43 |
elif "civitai.com" in url:
|
44 |
if "?" in url:
|
45 |
url = url.split("?")[0]
|
|
|
95 |
vae_model_list.insert(0, "None")
|
96 |
|
97 |
|
98 |
+
def is_repo_name(s):
|
99 |
+
import re
|
100 |
+
return re.fullmatch(r'^[^/]+?/[^/]+?$', s)
|
101 |
+
|
102 |
+
|
103 |
def get_t2i_model_info(repo_id: str):
|
104 |
+
api = HfApi(token=HF_TOKEN)
|
|
|
105 |
try:
|
106 |
+
if not is_repo_name(repo_id): return ""
|
107 |
+
model = api.model_info(repo_id=repo_id, timeout=5.0)
|
108 |
except Exception as e:
|
109 |
+
print(f"Error: Failed to get {repo_id}'s info.")
|
110 |
print(e)
|
111 |
return ""
|
112 |
if model.private or model.gated: return ""
|
|
|
114 |
info = []
|
115 |
url = f"https://huggingface.co/{repo_id}/"
|
116 |
if not 'diffusers' in tags: return ""
|
117 |
+
for k, v in MODEL_TYPE_DICT.items():
|
118 |
+
if k in tags: info.append(v)
|
|
|
|
|
|
|
|
|
119 |
if model.card_data and model.card_data.tags:
|
120 |
info.extend(list_sub(model.card_data.tags, ['text-to-image', 'stable-diffusion', 'stable-diffusion-api', 'safetensors', 'stable-diffusion-xl']))
|
121 |
info.append(f"DLs: {model.downloads}")
|
|
|
247 |
gr.update(value=tag, label=label, visible=on), gr.update(value=md, visible=on)
|
248 |
|
249 |
|
250 |
+
def search_civitai_lora(query, base_model, sort="Highest Rated", period="AllTime", tag=""):
|
251 |
global civitai_lora_last_results
|
252 |
+
items = search_lora_on_civitai(query, base_model, 100, sort, period, tag)
|
253 |
if not items: return gr.update(choices=[("", "")], value="", visible=False),\
|
254 |
gr.update(value="", visible=False), gr.update(visible=True), gr.update(visible=True)
|
255 |
civitai_lora_last_results = {}
|
mod.py
CHANGED
@@ -347,7 +347,7 @@ def enhance_prompt(input_prompt):
|
|
347 |
|
348 |
def save_image(image, savefile, modelname, prompt, height, width, steps, cfg, seed):
|
349 |
import uuid
|
350 |
-
from PIL import
|
351 |
import json
|
352 |
try:
|
353 |
if savefile is None: savefile = f"{modelname.split('/')[-1]}_{str(uuid.uuid4())}.png"
|
|
|
347 |
|
348 |
def save_image(image, savefile, modelname, prompt, height, width, steps, cfg, seed):
|
349 |
import uuid
|
350 |
+
from PIL import PngImagePlugin
|
351 |
import json
|
352 |
try:
|
353 |
if savefile is None: savefile = f"{modelname.split('/')[-1]}_{str(uuid.uuid4())}.png"
|
modutils.py
CHANGED
@@ -4,11 +4,19 @@ import gradio as gr
|
|
4 |
from huggingface_hub import HfApi
|
5 |
import os
|
6 |
from pathlib import Path
|
|
|
7 |
|
8 |
|
9 |
from env import (HF_LORA_PRIVATE_REPOS1, HF_LORA_PRIVATE_REPOS2,
|
10 |
-
HF_MODEL_USER_EX, HF_MODEL_USER_LIKES,
|
11 |
-
directory_loras, hf_read_token,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
|
13 |
|
14 |
def get_user_agent():
|
@@ -27,6 +35,11 @@ def list_sub(a, b):
|
|
27 |
return [e for e in a if e not in b]
|
28 |
|
29 |
|
|
|
|
|
|
|
|
|
|
|
30 |
from translatepy import Translator
|
31 |
translator = Translator()
|
32 |
def translate_to_en(input: str):
|
@@ -64,7 +77,7 @@ def download_things(directory, url, hf_token="", civitai_api_key=""):
|
|
64 |
if hf_token:
|
65 |
os.system(f"aria2c --console-log-level=error --summary-interval=10 --header={user_header} -c -x 16 -k 1M -s 16 {url} -d {directory} -o {url.split('/')[-1]}")
|
66 |
else:
|
67 |
-
os.system
|
68 |
elif "civitai.com" in url:
|
69 |
if "?" in url:
|
70 |
url = url.split("?")[0]
|
@@ -100,6 +113,23 @@ def safe_float(input):
|
|
100 |
return output
|
101 |
|
102 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
103 |
def save_gallery_images(images, progress=gr.Progress(track_tqdm=True)):
|
104 |
from datetime import datetime, timezone, timedelta
|
105 |
progress(0, desc="Updating gallery...")
|
@@ -209,11 +239,16 @@ def get_model_id_list():
|
|
209 |
model_ids.append(model.id) if not model.private else ""
|
210 |
anime_models = []
|
211 |
real_models = []
|
|
|
|
|
212 |
for model in models_ex:
|
213 |
-
if not model.private and not model.gated
|
214 |
-
|
|
|
215 |
model_ids.extend(anime_models)
|
216 |
model_ids.extend(real_models)
|
|
|
|
|
217 |
model_id_list = model_ids.copy()
|
218 |
return model_ids
|
219 |
|
@@ -222,10 +257,10 @@ model_id_list = get_model_id_list()
|
|
222 |
|
223 |
|
224 |
def get_t2i_model_info(repo_id: str):
|
225 |
-
api = HfApi()
|
226 |
try:
|
227 |
-
if
|
228 |
-
model = api.model_info(repo_id=repo_id)
|
229 |
except Exception as e:
|
230 |
print(f"Error: Failed to get {repo_id}'s info.")
|
231 |
print(e)
|
@@ -235,9 +270,8 @@ def get_t2i_model_info(repo_id: str):
|
|
235 |
info = []
|
236 |
url = f"https://huggingface.co/{repo_id}/"
|
237 |
if not 'diffusers' in tags: return ""
|
238 |
-
|
239 |
-
|
240 |
-
elif 'diffusers:StableDiffusionPipeline' in tags: info.append("SD1.5")
|
241 |
if model.card_data and model.card_data.tags:
|
242 |
info.extend(list_sub(model.card_data.tags, ['text-to-image', 'stable-diffusion', 'stable-diffusion-api', 'safetensors', 'stable-diffusion-xl']))
|
243 |
info.append(f"DLs: {model.downloads}")
|
@@ -262,12 +296,8 @@ def get_tupled_model_list(model_list):
|
|
262 |
tags = model.tags
|
263 |
info = []
|
264 |
if not 'diffusers' in tags: continue
|
265 |
-
|
266 |
-
info.append(
|
267 |
-
if 'diffusers:StableDiffusionXLPipeline' in tags:
|
268 |
-
info.append("SDXL")
|
269 |
-
elif 'diffusers:StableDiffusionPipeline' in tags:
|
270 |
-
info.append("SD1.5")
|
271 |
if model.card_data and model.card_data.tags:
|
272 |
info.extend(list_sub(model.card_data.tags, ['text-to-image', 'stable-diffusion', 'stable-diffusion-api', 'safetensors', 'stable-diffusion-xl']))
|
273 |
if "pony" in info:
|
@@ -351,7 +381,7 @@ def get_civitai_info(path):
|
|
351 |
|
352 |
|
353 |
def get_lora_model_list():
|
354 |
-
loras = list_uniq(get_private_lora_model_lists() + get_local_model_list(directory_loras))
|
355 |
loras.insert(0, "None")
|
356 |
loras.insert(0, "")
|
357 |
return loras
|
@@ -408,7 +438,7 @@ def download_lora(dl_urls: str):
|
|
408 |
for url in [url.strip() for url in dl_urls.split(',')]:
|
409 |
local_path = f"{directory_loras}/{url.split('/')[-1]}"
|
410 |
if not Path(local_path).exists():
|
411 |
-
download_things(directory_loras, url,
|
412 |
urls.append(url)
|
413 |
after = get_local_model_list(directory_loras)
|
414 |
new_files = list_sub(after, before)
|
@@ -460,7 +490,7 @@ def download_my_lora(dl_urls: str, lora1: str, lora2: str, lora3: str, lora4: st
|
|
460 |
gr.update(value=lora4, choices=choices), gr.update(value=lora5, choices=choices)
|
461 |
|
462 |
|
463 |
-
def get_valid_lora_name(query: str):
|
464 |
path = "None"
|
465 |
if not query or query == "None": return "None"
|
466 |
if to_lora_key(query) in loras_dict.keys(): return query
|
@@ -474,7 +504,7 @@ def get_valid_lora_name(query: str):
|
|
474 |
dl_file = download_lora(query)
|
475 |
if dl_file and Path(dl_file).exists(): return dl_file
|
476 |
else:
|
477 |
-
dl_file = find_similar_lora(query)
|
478 |
if dl_file and Path(dl_file).exists(): return dl_file
|
479 |
return "None"
|
480 |
|
@@ -498,14 +528,14 @@ def get_valid_lora_wt(prompt: str, lora_path: str, lora_wt: float):
|
|
498 |
return wt
|
499 |
|
500 |
|
501 |
-
def set_prompt_loras(prompt, prompt_syntax, lora1, lora1_wt, lora2, lora2_wt, lora3, lora3_wt, lora4, lora4_wt, lora5, lora5_wt):
|
502 |
import re
|
503 |
if not "Classic" in str(prompt_syntax): return lora1, lora1_wt, lora2, lora2_wt, lora3, lora3_wt, lora4, lora4_wt, lora5, lora5_wt
|
504 |
-
lora1 = get_valid_lora_name(lora1)
|
505 |
-
lora2 = get_valid_lora_name(lora2)
|
506 |
-
lora3 = get_valid_lora_name(lora3)
|
507 |
-
lora4 = get_valid_lora_name(lora4)
|
508 |
-
lora5 = get_valid_lora_name(lora5)
|
509 |
if not "<lora" in prompt: return lora1, lora1_wt, lora2, lora2_wt, lora3, lora3_wt, lora4, lora4_wt, lora5, lora5_wt
|
510 |
lora1_wt = get_valid_lora_wt(prompt, lora1, lora1_wt)
|
511 |
lora2_wt = get_valid_lora_wt(prompt, lora2, lora2_wt)
|
@@ -670,7 +700,7 @@ def get_my_lora(link_url):
|
|
670 |
before = get_local_model_list(directory_loras)
|
671 |
for url in [url.strip() for url in link_url.split(',')]:
|
672 |
if not Path(f"{directory_loras}/{url.split('/')[-1]}").exists():
|
673 |
-
download_things(directory_loras, url,
|
674 |
after = get_local_model_list(directory_loras)
|
675 |
new_files = list_sub(after, before)
|
676 |
for file in new_files:
|
@@ -727,8 +757,7 @@ def move_file_lora(filepaths):
|
|
727 |
|
728 |
|
729 |
def get_civitai_info(path):
|
730 |
-
global civitai_not_exists_list
|
731 |
-
global loras_url_to_path_dict
|
732 |
import requests
|
733 |
from requests.adapters import HTTPAdapter
|
734 |
from urllib3.util import Retry
|
@@ -768,16 +797,18 @@ def get_civitai_info(path):
|
|
768 |
return items
|
769 |
|
770 |
|
771 |
-
def search_lora_on_civitai(query: str, allow_model: list[str] = ["Pony", "SDXL 1.0"], limit: int = 100
|
|
|
772 |
import requests
|
773 |
from requests.adapters import HTTPAdapter
|
774 |
from urllib3.util import Retry
|
775 |
-
|
776 |
user_agent = get_user_agent()
|
777 |
headers = {'User-Agent': user_agent, 'content-type': 'application/json'}
|
778 |
base_url = 'https://civitai.com/api/v1/models'
|
779 |
-
params = {'
|
780 |
-
|
|
|
781 |
session = requests.Session()
|
782 |
retries = Retry(total=5, backoff_factor=1, status_forcelist=[500, 502, 503, 504])
|
783 |
session.mount("https://", HTTPAdapter(max_retries=retries))
|
@@ -806,9 +837,9 @@ def search_lora_on_civitai(query: str, allow_model: list[str] = ["Pony", "SDXL 1
|
|
806 |
return items
|
807 |
|
808 |
|
809 |
-
def search_civitai_lora(query, base_model):
|
810 |
global civitai_lora_last_results
|
811 |
-
items = search_lora_on_civitai(query, base_model)
|
812 |
if not items: return gr.update(choices=[("", "")], value="", visible=False),\
|
813 |
gr.update(value="", visible=False), gr.update(visible=True), gr.update(visible=True)
|
814 |
civitai_lora_last_results = {}
|
@@ -834,7 +865,27 @@ def select_civitai_lora(search_result):
|
|
834 |
return gr.update(value=search_result), gr.update(value=md, visible=True)
|
835 |
|
836 |
|
837 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
838 |
from rapidfuzz.process import extractOne
|
839 |
from rapidfuzz.utils import default_process
|
840 |
query = to_lora_key(q)
|
@@ -857,7 +908,7 @@ def find_similar_lora(q: str):
|
|
857 |
print(f"Finding <lora:{query}:...> on Civitai...")
|
858 |
civitai_query = Path(query).stem if Path(query).is_file() else query
|
859 |
civitai_query = civitai_query.replace("_", " ").replace("-", " ")
|
860 |
-
base_model =
|
861 |
items = search_lora_on_civitai(civitai_query, base_model, 1)
|
862 |
if items:
|
863 |
item = items[0]
|
@@ -1219,12 +1270,12 @@ def set_textual_inversion_prompt(textual_inversion_gui, prompt_gui, neg_prompt_g
|
|
1219 |
|
1220 |
def get_model_pipeline(repo_id: str):
|
1221 |
from huggingface_hub import HfApi
|
1222 |
-
api = HfApi()
|
1223 |
default = "StableDiffusionPipeline"
|
1224 |
try:
|
1225 |
-
if
|
1226 |
-
model = api.model_info(repo_id=repo_id)
|
1227 |
-
except Exception
|
1228 |
return default
|
1229 |
if model.private or model.gated: return default
|
1230 |
tags = model.tags
|
|
|
4 |
from huggingface_hub import HfApi
|
5 |
import os
|
6 |
from pathlib import Path
|
7 |
+
from PIL import Image
|
8 |
|
9 |
|
10 |
from env import (HF_LORA_PRIVATE_REPOS1, HF_LORA_PRIVATE_REPOS2,
|
11 |
+
HF_MODEL_USER_EX, HF_MODEL_USER_LIKES, DIFFUSERS_FORMAT_LORAS,
|
12 |
+
directory_loras, hf_read_token, HF_TOKEN, CIVITAI_API_KEY)
|
13 |
+
|
14 |
+
|
15 |
+
MODEL_TYPE_DICT = {
|
16 |
+
"diffusers:StableDiffusionPipeline": "SD 1.5",
|
17 |
+
"diffusers:StableDiffusionXLPipeline": "SDXL",
|
18 |
+
"diffusers:FluxPipeline": "FLUX",
|
19 |
+
}
|
20 |
|
21 |
|
22 |
def get_user_agent():
|
|
|
35 |
return [e for e in a if e not in b]
|
36 |
|
37 |
|
38 |
+
def is_repo_name(s):
|
39 |
+
import re
|
40 |
+
return re.fullmatch(r'^[^/]+?/[^/]+?$', s)
|
41 |
+
|
42 |
+
|
43 |
from translatepy import Translator
|
44 |
translator = Translator()
|
45 |
def translate_to_en(input: str):
|
|
|
77 |
if hf_token:
|
78 |
os.system(f"aria2c --console-log-level=error --summary-interval=10 --header={user_header} -c -x 16 -k 1M -s 16 {url} -d {directory} -o {url.split('/')[-1]}")
|
79 |
else:
|
80 |
+
os.system(f"aria2c --optimize-concurrent-downloads --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 {url} -d {directory} -o {url.split('/')[-1]}")
|
81 |
elif "civitai.com" in url:
|
82 |
if "?" in url:
|
83 |
url = url.split("?")[0]
|
|
|
113 |
return output
|
114 |
|
115 |
|
116 |
+
def save_images(images: list[Image.Image], metadatas: list[str]):
|
117 |
+
from PIL import PngImagePlugin
|
118 |
+
import uuid
|
119 |
+
try:
|
120 |
+
output_images = []
|
121 |
+
for image, metadata in zip(images, metadatas):
|
122 |
+
info = PngImagePlugin.PngInfo()
|
123 |
+
info.add_text("metadata", metadata)
|
124 |
+
savefile = f"{str(uuid.uuid4())}.png"
|
125 |
+
image.save(savefile, "PNG", pnginfo=info)
|
126 |
+
output_images.append(str(Path(savefile).resolve()))
|
127 |
+
return output_images
|
128 |
+
except Exception as e:
|
129 |
+
print(f"Failed to save image file: {e}")
|
130 |
+
raise Exception(f"Failed to save image file:") from e
|
131 |
+
|
132 |
+
|
133 |
def save_gallery_images(images, progress=gr.Progress(track_tqdm=True)):
|
134 |
from datetime import datetime, timezone, timedelta
|
135 |
progress(0, desc="Updating gallery...")
|
|
|
239 |
model_ids.append(model.id) if not model.private else ""
|
240 |
anime_models = []
|
241 |
real_models = []
|
242 |
+
anime_models_flux = []
|
243 |
+
real_models_flux = []
|
244 |
for model in models_ex:
|
245 |
+
if not model.private and not model.gated:
|
246 |
+
if "diffusers:FluxPipeline" in model.tags: anime_models_flux.append(model.id) if "anime" in model.tags else real_models_flux.append(model.id)
|
247 |
+
else: anime_models.append(model.id) if "anime" in model.tags else real_models.append(model.id)
|
248 |
model_ids.extend(anime_models)
|
249 |
model_ids.extend(real_models)
|
250 |
+
model_ids.extend(anime_models_flux)
|
251 |
+
model_ids.extend(real_models_flux)
|
252 |
model_id_list = model_ids.copy()
|
253 |
return model_ids
|
254 |
|
|
|
257 |
|
258 |
|
259 |
def get_t2i_model_info(repo_id: str):
|
260 |
+
api = HfApi(token=HF_TOKEN)
|
261 |
try:
|
262 |
+
if not is_repo_name(repo_id): return ""
|
263 |
+
model = api.model_info(repo_id=repo_id, timeout=5.0)
|
264 |
except Exception as e:
|
265 |
print(f"Error: Failed to get {repo_id}'s info.")
|
266 |
print(e)
|
|
|
270 |
info = []
|
271 |
url = f"https://huggingface.co/{repo_id}/"
|
272 |
if not 'diffusers' in tags: return ""
|
273 |
+
for k, v in MODEL_TYPE_DICT.items():
|
274 |
+
if k in tags: info.append(v)
|
|
|
275 |
if model.card_data and model.card_data.tags:
|
276 |
info.extend(list_sub(model.card_data.tags, ['text-to-image', 'stable-diffusion', 'stable-diffusion-api', 'safetensors', 'stable-diffusion-xl']))
|
277 |
info.append(f"DLs: {model.downloads}")
|
|
|
296 |
tags = model.tags
|
297 |
info = []
|
298 |
if not 'diffusers' in tags: continue
|
299 |
+
for k, v in MODEL_TYPE_DICT.items():
|
300 |
+
if k in tags: info.append(v)
|
|
|
|
|
|
|
|
|
301 |
if model.card_data and model.card_data.tags:
|
302 |
info.extend(list_sub(model.card_data.tags, ['text-to-image', 'stable-diffusion', 'stable-diffusion-api', 'safetensors', 'stable-diffusion-xl']))
|
303 |
if "pony" in info:
|
|
|
381 |
|
382 |
|
383 |
def get_lora_model_list():
|
384 |
+
loras = list_uniq(get_private_lora_model_lists() + get_local_model_list(directory_loras) + DIFFUSERS_FORMAT_LORAS)
|
385 |
loras.insert(0, "None")
|
386 |
loras.insert(0, "")
|
387 |
return loras
|
|
|
438 |
for url in [url.strip() for url in dl_urls.split(',')]:
|
439 |
local_path = f"{directory_loras}/{url.split('/')[-1]}"
|
440 |
if not Path(local_path).exists():
|
441 |
+
download_things(directory_loras, url, HF_TOKEN, CIVITAI_API_KEY)
|
442 |
urls.append(url)
|
443 |
after = get_local_model_list(directory_loras)
|
444 |
new_files = list_sub(after, before)
|
|
|
490 |
gr.update(value=lora4, choices=choices), gr.update(value=lora5, choices=choices)
|
491 |
|
492 |
|
493 |
+
def get_valid_lora_name(query: str, model_name: str):
|
494 |
path = "None"
|
495 |
if not query or query == "None": return "None"
|
496 |
if to_lora_key(query) in loras_dict.keys(): return query
|
|
|
504 |
dl_file = download_lora(query)
|
505 |
if dl_file and Path(dl_file).exists(): return dl_file
|
506 |
else:
|
507 |
+
dl_file = find_similar_lora(query, model_name)
|
508 |
if dl_file and Path(dl_file).exists(): return dl_file
|
509 |
return "None"
|
510 |
|
|
|
528 |
return wt
|
529 |
|
530 |
|
531 |
+
def set_prompt_loras(prompt, prompt_syntax, model_name, lora1, lora1_wt, lora2, lora2_wt, lora3, lora3_wt, lora4, lora4_wt, lora5, lora5_wt):
|
532 |
import re
|
533 |
if not "Classic" in str(prompt_syntax): return lora1, lora1_wt, lora2, lora2_wt, lora3, lora3_wt, lora4, lora4_wt, lora5, lora5_wt
|
534 |
+
lora1 = get_valid_lora_name(lora1, model_name)
|
535 |
+
lora2 = get_valid_lora_name(lora2, model_name)
|
536 |
+
lora3 = get_valid_lora_name(lora3, model_name)
|
537 |
+
lora4 = get_valid_lora_name(lora4, model_name)
|
538 |
+
lora5 = get_valid_lora_name(lora5, model_name)
|
539 |
if not "<lora" in prompt: return lora1, lora1_wt, lora2, lora2_wt, lora3, lora3_wt, lora4, lora4_wt, lora5, lora5_wt
|
540 |
lora1_wt = get_valid_lora_wt(prompt, lora1, lora1_wt)
|
541 |
lora2_wt = get_valid_lora_wt(prompt, lora2, lora2_wt)
|
|
|
700 |
before = get_local_model_list(directory_loras)
|
701 |
for url in [url.strip() for url in link_url.split(',')]:
|
702 |
if not Path(f"{directory_loras}/{url.split('/')[-1]}").exists():
|
703 |
+
download_things(directory_loras, url, HF_TOKEN, CIVITAI_API_KEY)
|
704 |
after = get_local_model_list(directory_loras)
|
705 |
new_files = list_sub(after, before)
|
706 |
for file in new_files:
|
|
|
757 |
|
758 |
|
759 |
def get_civitai_info(path):
|
760 |
+
global civitai_not_exists_list, loras_url_to_path_dict
|
|
|
761 |
import requests
|
762 |
from requests.adapters import HTTPAdapter
|
763 |
from urllib3.util import Retry
|
|
|
797 |
return items
|
798 |
|
799 |
|
800 |
+
def search_lora_on_civitai(query: str, allow_model: list[str] = ["Pony", "SDXL 1.0"], limit: int = 100,
|
801 |
+
sort: str = "Highest Rated", period: str = "AllTime", tag: str = ""):
|
802 |
import requests
|
803 |
from requests.adapters import HTTPAdapter
|
804 |
from urllib3.util import Retry
|
805 |
+
|
806 |
user_agent = get_user_agent()
|
807 |
headers = {'User-Agent': user_agent, 'content-type': 'application/json'}
|
808 |
base_url = 'https://civitai.com/api/v1/models'
|
809 |
+
params = {'types': ['LORA'], 'sort': sort, 'period': period, 'limit': limit, 'nsfw': 'true'}
|
810 |
+
if query: params["query"] = query
|
811 |
+
if tag: params["tag"] = tag
|
812 |
session = requests.Session()
|
813 |
retries = Retry(total=5, backoff_factor=1, status_forcelist=[500, 502, 503, 504])
|
814 |
session.mount("https://", HTTPAdapter(max_retries=retries))
|
|
|
837 |
return items
|
838 |
|
839 |
|
840 |
+
def search_civitai_lora(query, base_model, sort="Highest Rated", period="AllTime", tag=""):
|
841 |
global civitai_lora_last_results
|
842 |
+
items = search_lora_on_civitai(query, base_model, 100, sort, period, tag)
|
843 |
if not items: return gr.update(choices=[("", "")], value="", visible=False),\
|
844 |
gr.update(value="", visible=False), gr.update(visible=True), gr.update(visible=True)
|
845 |
civitai_lora_last_results = {}
|
|
|
865 |
return gr.update(value=search_result), gr.update(value=md, visible=True)
|
866 |
|
867 |
|
868 |
+
LORA_BASE_MODEL_DICT = {
|
869 |
+
"diffusers:StableDiffusionPipeline": ["SD 1.5"],
|
870 |
+
"diffusers:StableDiffusionXLPipeline": ["Pony", "SDXL 1.0"],
|
871 |
+
"diffusers:FluxPipeline": ["Flux.1 D", "Flux.1 S"],
|
872 |
+
}
|
873 |
+
|
874 |
+
|
875 |
+
def get_lora_base_model(model_name: str):
|
876 |
+
api = HfApi(token=HF_TOKEN)
|
877 |
+
default = ["Pony", "SDXL 1.0"]
|
878 |
+
try:
|
879 |
+
model = api.model_info(repo_id=model_name, timeout=5.0)
|
880 |
+
tags = model.tags
|
881 |
+
for tag in tags:
|
882 |
+
if tag in LORA_BASE_MODEL_DICT.keys(): return LORA_BASE_MODEL_DICT.get(tag, default)
|
883 |
+
except Exception:
|
884 |
+
return default
|
885 |
+
return default
|
886 |
+
|
887 |
+
|
888 |
+
def find_similar_lora(q: str, model_name: str):
|
889 |
from rapidfuzz.process import extractOne
|
890 |
from rapidfuzz.utils import default_process
|
891 |
query = to_lora_key(q)
|
|
|
908 |
print(f"Finding <lora:{query}:...> on Civitai...")
|
909 |
civitai_query = Path(query).stem if Path(query).is_file() else query
|
910 |
civitai_query = civitai_query.replace("_", " ").replace("-", " ")
|
911 |
+
base_model = get_lora_base_model(model_name)
|
912 |
items = search_lora_on_civitai(civitai_query, base_model, 1)
|
913 |
if items:
|
914 |
item = items[0]
|
|
|
1270 |
|
1271 |
def get_model_pipeline(repo_id: str):
|
1272 |
from huggingface_hub import HfApi
|
1273 |
+
api = HfApi(token=HF_TOKEN)
|
1274 |
default = "StableDiffusionPipeline"
|
1275 |
try:
|
1276 |
+
if not is_repo_name(repo_id): return default
|
1277 |
+
model = api.model_info(repo_id=repo_id, timeout=5.0)
|
1278 |
+
except Exception:
|
1279 |
return default
|
1280 |
if model.private or model.gated: return default
|
1281 |
tags = model.tags
|