MonsterMMORPG commited on
Commit
eac949a
·
verified ·
1 Parent(s): b9b3327

Upload kohya_gui_kaggle.py

Browse files
Files changed (1) hide show
  1. kohya_gui_kaggle.py +22 -11
kohya_gui_kaggle.py CHANGED
@@ -1,11 +1,12 @@
1
  import gradio as gr
2
  import os
3
  import argparse
4
- from dreambooth_gui import dreambooth_tab
5
- from finetune_gui import finetune_tab
6
- from textual_inversion_gui import ti_tab
 
7
  from kohya_gui.utilities import utilities_tab
8
- from lora_gui import lora_tab
9
  from kohya_gui.class_lora_tab import LoRATools
10
 
11
  from kohya_gui.custom_logging import setup_logging
@@ -22,9 +23,9 @@ def UI(**kwargs):
22
  headless = kwargs.get("headless", False)
23
  log.info(f"headless: {headless}")
24
 
25
- if os.path.exists("./style.css"):
26
- with open(os.path.join("./style.css"), "r", encoding="utf8") as file:
27
- log.info("Load CSS...")
28
  css += file.read() + "\n"
29
 
30
  if os.path.exists("./.release"):
@@ -38,6 +39,8 @@ def UI(**kwargs):
38
  interface = gr.Blocks(
39
  css=css, title=f"Kohya_ss GUI {release}", theme=gr.themes.Default()
40
  )
 
 
41
 
42
  with interface:
43
  with gr.Tab("Dreambooth"):
@@ -46,13 +49,13 @@ def UI(**kwargs):
46
  reg_data_dir_input,
47
  output_dir_input,
48
  logging_dir_input,
49
- ) = dreambooth_tab(headless=headless)
50
  with gr.Tab("LoRA"):
51
- lora_tab(headless=headless)
52
  with gr.Tab("Textual Inversion"):
53
- ti_tab(headless=headless)
54
  with gr.Tab("Finetuning"):
55
- finetune_tab(headless=headless)
56
  with gr.Tab("Utilities"):
57
  utilities_tab(
58
  train_data_dir_input=train_data_dir_input,
@@ -102,6 +105,12 @@ def UI(**kwargs):
102
  if __name__ == "__main__":
103
  # torch.cuda.set_per_process_memory_fraction(0.48)
104
  parser = argparse.ArgumentParser()
 
 
 
 
 
 
105
  parser.add_argument(
106
  "--listen",
107
  type=str,
@@ -130,10 +139,12 @@ if __name__ == "__main__":
130
  )
131
 
132
  parser.add_argument("--use-ipex", action="store_true", help="Use IPEX environment")
 
133
 
134
  args = parser.parse_args()
135
 
136
  UI(
 
137
  username=args.username,
138
  password=args.password,
139
  inbrowser=args.inbrowser,
 
1
  import gradio as gr
2
  import os
3
  import argparse
4
+ from kohya_gui.class_gui_config import KohyaSSGUIConfig
5
+ from kohya_gui.dreambooth_gui import dreambooth_tab
6
+ from kohya_gui.finetune_gui import finetune_tab
7
+ from kohya_gui.textual_inversion_gui import ti_tab
8
  from kohya_gui.utilities import utilities_tab
9
+ from kohya_gui.lora_gui import lora_tab
10
  from kohya_gui.class_lora_tab import LoRATools
11
 
12
  from kohya_gui.custom_logging import setup_logging
 
23
  headless = kwargs.get("headless", False)
24
  log.info(f"headless: {headless}")
25
 
26
+ if os.path.exists("./assets/style.css"):
27
+ with open(os.path.join("./assets/style.css"), "r", encoding="utf8") as file:
28
+ log.debug("Load CSS...")
29
  css += file.read() + "\n"
30
 
31
  if os.path.exists("./.release"):
 
39
  interface = gr.Blocks(
40
  css=css, title=f"Kohya_ss GUI {release}", theme=gr.themes.Default()
41
  )
42
+
43
+ config = KohyaSSGUIConfig(config_file_path=kwargs.get("config_file_path"))
44
 
45
  with interface:
46
  with gr.Tab("Dreambooth"):
 
49
  reg_data_dir_input,
50
  output_dir_input,
51
  logging_dir_input,
52
+ ) = dreambooth_tab(headless=headless, config=config)
53
  with gr.Tab("LoRA"):
54
+ lora_tab(headless=headless, config=config)
55
  with gr.Tab("Textual Inversion"):
56
+ ti_tab(headless=headless, config=config)
57
  with gr.Tab("Finetuning"):
58
+ finetune_tab(headless=headless, config=config)
59
  with gr.Tab("Utilities"):
60
  utilities_tab(
61
  train_data_dir_input=train_data_dir_input,
 
105
  if __name__ == "__main__":
106
  # torch.cuda.set_per_process_memory_fraction(0.48)
107
  parser = argparse.ArgumentParser()
108
+ parser.add_argument(
109
+ "--config",
110
+ type=str,
111
+ default="./config.toml",
112
+ help="Path to the toml config file for interface defaults",
113
+ )
114
  parser.add_argument(
115
  "--listen",
116
  type=str,
 
139
  )
140
 
141
  parser.add_argument("--use-ipex", action="store_true", help="Use IPEX environment")
142
+ parser.add_argument("--use-rocm", action="store_true", help="Use ROCm environment")
143
 
144
  args = parser.parse_args()
145
 
146
  UI(
147
+ config_file_path=args.config,
148
  username=args.username,
149
  password=args.password,
150
  inbrowser=args.inbrowser,