import argparse import inspect import os from pathlib import Path import toml from kohya_ss.library import train_util, config_util import gradio as gr from scripts.shared import ROOT_DIR from scripts.utilities import gradio_to_args PRESET_DIR = os.path.join(ROOT_DIR, "presets") PRESET_PATH = os.path.join(ROOT_DIR, "presets.json") def get_arg_templates(fn): parser = argparse.ArgumentParser() args = [parser] sig = inspect.signature(fn) args.extend([True] * (len(sig.parameters) - 1)) fn(*args) keys = [ x.replace("--", "") for x in parser.__dict__["_option_string_actions"].keys() ] keys = [x for x in keys if x not in ["help", "-h"]] return keys, fn.__name__.replace("add_", "") arguments_functions = [ train_util.add_dataset_arguments, train_util.add_optimizer_arguments, train_util.add_sd_models_arguments, train_util.add_sd_saving_arguments, train_util.add_training_arguments, config_util.add_config_arguments, ] arg_templates = [get_arg_templates(x) for x in arguments_functions] def load_presets(): obj = {} os.makedirs(PRESET_DIR, exist_ok=True) preset_names = os.listdir(PRESET_DIR) for preset_name in preset_names: preset_path = os.path.join(PRESET_DIR, preset_name) obj[preset_name] = {} for key in os.listdir(preset_path): key = key.replace(".toml", "") obj[preset_name][key] = load_preset(preset_name, key) return obj def load_preset(key, name): filepath = os.path.join(PRESET_DIR, key, name + ".toml") if not os.path.exists(filepath): return {} with open(filepath, mode="r") as f: obj = toml.load(f) flatten = {} for k, v in obj.items(): if not isinstance(v, dict): flatten[k] = v else: for k2, v2 in v.items(): flatten[k2] = v2 return flatten def save_preset(key, name, value): obj = {} for k, v in value.items(): if isinstance(v, Path): v = str(v) for (template, category) in arg_templates: if k in template: if category not in obj: obj[category] = {} obj[category][k] = v break else: obj[k] = v filepath = os.path.join(PRESET_DIR, key, name + ".toml") os.makedirs(os.path.dirname(filepath), exist_ok=True) with open(filepath, mode="w") as f: toml.dump(obj, f) def delete_preset(key, name): filepath = os.path.join(PRESET_DIR, key, name + ".toml") if os.path.exists(filepath): os.remove(filepath) def create_ui(key, tmpls, opts): get_templates = lambda: tmpls() if callable(tmpls) else tmpls get_options = lambda: opts() if callable(opts) else opts presets = load_presets() if key not in presets: presets[key] = {} with gr.Box(): with gr.Row(): with gr.Column() as c: load_preset_button = gr.Button("Load preset", variant="primary") delete_preset_button = gr.Button("Delete preset") with gr.Column() as c: load_preset_name = gr.Dropdown( list(presets[key].keys()), show_label=False ).style(container=False) reload_presets_button = gr.Button("🔄️") with gr.Column() as c: c.scale = 0.5 save_preset_name = gr.Textbox( "", placeholder="Preset name", lines=1, show_label=False ).style(container=False) save_preset_button = gr.Button("Save preset", variant="primary") def update_dropdown(): presets = load_presets() if key not in presets: presets[key] = {} return gr.Dropdown.update(choices=list(presets[key].keys())) def _save_preset(args): name = args[save_preset_name] if not name: return update_dropdown() args = gradio_to_args(get_templates(), get_options(), args) save_preset(key, name, args) return update_dropdown() def _load_preset(args): name = args[load_preset_name] if not name: return update_dropdown() args = gradio_to_args(get_templates(), get_options(), args) preset = load_preset(key, name) result = [] for k, _ in args.items(): if k == load_preset_name: continue if k not in preset: result.append(None) continue v = preset[k] if type(v) == list: v = " ".join(v) result.append(v) return result[0] if len(result) == 1 else result def _delete_preset(name): if not name: return update_dropdown() delete_preset(key, name) return update_dropdown() def init(): save_preset_button.click( _save_preset, set([save_preset_name, *get_options().values()]), [load_preset_name], ) load_preset_button.click( _load_preset, set([load_preset_name, *get_options().values()]), [*get_options().values()], ) delete_preset_button.click(_delete_preset, load_preset_name, [load_preset_name]) reload_presets_button.click( update_dropdown, inputs=[], outputs=[load_preset_name] ) return init