Spaces:
Runtime error
Runtime error
Commit
•
2406cac
1
Parent(s):
09e5977
Update app.py
Browse files
app.py
CHANGED
@@ -3,6 +3,7 @@ import subprocess
|
|
3 |
from typing import Union
|
4 |
from huggingface_hub import whoami
|
5 |
is_spaces = True if os.environ.get("SPACE_ID") else False
|
|
|
6 |
|
7 |
if is_spaces:
|
8 |
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
|
@@ -226,7 +227,7 @@ def start_training(
|
|
226 |
else:
|
227 |
config["config"]["process"][0]["train"]["disable_sampling"] = True
|
228 |
|
229 |
-
if(which_model == "[schnell]
|
230 |
config["config"]["process"][0]["model"]["name_or_path"] = "black-forest-labs/FLUX.1-schnell"
|
231 |
config["config"]["process"][0]["model"]["assistant_lora_path"] = "ostris/FLUX.1-schnell-training-adapter"
|
232 |
config["config"]["process"][0]["sample"]["sample_steps"] = 4
|
@@ -374,7 +375,13 @@ with gr.Blocks(theme=theme, css=css) as demo:
|
|
374 |
placeholder="uncommon word like p3rs0n or trtcrd, or sentence like 'in the style of CNSTLL'",
|
375 |
interactive=True,
|
376 |
)
|
377 |
-
which_model = gr.Radio(
|
|
|
|
|
|
|
|
|
|
|
|
|
378 |
model_warning = gr.Markdown("""> [dev] model license is non-commercial. By choosing to fine-tune [dev], you must agree with [its license](https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.md) and make sure the LoRA you will train and the training process you would start does not violate it.
|
379 |
""", visible=False)
|
380 |
with gr.Group(visible=True) as image_upload:
|
|
|
3 |
from typing import Union
|
4 |
from huggingface_hub import whoami
|
5 |
is_spaces = True if os.environ.get("SPACE_ID") else False
|
6 |
+
is_canonical = True if os.environ.get("SPACE_ID") == "autotrain-projects/train-flux-lora-ease" else False
|
7 |
|
8 |
if is_spaces:
|
9 |
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
|
|
|
227 |
else:
|
228 |
config["config"]["process"][0]["train"]["disable_sampling"] = True
|
229 |
|
230 |
+
if(which_model == "[schnell]"):
|
231 |
config["config"]["process"][0]["model"]["name_or_path"] = "black-forest-labs/FLUX.1-schnell"
|
232 |
config["config"]["process"][0]["model"]["assistant_lora_path"] = "ostris/FLUX.1-schnell-training-adapter"
|
233 |
config["config"]["process"][0]["sample"]["sample_steps"] = 4
|
|
|
375 |
placeholder="uncommon word like p3rs0n or trtcrd, or sentence like 'in the style of CNSTLL'",
|
376 |
interactive=True,
|
377 |
)
|
378 |
+
which_model = gr.Radio(
|
379 |
+
[("[schnell] (4 step fast model)", "[schnell]"),
|
380 |
+
("[dev] (high quality model, non-commercial license - available if you duplicate this space or locally)" if is_canonical else "[dev] (high quality model, non-commercial license)", "[dev]")],
|
381 |
+
label="Which base model to train?",
|
382 |
+
elem_id="space_model" if is_canonical else "local_model",
|
383 |
+
value="[schnell]" if is_canonical else "[dev]"
|
384 |
+
)
|
385 |
model_warning = gr.Markdown("""> [dev] model license is non-commercial. By choosing to fine-tune [dev], you must agree with [its license](https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.md) and make sure the LoRA you will train and the training process you would start does not violate it.
|
386 |
""", visible=False)
|
387 |
with gr.Group(visible=True) as image_upload:
|