Kha37lid commited on
Commit
6719a18
1 Parent(s): 5b6a636

Upload kohya_gui_kaggle.py

Browse files
Files changed (1) hide show
  1. kohya_gui_kaggle.py +156 -0
kohya_gui_kaggle.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import argparse
4
+ from kohya_gui.class_gui_config import KohyaSSGUIConfig
5
+ from kohya_gui.dreambooth_gui import dreambooth_tab
6
+ from kohya_gui.finetune_gui import finetune_tab
7
+ from kohya_gui.textual_inversion_gui import ti_tab
8
+ from kohya_gui.utilities import utilities_tab
9
+ from kohya_gui.lora_gui import lora_tab
10
+ from kohya_gui.class_lora_tab import LoRATools
11
+
12
+ from kohya_gui.custom_logging import setup_logging
13
+ from kohya_gui.localization_ext import add_javascript
14
+
15
+ # Set up logging
16
+ log = setup_logging()
17
+
18
+
19
+ def UI(**kwargs):
20
+ add_javascript(kwargs.get("language"))
21
+ css = ""
22
+
23
+ headless = kwargs.get("headless", False)
24
+ log.info(f"headless: {headless}")
25
+
26
+ if os.path.exists("./assets/style.css"):
27
+ with open(os.path.join("./assets/style.css"), "r", encoding="utf8") as file:
28
+ log.debug("Load CSS...")
29
+ css += file.read() + "\n"
30
+
31
+ if os.path.exists("./.release"):
32
+ with open(os.path.join("./.release"), "r", encoding="utf8") as file:
33
+ release = file.read()
34
+
35
+ if os.path.exists("./README.md"):
36
+ with open(os.path.join("./README.md"), "r", encoding="utf8") as file:
37
+ README = file.read()
38
+
39
+ interface = gr.Blocks(
40
+ css=css, title=f"Kohya_ss GUI {release}", theme=gr.themes.Default()
41
+ )
42
+
43
+ config = KohyaSSGUIConfig(config_file_path=kwargs.get("config_file_path"))
44
+
45
+ with interface:
46
+ with gr.Tab("Dreambooth"):
47
+ (
48
+ train_data_dir_input,
49
+ reg_data_dir_input,
50
+ output_dir_input,
51
+ logging_dir_input,
52
+ ) = dreambooth_tab(headless=headless, config=config)
53
+ with gr.Tab("LoRA"):
54
+ lora_tab(headless=headless, config=config)
55
+ with gr.Tab("Textual Inversion"):
56
+ ti_tab(headless=headless, config=config)
57
+ with gr.Tab("Finetuning"):
58
+ finetune_tab(headless=headless, config=config)
59
+ with gr.Tab("Utilities"):
60
+ utilities_tab(
61
+ train_data_dir_input=train_data_dir_input,
62
+ reg_data_dir_input=reg_data_dir_input,
63
+ output_dir_input=output_dir_input,
64
+ logging_dir_input=logging_dir_input,
65
+ enable_copy_info_button=True,
66
+ headless=headless,
67
+ )
68
+ with gr.Tab("LoRA"):
69
+ _ = LoRATools(headless=headless)
70
+ with gr.Tab("About"):
71
+ gr.Markdown(f"kohya_ss GUI release {release}")
72
+ with gr.Tab("README"):
73
+ gr.Markdown(README)
74
+
75
+ htmlStr = f"""
76
+ <html>
77
+ <body>
78
+ <div class="ver-class">{release}</div>
79
+ </body>
80
+ </html>
81
+ """
82
+ gr.HTML(htmlStr)
83
+ # Show the interface
84
+ launch_kwargs = {}
85
+ username = kwargs.get("username")
86
+ password = kwargs.get("password")
87
+ server_port = kwargs.get("server_port", 0)
88
+ inbrowser = kwargs.get("inbrowser", False)
89
+ share = False
90
+ server_name = kwargs.get("listen")
91
+
92
+ launch_kwargs["server_name"] = server_name
93
+ if username and password:
94
+ launch_kwargs["auth"] = (username, password)
95
+ if server_port > 0:
96
+ launch_kwargs["server_port"] = server_port
97
+ if inbrowser:
98
+ launch_kwargs["inbrowser"] = inbrowser
99
+ if share:
100
+ launch_kwargs["share"] = False
101
+ launch_kwargs["debug"] = True
102
+ interface.launch(**launch_kwargs, share=False)
103
+
104
+
105
+ if __name__ == "__main__":
106
+ # torch.cuda.set_per_process_memory_fraction(0.48)
107
+ parser = argparse.ArgumentParser()
108
+ parser.add_argument(
109
+ "--config",
110
+ type=str,
111
+ default="./config.toml",
112
+ help="Path to the toml config file for interface defaults",
113
+ )
114
+ parser.add_argument(
115
+ "--listen",
116
+ type=str,
117
+ default="127.0.0.1",
118
+ help="IP to listen on for connections to Gradio",
119
+ )
120
+ parser.add_argument(
121
+ "--username", type=str, default="", help="Username for authentication"
122
+ )
123
+ parser.add_argument(
124
+ "--password", type=str, default="", help="Password for authentication"
125
+ )
126
+ parser.add_argument(
127
+ "--server_port",
128
+ type=int,
129
+ default=0,
130
+ help="Port to run the server listener on",
131
+ )
132
+ parser.add_argument("--inbrowser", action="store_true", help="Open in browser")
133
+ parser.add_argument("--share", action="store_true", help="Share the gradio UI")
134
+ parser.add_argument(
135
+ "--headless", action="store_true", help="Is the server headless"
136
+ )
137
+ parser.add_argument(
138
+ "--language", type=str, default=None, help="Set custom language"
139
+ )
140
+
141
+ parser.add_argument("--use-ipex", action="store_true", help="Use IPEX environment")
142
+ parser.add_argument("--use-rocm", action="store_true", help="Use ROCm environment")
143
+
144
+ args = parser.parse_args()
145
+
146
+ UI(
147
+ config_file_path=args.config,
148
+ username=args.username,
149
+ password=args.password,
150
+ inbrowser=args.inbrowser,
151
+ server_port=args.server_port,
152
+ share=False,
153
+ listen=args.listen,
154
+ headless=args.headless,
155
+ language=args.language,
156
+ )