Update app.py
Browse files
app.py
CHANGED
@@ -160,6 +160,7 @@ def recursive_update(d, u):
|
|
160 |
def start_training(
|
161 |
lora_name,
|
162 |
concept_sentence,
|
|
|
163 |
steps,
|
164 |
lr,
|
165 |
rank,
|
@@ -224,7 +225,12 @@ def start_training(
|
|
224 |
config["config"]["process"][0]["sample"]["prompts"].append(sample_3)
|
225 |
else:
|
226 |
config["config"]["process"][0]["train"]["disable_sampling"] = True
|
227 |
-
|
|
|
|
|
|
|
|
|
|
|
228 |
if(use_more_advanced_options):
|
229 |
more_advanced_options_dict = yaml.safe_load(more_advanced_options)
|
230 |
config["config"]["process"][0] = recursive_update(config["config"]["process"][0], more_advanced_options_dict)
|
@@ -291,11 +297,13 @@ def update_pricing(steps, oauth_token: Union[gr.OAuthToken, None]):
|
|
291 |
else:
|
292 |
return gr.update(visible=False), "", gr.update(visible=False), gr.update(visible=True)
|
293 |
|
|
|
|
|
|
|
294 |
config_yaml = '''
|
295 |
device: cuda:0
|
296 |
model:
|
297 |
is_flux: true
|
298 |
-
name_or_path: black-forest-labs/FLUX.1-dev
|
299 |
quantize: true
|
300 |
network:
|
301 |
linear: 16 #it will overcome the 'rank' parameter
|
@@ -342,6 +350,7 @@ h3{margin-top: 0}
|
|
342 |
.main_ui_logged_out{opacity: 0.3; pointer-events: none}
|
343 |
.tabitem{border: 0px}
|
344 |
.group_padding{padding: .55em}
|
|
|
345 |
"""
|
346 |
with gr.Blocks(theme=theme, css=css) as demo:
|
347 |
gr.Markdown(
|
@@ -352,18 +361,22 @@ with gr.Blocks(theme=theme, css=css) as demo:
|
|
352 |
gr.LoginButton("Sign in with Hugging Face to train your LoRA on Spaces", visible=is_spaces)
|
353 |
with gr.Tab("Train on Spaces" if is_spaces else "Train locally"):
|
354 |
with gr.Column() as main_ui:
|
355 |
-
with gr.
|
356 |
-
|
357 |
-
|
358 |
-
|
359 |
-
|
360 |
-
|
361 |
-
|
362 |
-
|
363 |
-
|
364 |
-
|
365 |
-
|
366 |
-
|
|
|
|
|
|
|
|
|
367 |
with gr.Group(visible=True) as image_upload:
|
368 |
with gr.Row():
|
369 |
images = gr.File(
|
@@ -503,12 +516,18 @@ with gr.Blocks(theme=theme, css=css) as demo:
|
|
503 |
inputs=[steps],
|
504 |
outputs=[cost_preview, cost_preview_info, payment_update, start]
|
505 |
)
|
506 |
-
|
|
|
|
|
|
|
|
|
|
|
507 |
start.click(fn=create_dataset, inputs=[images] + caption_list, outputs=dataset_folder).then(
|
508 |
fn=start_training,
|
509 |
inputs=[
|
510 |
lora_name,
|
511 |
concept_sentence,
|
|
|
512 |
steps,
|
513 |
lr,
|
514 |
rank,
|
|
|
160 |
def start_training(
|
161 |
lora_name,
|
162 |
concept_sentence,
|
163 |
+
which_model,
|
164 |
steps,
|
165 |
lr,
|
166 |
rank,
|
|
|
225 |
config["config"]["process"][0]["sample"]["prompts"].append(sample_3)
|
226 |
else:
|
227 |
config["config"]["process"][0]["train"]["disable_sampling"] = True
|
228 |
+
|
229 |
+
if(which_model == "[schnell] (4 step fast model)"):
|
230 |
+
config["config"]["process"][0]["model"]["name_or_path"] = "black-forest-labs/FLUX.1-schnell"
|
231 |
+
config["config"]["process"][0]["model"]["assistant_lora_path"] = "ostris/FLUX.1-schnell-training-adapter"
|
232 |
+
config["config"]["process"][0]["sample"]["sample_steps"] = 4
|
233 |
+
|
234 |
if(use_more_advanced_options):
|
235 |
more_advanced_options_dict = yaml.safe_load(more_advanced_options)
|
236 |
config["config"]["process"][0] = recursive_update(config["config"]["process"][0], more_advanced_options_dict)
|
|
|
297 |
else:
|
298 |
return gr.update(visible=False), "", gr.update(visible=False), gr.update(visible=True)
|
299 |
|
300 |
+
def swap_base_model(model):
|
301 |
+
return gr.update(visible=True) if model == "[dev] (high quality model, non-commercial license)" else gr.update(visible=False)
|
302 |
+
|
303 |
config_yaml = '''
|
304 |
device: cuda:0
|
305 |
model:
|
306 |
is_flux: true
|
|
|
307 |
quantize: true
|
308 |
network:
|
309 |
linear: 16 #it will overcome the 'rank' parameter
|
|
|
350 |
.main_ui_logged_out{opacity: 0.3; pointer-events: none}
|
351 |
.tabitem{border: 0px}
|
352 |
.group_padding{padding: .55em}
|
353 |
+
#space_model .wrap > label:last-child{opacity: 0.3; pointer-events:none}
|
354 |
"""
|
355 |
with gr.Blocks(theme=theme, css=css) as demo:
|
356 |
gr.Markdown(
|
|
|
361 |
gr.LoginButton("Sign in with Hugging Face to train your LoRA on Spaces", visible=is_spaces)
|
362 |
with gr.Tab("Train on Spaces" if is_spaces else "Train locally"):
|
363 |
with gr.Column() as main_ui:
|
364 |
+
with gr.Group():
|
365 |
+
with gr.Row():
|
366 |
+
lora_name = gr.Textbox(
|
367 |
+
label="The name of your LoRA",
|
368 |
+
info="This has to be a unique name",
|
369 |
+
placeholder="e.g.: Persian Miniature Painting style, Cat Toy",
|
370 |
+
)
|
371 |
+
concept_sentence = gr.Textbox(
|
372 |
+
label="Trigger word/sentence",
|
373 |
+
info="Trigger word or sentence to be used",
|
374 |
+
placeholder="uncommon word like p3rs0n or trtcrd, or sentence like 'in the style of CNSTLL'",
|
375 |
+
interactive=True,
|
376 |
+
)
|
377 |
+
which_model = gr.Radio(["[schnell] (4 step fast model)", "[dev] (high quality model, non-commercial license - available when training locally)"], label="Which base model to train?", elem_id="space_model" if is_spaces else "local_model", value="[schnell] (4 step fast model)",)
|
378 |
+
model_warning = gr.Markdown("""> [dev] model license is non-commercial. By choosing to fine-tune [dev], you must agree with [its license](https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.md) and make sure the LoRA you will train and the training process you would start does not violate it.
|
379 |
+
""", visible=False)
|
380 |
with gr.Group(visible=True) as image_upload:
|
381 |
with gr.Row():
|
382 |
images = gr.File(
|
|
|
516 |
inputs=[steps],
|
517 |
outputs=[cost_preview, cost_preview_info, payment_update, start]
|
518 |
)
|
519 |
+
|
520 |
+
which_model.change(
|
521 |
+
fn=swap_base_model,
|
522 |
+
inputs=which_model,
|
523 |
+
outputs=model_warning
|
524 |
+
)
|
525 |
start.click(fn=create_dataset, inputs=[images] + caption_list, outputs=dataset_folder).then(
|
526 |
fn=start_training,
|
527 |
inputs=[
|
528 |
lora_name,
|
529 |
concept_sentence,
|
530 |
+
which_model,
|
531 |
steps,
|
532 |
lr,
|
533 |
rank,
|