|
|
|
|
|
|
|
import gradio as gr |
|
from pathlib import Path |
|
from modules import scripts, script_callbacks, shared, sd_hijack |
|
|
|
import yaml |
|
|
|
|
|
FILE_DIR = Path().absolute() |
|
|
|
|
|
|
|
|
|
EXT_PATH = FILE_DIR.joinpath('extensions') |
|
|
|
|
|
|
|
|
|
TAGS_PATH = Path(scripts.basedir()).joinpath('tags') |
|
|
|
|
|
WILDCARD_PATH = FILE_DIR.joinpath('scripts/wildcards') |
|
EMB_PATH = Path(shared.cmd_opts.embeddings_dir) |
|
HYP_PATH = Path(shared.cmd_opts.hypernetwork_dir) |
|
|
|
try: |
|
LORA_PATH = Path(shared.cmd_opts.lora_dir) |
|
except AttributeError: |
|
LORA_PATH = None |
|
|
|
def find_ext_wildcard_paths(): |
|
"""Returns the path to the extension wildcards folder""" |
|
found = list(EXT_PATH.glob('*/wildcards/')) |
|
return found |
|
|
|
|
|
|
|
WILDCARD_EXT_PATHS = find_ext_wildcard_paths() |
|
|
|
|
|
STATIC_TEMP_PATH = FILE_DIR.joinpath('tmp') |
|
TEMP_PATH = TAGS_PATH.joinpath('temp') |
|
|
|
|
|
def get_wildcards(): |
|
"""Returns a list of all wildcards. Works on nested folders.""" |
|
wildcard_files = list(WILDCARD_PATH.rglob("*.txt")) |
|
resolved = [w.relative_to(WILDCARD_PATH).as_posix( |
|
) for w in wildcard_files if w.name != "put wildcards here.txt"] |
|
return resolved |
|
|
|
|
|
def get_ext_wildcards(): |
|
"""Returns a list of all extension wildcards. Works on nested folders.""" |
|
wildcard_files = [] |
|
|
|
for path in WILDCARD_EXT_PATHS: |
|
wildcard_files.append(path.relative_to(FILE_DIR).as_posix()) |
|
wildcard_files.extend(p.relative_to(path).as_posix() for p in path.rglob("*.txt") if p.name != "put wildcards here.txt") |
|
wildcard_files.append("-----") |
|
|
|
return wildcard_files |
|
|
|
|
|
def get_ext_wildcard_tags(): |
|
"""Returns a list of all tags found in extension YAML files found under a Tags: key.""" |
|
wildcard_tags = {} |
|
yaml_files = [] |
|
for path in WILDCARD_EXT_PATHS: |
|
yaml_files.extend(p for p in path.rglob("*.yml")) |
|
yaml_files.extend(p for p in path.rglob("*.yaml")) |
|
count = 0 |
|
for path in yaml_files: |
|
try: |
|
with open(path, encoding="utf8") as file: |
|
data = yaml.safe_load(file) |
|
for item in data: |
|
if data[item] and 'Tags' in data[item]: |
|
wildcard_tags[count] = ','.join(data[item]['Tags']) |
|
count += 1 |
|
else: |
|
print('Issue with tags found in ' + path.name + ' at item ' + item) |
|
except yaml.YAMLError as exc: |
|
print(exc) |
|
|
|
sorted_tags = sorted(wildcard_tags.items(), key=lambda item: item[1], reverse=True) |
|
output = [] |
|
for tag, count in sorted_tags: |
|
output.append(f"{tag},{count}") |
|
return output |
|
|
|
|
|
def get_embeddings(sd_model): |
|
"""Write a list of all embeddings with their version""" |
|
|
|
|
|
V1_SHAPE = 768 |
|
V2_SHAPE = 1024 |
|
emb_v1 = [] |
|
emb_v2 = [] |
|
results = [] |
|
|
|
try: |
|
|
|
emb_type_a = sd_hijack.model_hijack.embedding_db.word_embeddings |
|
emb_type_b = sd_hijack.model_hijack.embedding_db.skipped_embeddings |
|
|
|
emb_a_shape = -1 |
|
emb_b_shape = -1 |
|
if (len(emb_type_a) > 0): |
|
emb_a_shape = next(iter(emb_type_a.items()))[1].shape |
|
if (len(emb_type_b) > 0): |
|
emb_b_shape = next(iter(emb_type_b.items()))[1].shape |
|
|
|
|
|
if (emb_a_shape == V1_SHAPE): |
|
emb_v1 = list(emb_type_a.keys()) |
|
elif (emb_a_shape == V2_SHAPE): |
|
emb_v2 = list(emb_type_a.keys()) |
|
|
|
if (emb_b_shape == V1_SHAPE): |
|
emb_v1 = list(emb_type_b.keys()) |
|
elif (emb_b_shape == V2_SHAPE): |
|
emb_v2 = list(emb_type_b.keys()) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
results = sorted([e + ",v1" for e in emb_v1] + [e + ",v2" for e in emb_v2], key=lambda x: x.lower()) |
|
except AttributeError: |
|
print("tag_autocomplete_helper: Old webui version or unrecognized model shape, using fallback for embedding completion.") |
|
|
|
all_embeds = [str(e.relative_to(EMB_PATH)) for e in EMB_PATH.rglob("*") if e.suffix in {".bin", ".pt", ".png",'.webp', '.jxl', '.avif'}] |
|
|
|
all_embeds = [e for e in all_embeds if EMB_PATH.joinpath(e).stat().st_size > 0] |
|
|
|
all_embeds = [e[:e.rfind('.')] for e in all_embeds] |
|
results = [e + "," for e in all_embeds] |
|
|
|
write_to_temp_file('emb.txt', results) |
|
|
|
def get_hypernetworks(): |
|
"""Write a list of all hypernetworks""" |
|
|
|
|
|
all_hypernetworks = [str(h.name) for h in HYP_PATH.rglob("*") if h.suffix in {".pt"}] |
|
|
|
return sorted([h[:h.rfind('.')] for h in all_hypernetworks], key=lambda x: x.lower()) |
|
|
|
def get_lora(): |
|
"""Write a list of all lora""" |
|
|
|
|
|
all_lora = [str(l.name) for l in LORA_PATH.rglob("*") if l.suffix in {".safetensors", ".ckpt", ".pt"}] |
|
|
|
return sorted([l[:l.rfind('.')] for l in all_lora], key=lambda x: x.lower()) |
|
|
|
|
|
def write_tag_base_path(): |
|
"""Writes the tag base path to a fixed location temporary file""" |
|
with open(STATIC_TEMP_PATH.joinpath('tagAutocompletePath.txt'), 'w', encoding="utf-8") as f: |
|
f.write(TAGS_PATH.relative_to(FILE_DIR).as_posix()) |
|
|
|
|
|
def write_to_temp_file(name, data): |
|
"""Writes the given data to a temporary file""" |
|
with open(TEMP_PATH.joinpath(name), 'w', encoding="utf-8") as f: |
|
f.write(('\n'.join(data))) |
|
|
|
|
|
csv_files = [] |
|
csv_files_withnone = [] |
|
def update_tag_files(): |
|
"""Returns a list of all potential tag files""" |
|
global csv_files, csv_files_withnone |
|
files = [str(t.relative_to(TAGS_PATH)) for t in TAGS_PATH.glob("*.csv")] |
|
csv_files = files |
|
csv_files_withnone = ["None"] + files |
|
|
|
|
|
|
|
|
|
|
|
if not STATIC_TEMP_PATH.exists(): |
|
STATIC_TEMP_PATH.mkdir(exist_ok=True) |
|
|
|
write_tag_base_path() |
|
update_tag_files() |
|
|
|
|
|
if not TEMP_PATH.exists(): |
|
TEMP_PATH.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
|
|
write_to_temp_file('wc.txt', []) |
|
write_to_temp_file('wce.txt', []) |
|
write_to_temp_file('wcet.txt', []) |
|
write_to_temp_file('hyp.txt', []) |
|
write_to_temp_file('lora.txt', []) |
|
|
|
if not TEMP_PATH.joinpath("emb.txt").exists(): |
|
write_to_temp_file('emb.txt', []) |
|
|
|
|
|
if WILDCARD_PATH.exists(): |
|
wildcards = [WILDCARD_PATH.relative_to(FILE_DIR).as_posix()] + get_wildcards() |
|
if wildcards: |
|
write_to_temp_file('wc.txt', wildcards) |
|
|
|
|
|
if WILDCARD_EXT_PATHS is not None: |
|
wildcards_ext = get_ext_wildcards() |
|
if wildcards_ext: |
|
write_to_temp_file('wce.txt', wildcards_ext) |
|
|
|
wildcards_yaml_ext = get_ext_wildcard_tags() |
|
if wildcards_yaml_ext: |
|
write_to_temp_file('wcet.txt', wildcards_yaml_ext) |
|
|
|
|
|
if EMB_PATH.exists(): |
|
|
|
script_callbacks.on_model_loaded(get_embeddings) |
|
|
|
if HYP_PATH.exists(): |
|
hypernets = get_hypernetworks() |
|
if hypernets: |
|
write_to_temp_file('hyp.txt', hypernets) |
|
|
|
if LORA_PATH is not None and LORA_PATH.exists(): |
|
lora = get_lora() |
|
if lora: |
|
write_to_temp_file('lora.txt', lora) |
|
|
|
|
|
def on_ui_settings(): |
|
TAC_SECTION = ("tac", "Tag Autocomplete") |
|
|
|
shared.opts.add_option("tac_tagFile", shared.OptionInfo("danbooru.csv", "Tag filename", gr.Dropdown, lambda: {"choices": csv_files_withnone}, refresh=update_tag_files, section=TAC_SECTION)) |
|
|
|
shared.opts.add_option("tac_active", shared.OptionInfo(True, "Enable Tag Autocompletion", section=TAC_SECTION)) |
|
shared.opts.add_option("tac_activeIn.txt2img", shared.OptionInfo(True, "Active in txt2img (Requires restart)", section=TAC_SECTION)) |
|
shared.opts.add_option("tac_activeIn.img2img", shared.OptionInfo(True, "Active in img2img (Requires restart)", section=TAC_SECTION)) |
|
shared.opts.add_option("tac_activeIn.negativePrompts", shared.OptionInfo(True, "Active in negative prompts (Requires restart)", section=TAC_SECTION)) |
|
shared.opts.add_option("tac_activeIn.thirdParty", shared.OptionInfo(True, "Active in third party textboxes [Dataset Tag Editor] (Requires restart)", section=TAC_SECTION)) |
|
shared.opts.add_option("tac_activeIn.modelList", shared.OptionInfo("", "List of model names (with file extension) or their hashes to use as black/whitelist, separated by commas.", section=TAC_SECTION)) |
|
shared.opts.add_option("tac_activeIn.modelListMode", shared.OptionInfo("Blacklist", "Mode to use for model list", gr.Dropdown, lambda: {"choices": ["Blacklist","Whitelist"]}, section=TAC_SECTION)) |
|
|
|
shared.opts.add_option("tac_slidingPopup", shared.OptionInfo(True, "Move completion popup together with text cursor", section=TAC_SECTION)) |
|
shared.opts.add_option("tac_maxResults", shared.OptionInfo(5, "Maximum results", section=TAC_SECTION)) |
|
shared.opts.add_option("tac_showAllResults", shared.OptionInfo(False, "Show all results", section=TAC_SECTION)) |
|
shared.opts.add_option("tac_resultStepLength", shared.OptionInfo(100, "How many results to load at once", section=TAC_SECTION)) |
|
shared.opts.add_option("tac_delayTime", shared.OptionInfo(100, "Time in ms to wait before triggering completion again (Requires restart)", section=TAC_SECTION)) |
|
shared.opts.add_option("tac_useWildcards", shared.OptionInfo(True, "Search for wildcards", section=TAC_SECTION)) |
|
shared.opts.add_option("tac_useEmbeddings", shared.OptionInfo(True, "Search for embeddings", section=TAC_SECTION)) |
|
shared.opts.add_option("tac_useHypernetworks", shared.OptionInfo(True, "Search for hypernetworks", section=TAC_SECTION)) |
|
shared.opts.add_option("tac_useLoras", shared.OptionInfo(True, "Search for Loras", section=TAC_SECTION)) |
|
shared.opts.add_option("tac_showWikiLinks", shared.OptionInfo(False, "Show '?' next to tags, linking to its Danbooru or e621 wiki page (Warning: This is an external site and very likely contains NSFW examples!)", section=TAC_SECTION)) |
|
|
|
shared.opts.add_option("tac_replaceUnderscores", shared.OptionInfo(True, "Replace underscores with spaces on insertion", section=TAC_SECTION)) |
|
shared.opts.add_option("tac_escapeParentheses", shared.OptionInfo(True, "Escape parentheses on insertion", section=TAC_SECTION)) |
|
shared.opts.add_option("tac_appendComma", shared.OptionInfo(True, "Append comma on tag autocompletion", section=TAC_SECTION)) |
|
|
|
shared.opts.add_option("tac_alias.searchByAlias", shared.OptionInfo(True, "Search by alias", section=TAC_SECTION)) |
|
shared.opts.add_option("tac_alias.onlyShowAlias", shared.OptionInfo(False, "Only show alias", section=TAC_SECTION)) |
|
|
|
shared.opts.add_option("tac_translation.translationFile", shared.OptionInfo("None", "Translation filename", gr.Dropdown, lambda: {"choices": csv_files_withnone}, refresh=update_tag_files, section=TAC_SECTION)) |
|
shared.opts.add_option("tac_translation.oldFormat", shared.OptionInfo(False, "Translation file uses old 3-column translation format instead of the new 2-column one", section=TAC_SECTION)) |
|
shared.opts.add_option("tac_translation.searchByTranslation", shared.OptionInfo(True, "Search by translation", section=TAC_SECTION)) |
|
|
|
shared.opts.add_option("tac_extra.extraFile", shared.OptionInfo("extra-quality-tags.csv", "Extra filename (for small sets of custom tags)", gr.Dropdown, lambda: {"choices": csv_files_withnone}, refresh=update_tag_files, section=TAC_SECTION)) |
|
shared.opts.add_option("tac_extra.addMode", shared.OptionInfo("Insert before", "Mode to add the extra tags to the main tag list", gr.Dropdown, lambda: {"choices": ["Insert before","Insert after"]}, section=TAC_SECTION)) |
|
|
|
script_callbacks.on_ui_settings(on_ui_settings) |
|
|