Upload kohya_gui_kaggle.py
Browse files- kohya_gui_kaggle.py +22 -11
kohya_gui_kaggle.py
CHANGED
@@ -1,11 +1,12 @@
|
|
1 |
import gradio as gr
|
2 |
import os
|
3 |
import argparse
|
4 |
-
from
|
5 |
-
from
|
6 |
-
from
|
|
|
7 |
from kohya_gui.utilities import utilities_tab
|
8 |
-
from lora_gui import lora_tab
|
9 |
from kohya_gui.class_lora_tab import LoRATools
|
10 |
|
11 |
from kohya_gui.custom_logging import setup_logging
|
@@ -22,9 +23,9 @@ def UI(**kwargs):
|
|
22 |
headless = kwargs.get("headless", False)
|
23 |
log.info(f"headless: {headless}")
|
24 |
|
25 |
-
if os.path.exists("./style.css"):
|
26 |
-
with open(os.path.join("./style.css"), "r", encoding="utf8") as file:
|
27 |
-
log.
|
28 |
css += file.read() + "\n"
|
29 |
|
30 |
if os.path.exists("./.release"):
|
@@ -38,6 +39,8 @@ def UI(**kwargs):
|
|
38 |
interface = gr.Blocks(
|
39 |
css=css, title=f"Kohya_ss GUI {release}", theme=gr.themes.Default()
|
40 |
)
|
|
|
|
|
41 |
|
42 |
with interface:
|
43 |
with gr.Tab("Dreambooth"):
|
@@ -46,13 +49,13 @@ def UI(**kwargs):
|
|
46 |
reg_data_dir_input,
|
47 |
output_dir_input,
|
48 |
logging_dir_input,
|
49 |
-
) = dreambooth_tab(headless=headless)
|
50 |
with gr.Tab("LoRA"):
|
51 |
-
lora_tab(headless=headless)
|
52 |
with gr.Tab("Textual Inversion"):
|
53 |
-
ti_tab(headless=headless)
|
54 |
with gr.Tab("Finetuning"):
|
55 |
-
finetune_tab(headless=headless)
|
56 |
with gr.Tab("Utilities"):
|
57 |
utilities_tab(
|
58 |
train_data_dir_input=train_data_dir_input,
|
@@ -102,6 +105,12 @@ def UI(**kwargs):
|
|
102 |
if __name__ == "__main__":
|
103 |
# torch.cuda.set_per_process_memory_fraction(0.48)
|
104 |
parser = argparse.ArgumentParser()
|
|
|
|
|
|
|
|
|
|
|
|
|
105 |
parser.add_argument(
|
106 |
"--listen",
|
107 |
type=str,
|
@@ -130,10 +139,12 @@ if __name__ == "__main__":
|
|
130 |
)
|
131 |
|
132 |
parser.add_argument("--use-ipex", action="store_true", help="Use IPEX environment")
|
|
|
133 |
|
134 |
args = parser.parse_args()
|
135 |
|
136 |
UI(
|
|
|
137 |
username=args.username,
|
138 |
password=args.password,
|
139 |
inbrowser=args.inbrowser,
|
|
|
1 |
import gradio as gr
|
2 |
import os
|
3 |
import argparse
|
4 |
+
from kohya_gui.class_gui_config import KohyaSSGUIConfig
|
5 |
+
from kohya_gui.dreambooth_gui import dreambooth_tab
|
6 |
+
from kohya_gui.finetune_gui import finetune_tab
|
7 |
+
from kohya_gui.textual_inversion_gui import ti_tab
|
8 |
from kohya_gui.utilities import utilities_tab
|
9 |
+
from kohya_gui.lora_gui import lora_tab
|
10 |
from kohya_gui.class_lora_tab import LoRATools
|
11 |
|
12 |
from kohya_gui.custom_logging import setup_logging
|
|
|
23 |
headless = kwargs.get("headless", False)
|
24 |
log.info(f"headless: {headless}")
|
25 |
|
26 |
+
if os.path.exists("./assets/style.css"):
|
27 |
+
with open(os.path.join("./assets/style.css"), "r", encoding="utf8") as file:
|
28 |
+
log.debug("Load CSS...")
|
29 |
css += file.read() + "\n"
|
30 |
|
31 |
if os.path.exists("./.release"):
|
|
|
39 |
interface = gr.Blocks(
|
40 |
css=css, title=f"Kohya_ss GUI {release}", theme=gr.themes.Default()
|
41 |
)
|
42 |
+
|
43 |
+
config = KohyaSSGUIConfig(config_file_path=kwargs.get("config_file_path"))
|
44 |
|
45 |
with interface:
|
46 |
with gr.Tab("Dreambooth"):
|
|
|
49 |
reg_data_dir_input,
|
50 |
output_dir_input,
|
51 |
logging_dir_input,
|
52 |
+
) = dreambooth_tab(headless=headless, config=config)
|
53 |
with gr.Tab("LoRA"):
|
54 |
+
lora_tab(headless=headless, config=config)
|
55 |
with gr.Tab("Textual Inversion"):
|
56 |
+
ti_tab(headless=headless, config=config)
|
57 |
with gr.Tab("Finetuning"):
|
58 |
+
finetune_tab(headless=headless, config=config)
|
59 |
with gr.Tab("Utilities"):
|
60 |
utilities_tab(
|
61 |
train_data_dir_input=train_data_dir_input,
|
|
|
105 |
if __name__ == "__main__":
|
106 |
# torch.cuda.set_per_process_memory_fraction(0.48)
|
107 |
parser = argparse.ArgumentParser()
|
108 |
+
parser.add_argument(
|
109 |
+
"--config",
|
110 |
+
type=str,
|
111 |
+
default="./config.toml",
|
112 |
+
help="Path to the toml config file for interface defaults",
|
113 |
+
)
|
114 |
parser.add_argument(
|
115 |
"--listen",
|
116 |
type=str,
|
|
|
139 |
)
|
140 |
|
141 |
parser.add_argument("--use-ipex", action="store_true", help="Use IPEX environment")
|
142 |
+
parser.add_argument("--use-rocm", action="store_true", help="Use ROCm environment")
|
143 |
|
144 |
args = parser.parse_args()
|
145 |
|
146 |
UI(
|
147 |
+
config_file_path=args.config,
|
148 |
username=args.username,
|
149 |
password=args.password,
|
150 |
inbrowser=args.inbrowser,
|