Spaces:
Build error
Build error
import os | |
import sys | |
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" | |
os.environ['GRADIO_ANALYTICS_ENABLED'] = '0' | |
sys.path.insert(0, os.getcwd()) | |
sys.path.append(os.path.join(os.path.dirname(__file__), 'sd-scripts')) | |
import subprocess | |
import gradio as gr | |
from PIL import Image | |
import torch | |
import uuid | |
import shutil | |
import json | |
import yaml | |
from slugify import slugify | |
from transformers import AutoProcessor, AutoModelForCausalLM | |
from gradio_logsview import LogsView, LogsViewRunner | |
from huggingface_hub import hf_hub_download, HfApi | |
from library import flux_train_utils, huggingface_util | |
from argparse import Namespace | |
import train_network | |
import toml | |
import re | |
MAX_IMAGES = 150 | |
with open('models.yaml', 'r') as file: | |
models = yaml.safe_load(file) | |
def readme(base_model, lora_name, instance_prompt, sample_prompts): | |
# model license | |
model_config = models[base_model] | |
model_file = model_config["file"] | |
base_model_name = model_config["base"] | |
license = None | |
license_name = None | |
license_link = None | |
license_items = [] | |
if "license" in model_config: | |
license = model_config["license"] | |
license_items.append(f"license: {license}") | |
if "license_name" in model_config: | |
license_name = model_config["license_name"] | |
license_items.append(f"license_name: {license_name}") | |
if "license_link" in model_config: | |
license_link = model_config["license_link"] | |
license_items.append(f"license_link: {license_link}") | |
license_str = "\n".join(license_items) | |
print(f"license_items={license_items}") | |
print(f"license_str = {license_str}") | |
# tags | |
tags = [ "text-to-image", "flux", "lora", "diffusers", "template:sd-lora", "fluxgym" ] | |
# widgets | |
widgets = [] | |
sample_image_paths = [] | |
output_name = slugify(lora_name) | |
samples_dir = resolve_path_without_quotes(f"outputs/{output_name}/sample") | |
try: | |
for filename in os.listdir(samples_dir): | |
# Filename Schema: [name]_[steps]_[index]_[timestamp].png | |
match = re.search(r"_(\d+)_(\d+)_(\d+)\.png$", filename) | |
if match: | |
steps, index, timestamp = int(match.group(1)), int(match.group(2)), int(match.group(3)) | |
sample_image_paths.append((steps, index, f"sample/{filename}")) | |
# Sort by numeric index | |
sample_image_paths.sort(key=lambda x: x[0], reverse=True) | |
final_sample_image_paths = sample_image_paths[:len(sample_prompts)] | |
final_sample_image_paths.sort(key=lambda x: x[1]) | |
for i, prompt in enumerate(sample_prompts): | |
_, _, image_path = final_sample_image_paths[i] | |
widgets.append( | |
{ | |
"text": prompt, | |
"output": { | |
"url": image_path | |
}, | |
} | |
) | |
except: | |
print(f"no samples") | |
dtype = "torch.bfloat16" | |
# Construct the README content | |
readme_content = f"""--- | |
tags: | |
{yaml.dump(tags, indent=4).strip()} | |
{"widget:" if os.path.isdir(samples_dir) else ""} | |
{yaml.dump(widgets, indent=4).strip() if widgets else ""} | |
base_model: {base_model_name} | |
{"instance_prompt: " + instance_prompt if instance_prompt else ""} | |
{license_str} | |
--- | |
# {lora_name} | |
A Flux LoRA trained on a local computer with [Fluxgym](https://github.com/cocktailpeanut/fluxgym) | |
<Gallery /> | |
## Trigger words | |
{"You should use `" + instance_prompt + "` to trigger the image generation." if instance_prompt else "No trigger words defined."} | |
## Download model and use it with ComfyUI, AUTOMATIC1111, SD.Next, Invoke AI, Forge, etc. | |
Weights for this model are available in Safetensors format. | |
""" | |
return readme_content | |
def account_hf(): | |
try: | |
with open("HF_TOKEN", "r") as file: | |
token = file.read() | |
api = HfApi(token=token) | |
try: | |
account = api.whoami() | |
return { "token": token, "account": account['name'] } | |
except: | |
return None | |
except: | |
return None | |
""" | |
hf_logout.click(fn=logout_hf, outputs=[hf_token, hf_login, hf_logout, repo_owner]) | |
""" | |
def logout_hf(): | |
os.remove("HF_TOKEN") | |
global current_account | |
current_account = account_hf() | |
print(f"current_account={current_account}") | |
return gr.update(value=""), gr.update(visible=True), gr.update(visible=False), gr.update(value="", visible=False) | |
""" | |
hf_login.click(fn=login_hf, inputs=[hf_token], outputs=[hf_token, hf_login, hf_logout, repo_owner]) | |
""" | |
def login_hf(hf_token): | |
api = HfApi(token=hf_token) | |
try: | |
account = api.whoami() | |
if account != None: | |
if "name" in account: | |
with open("HF_TOKEN", "w") as file: | |
file.write(hf_token) | |
global current_account | |
current_account = account_hf() | |
return gr.update(visible=True), gr.update(visible=False), gr.update(visible=True), gr.update(value=current_account["account"], visible=True) | |
return gr.update(), gr.update(), gr.update(), gr.update() | |
except: | |
print(f"incorrect hf_token") | |
return gr.update(), gr.update(), gr.update(), gr.update() | |
def upload_hf(base_model, lora_rows, repo_owner, repo_name, repo_visibility, hf_token): | |
src = lora_rows | |
repo_id = f"{repo_owner}/{repo_name}" | |
gr.Info(f"Uploading to Huggingface. Please Stand by...", duration=None) | |
args = Namespace( | |
huggingface_repo_id=repo_id, | |
huggingface_repo_type="model", | |
huggingface_repo_visibility=repo_visibility, | |
huggingface_path_in_repo="", | |
huggingface_token=hf_token, | |
async_upload=False | |
) | |
print(f"upload_hf args={args}") | |
huggingface_util.upload(args=args, src=src) | |
gr.Info(f"[Upload Complete] https://huggingface.co/{repo_id}", duration=None) | |
def load_captioning(uploaded_files, concept_sentence): | |
uploaded_images = [file for file in uploaded_files if not file.endswith('.txt')] | |
txt_files = [file for file in uploaded_files if file.endswith('.txt')] | |
txt_files_dict = {os.path.splitext(os.path.basename(txt_file))[0]: txt_file for txt_file in txt_files} | |
updates = [] | |
if len(uploaded_images) <= 1: | |
raise gr.Error( | |
"Please upload at least 2 images to train your model (the ideal number with default settings is between 4-30)" | |
) | |
elif len(uploaded_images) > MAX_IMAGES: | |
raise gr.Error(f"For now, only {MAX_IMAGES} or less images are allowed for training") | |
# Update for the captioning_area | |
# for _ in range(3): | |
updates.append(gr.update(visible=True)) | |
# Update visibility and image for each captioning row and image | |
for i in range(1, MAX_IMAGES + 1): | |
# Determine if the current row and image should be visible | |
visible = i <= len(uploaded_images) | |
# Update visibility of the captioning row | |
updates.append(gr.update(visible=visible)) | |
# Update for image component - display image if available, otherwise hide | |
image_value = uploaded_images[i - 1] if visible else None | |
updates.append(gr.update(value=image_value, visible=visible)) | |
corresponding_caption = False | |
if(image_value): | |
base_name = os.path.splitext(os.path.basename(image_value))[0] | |
if base_name in txt_files_dict: | |
with open(txt_files_dict[base_name], 'r') as file: | |
corresponding_caption = file.read() | |
# Update value of captioning area | |
text_value = corresponding_caption if visible and corresponding_caption else concept_sentence if visible and concept_sentence else None | |
updates.append(gr.update(value=text_value, visible=visible)) | |
# Update for the sample caption area | |
updates.append(gr.update(visible=True)) | |
updates.append(gr.update(visible=True)) | |
return updates | |
def hide_captioning(): | |
return gr.update(visible=False), gr.update(visible=False) | |
def resize_image(image_path, output_path, size): | |
with Image.open(image_path) as img: | |
width, height = img.size | |
if width < height: | |
new_width = size | |
new_height = int((size/width) * height) | |
else: | |
new_height = size | |
new_width = int((size/height) * width) | |
print(f"resize {image_path} : {new_width}x{new_height}") | |
img_resized = img.resize((new_width, new_height), Image.Resampling.LANCZOS) | |
img_resized.save(output_path) | |
def create_dataset(destination_folder, size, *inputs): | |
print("Creating dataset") | |
images = inputs[0] | |
if not os.path.exists(destination_folder): | |
os.makedirs(destination_folder) | |
for index, image in enumerate(images): | |
# copy the images to the datasets folder | |
new_image_path = shutil.copy(image, destination_folder) | |
# if it's a caption text file skip the next bit | |
ext = os.path.splitext(new_image_path)[-1].lower() | |
if ext == '.txt': | |
continue | |
# resize the images | |
resize_image(new_image_path, new_image_path, size) | |
# copy the captions | |
original_caption = inputs[index + 1] | |
image_file_name = os.path.basename(new_image_path) | |
caption_file_name = os.path.splitext(image_file_name)[0] + ".txt" | |
caption_path = resolve_path_without_quotes(os.path.join(destination_folder, caption_file_name)) | |
print(f"image_path={new_image_path}, caption_path = {caption_path}, original_caption={original_caption}") | |
# if caption_path exists, do not write | |
if os.path.exists(caption_path): | |
print(f"{caption_path} already exists. use the existing .txt file") | |
else: | |
print(f"{caption_path} create a .txt caption file") | |
with open(caption_path, 'w') as file: | |
file.write(original_caption) | |
print(f"destination_folder {destination_folder}") | |
return destination_folder | |
def run_captioning(images, concept_sentence, *captions): | |
print(f"run_captioning") | |
print(f"concept sentence {concept_sentence}") | |
print(f"captions {captions}") | |
#Load internally to not consume resources for training | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
print(f"device={device}") | |
torch_dtype = torch.float16 | |
model = AutoModelForCausalLM.from_pretrained( | |
"multimodalart/Florence-2-large-no-flash-attn", torch_dtype=torch_dtype, trust_remote_code=True | |
).to(device) | |
processor = AutoProcessor.from_pretrained("multimodalart/Florence-2-large-no-flash-attn", trust_remote_code=True) | |
captions = list(captions) | |
for i, image_path in enumerate(images): | |
print(captions[i]) | |
if isinstance(image_path, str): # If image is a file path | |
image = Image.open(image_path).convert("RGB") | |
prompt = "<DETAILED_CAPTION>" | |
inputs = processor(text=prompt, images=image, return_tensors="pt").to(device, torch_dtype) | |
print(f"inputs {inputs}") | |
generated_ids = model.generate( | |
input_ids=inputs["input_ids"], pixel_values=inputs["pixel_values"], max_new_tokens=1024, num_beams=3 | |
) | |
print(f"generated_ids {generated_ids}") | |
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0] | |
print(f"generated_text: {generated_text}") | |
parsed_answer = processor.post_process_generation( | |
generated_text, task=prompt, image_size=(image.width, image.height) | |
) | |
print(f"parsed_answer = {parsed_answer}") | |
caption_text = parsed_answer["<DETAILED_CAPTION>"].replace("The image shows ", "") | |
print(f"caption_text = {caption_text}, concept_sentence={concept_sentence}") | |
if concept_sentence: | |
caption_text = f"{concept_sentence} {caption_text}" | |
captions[i] = caption_text | |
yield captions | |
model.to("cpu") | |
del model | |
del processor | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
def recursive_update(d, u): | |
for k, v in u.items(): | |
if isinstance(v, dict) and v: | |
d[k] = recursive_update(d.get(k, {}), v) | |
else: | |
d[k] = v | |
return d | |
def download(base_model): | |
model = models[base_model] | |
model_file = model["file"] | |
repo = model["repo"] | |
# download unet | |
if base_model == "flux-dev" or base_model == "flux-schnell": | |
unet_folder = "models/unet" | |
else: | |
unet_folder = f"models/unet/{repo}" | |
unet_path = os.path.join(unet_folder, model_file) | |
if not os.path.exists(unet_path): | |
os.makedirs(unet_folder, exist_ok=True) | |
gr.Info(f"Downloading base model: {base_model}. Please wait. (You can check the terminal for the download progress)", duration=None) | |
print(f"download {base_model}") | |
hf_hub_download(repo_id=repo, local_dir=unet_folder, filename=model_file) | |
# download vae | |
vae_folder = "models/vae" | |
vae_path = os.path.join(vae_folder, "ae.sft") | |
if not os.path.exists(vae_path): | |
os.makedirs(vae_folder, exist_ok=True) | |
gr.Info(f"Downloading vae") | |
print(f"downloading ae.sft...") | |
hf_hub_download(repo_id="cocktailpeanut/xulf-dev", local_dir=vae_folder, filename="ae.sft") | |
# download clip | |
clip_folder = "models/clip" | |
clip_l_path = os.path.join(clip_folder, "clip_l.safetensors") | |
if not os.path.exists(clip_l_path): | |
os.makedirs(clip_folder, exist_ok=True) | |
gr.Info(f"Downloading clip...") | |
print(f"download clip_l.safetensors") | |
hf_hub_download(repo_id="comfyanonymous/flux_text_encoders", local_dir=clip_folder, filename="clip_l.safetensors") | |
# download t5xxl | |
t5xxl_path = os.path.join(clip_folder, "t5xxl_fp16.safetensors") | |
if not os.path.exists(t5xxl_path): | |
print(f"download t5xxl_fp16.safetensors") | |
gr.Info(f"Downloading t5xxl...") | |
hf_hub_download(repo_id="comfyanonymous/flux_text_encoders", local_dir=clip_folder, filename="t5xxl_fp16.safetensors") | |
def resolve_path(p): | |
current_dir = os.path.dirname(os.path.abspath(__file__)) | |
norm_path = os.path.normpath(os.path.join(current_dir, p)) | |
return f"\"{norm_path}\"" | |
def resolve_path_without_quotes(p): | |
current_dir = os.path.dirname(os.path.abspath(__file__)) | |
norm_path = os.path.normpath(os.path.join(current_dir, p)) | |
return norm_path | |
def gen_sh( | |
base_model, | |
output_name, | |
resolution, | |
seed, | |
workers, | |
learning_rate, | |
network_dim, | |
max_train_epochs, | |
save_every_n_epochs, | |
timestep_sampling, | |
guidance_scale, | |
vram, | |
sample_prompts, | |
sample_every_n_steps, | |
*advanced_components | |
): | |
print(f"gen_sh: network_dim:{network_dim}, max_train_epochs={max_train_epochs}, save_every_n_epochs={save_every_n_epochs}, timestep_sampling={timestep_sampling}, guidance_scale={guidance_scale}, vram={vram}, sample_prompts={sample_prompts}, sample_every_n_steps={sample_every_n_steps}") | |
output_dir = resolve_path(f"outputs/{output_name}") | |
sample_prompts_path = resolve_path(f"outputs/{output_name}/sample_prompts.txt") | |
line_break = "\\" | |
file_type = "sh" | |
if sys.platform == "win32": | |
line_break = "^" | |
file_type = "bat" | |
############# Sample args ######################## | |
sample = "" | |
if len(sample_prompts) > 0 and sample_every_n_steps > 0: | |
sample = f"""--sample_prompts={sample_prompts_path} --sample_every_n_steps="{sample_every_n_steps}" {line_break}""" | |
############# Optimizer args ######################## | |
# if vram == "8G": | |
# optimizer = f"""--optimizer_type adafactor {line_break} | |
# --optimizer_args "relative_step=False" "scale_parameter=False" "warmup_init=False" {line_break} | |
# --split_mode {line_break} | |
# --network_args "train_blocks=single" {line_break} | |
# --lr_scheduler constant_with_warmup {line_break} | |
# --max_grad_norm 0.0 {line_break}""" | |
if vram == "16G": | |
# 16G VRAM | |
optimizer = f"""--optimizer_type adafactor {line_break} | |
--optimizer_args "relative_step=False" "scale_parameter=False" "warmup_init=False" {line_break} | |
--lr_scheduler constant_with_warmup {line_break} | |
--max_grad_norm 0.0 {line_break}""" | |
elif vram == "12G": | |
# 12G VRAM | |
optimizer = f"""--optimizer_type adafactor {line_break} | |
--optimizer_args "relative_step=False" "scale_parameter=False" "warmup_init=False" {line_break} | |
--split_mode {line_break} | |
--network_args "train_blocks=single" {line_break} | |
--lr_scheduler constant_with_warmup {line_break} | |
--max_grad_norm 0.0 {line_break}""" | |
else: | |
# 20G+ VRAM | |
optimizer = f"--optimizer_type adamw8bit {line_break}" | |
####################################################### | |
model_config = models[base_model] | |
model_file = model_config["file"] | |
repo = model_config["repo"] | |
if base_model == "flux-dev" or base_model == "flux-schnell": | |
model_folder = "models/unet" | |
else: | |
model_folder = f"models/unet/{repo}" | |
model_path = os.path.join(model_folder, model_file) | |
pretrained_model_path = resolve_path(model_path) | |
clip_path = resolve_path("models/clip/clip_l.safetensors") | |
t5_path = resolve_path("models/clip/t5xxl_fp16.safetensors") | |
ae_path = resolve_path("models/vae/ae.sft") | |
sh = f"""accelerate launch {line_break} | |
--mixed_precision bf16 {line_break} | |
--num_cpu_threads_per_process 1 {line_break} | |
sd-scripts/flux_train_network.py {line_break} | |
--pretrained_model_name_or_path {pretrained_model_path} {line_break} | |
--clip_l {clip_path} {line_break} | |
--t5xxl {t5_path} {line_break} | |
--ae {ae_path} {line_break} | |
--cache_latents_to_disk {line_break} | |
--save_model_as safetensors {line_break} | |
--sdpa --persistent_data_loader_workers {line_break} | |
--max_data_loader_n_workers {workers} {line_break} | |
--seed {seed} {line_break} | |
--gradient_checkpointing {line_break} | |
--mixed_precision bf16 {line_break} | |
--save_precision bf16 {line_break} | |
--network_module networks.lora_flux {line_break} | |
--network_dim {network_dim} {line_break} | |
{optimizer}{sample} | |
--learning_rate {learning_rate} {line_break} | |
--cache_text_encoder_outputs {line_break} | |
--cache_text_encoder_outputs_to_disk {line_break} | |
--fp8_base {line_break} | |
--highvram {line_break} | |
--max_train_epochs {max_train_epochs} {line_break} | |
--save_every_n_epochs {save_every_n_epochs} {line_break} | |
--dataset_config {resolve_path(f"outputs/{output_name}/dataset.toml")} {line_break} | |
--output_dir {output_dir} {line_break} | |
--output_name {output_name} {line_break} | |
--timestep_sampling {timestep_sampling} {line_break} | |
--discrete_flow_shift 3.1582 {line_break} | |
--model_prediction_type raw {line_break} | |
--guidance_scale {guidance_scale} {line_break} | |
--loss_type l2 {line_break}""" | |
############# Advanced args ######################## | |
global advanced_component_ids | |
global original_advanced_component_values | |
# check dirty | |
print(f"original_advanced_component_values = {original_advanced_component_values}") | |
advanced_flags = [] | |
for i, current_value in enumerate(advanced_components): | |
# print(f"compare {advanced_component_ids[i]}: old={original_advanced_component_values[i]}, new={current_value}") | |
if original_advanced_component_values[i] != current_value: | |
# dirty | |
if current_value == True: | |
# Boolean | |
advanced_flags.append(advanced_component_ids[i]) | |
else: | |
# string | |
advanced_flags.append(f"{advanced_component_ids[i]} {current_value}") | |
if len(advanced_flags) > 0: | |
advanced_flags_str = f" {line_break}\n ".join(advanced_flags) | |
sh = sh + "\n " + advanced_flags_str | |
return sh | |
def gen_toml( | |
dataset_folder, | |
resolution, | |
class_tokens, | |
num_repeats | |
): | |
toml = f"""[general] | |
shuffle_caption = false | |
caption_extension = '.txt' | |
keep_tokens = 1 | |
[[datasets]] | |
resolution = {resolution} | |
batch_size = 1 | |
keep_tokens = 1 | |
[[datasets.subsets]] | |
image_dir = '{resolve_path_without_quotes(dataset_folder)}' | |
class_tokens = '{class_tokens}' | |
num_repeats = {num_repeats}""" | |
return toml | |
def update_total_steps(max_train_epochs, num_repeats, images): | |
try: | |
num_images = len(images) | |
total_steps = max_train_epochs * num_images * num_repeats | |
print(f"max_train_epochs={max_train_epochs} num_images={num_images}, num_repeats={num_repeats}, total_steps={total_steps}") | |
return gr.update(value = total_steps) | |
except: | |
print("") | |
def set_repo(lora_rows): | |
selected_name = os.path.basename(lora_rows) | |
return gr.update(value=selected_name) | |
def get_loras(): | |
try: | |
outputs_path = resolve_path_without_quotes(f"outputs") | |
files = os.listdir(outputs_path) | |
folders = [os.path.join(outputs_path, item) for item in files if os.path.isdir(os.path.join(outputs_path, item)) and item != "sample"] | |
folders.sort(key=lambda file: os.path.getctime(file), reverse=True) | |
return folders | |
except Exception as e: | |
return [] | |
def get_samples(lora_name): | |
output_name = slugify(lora_name) | |
try: | |
samples_path = resolve_path_without_quotes(f"outputs/{output_name}/sample") | |
files = [os.path.join(samples_path, file) for file in os.listdir(samples_path)] | |
files.sort(key=lambda file: os.path.getctime(file), reverse=True) | |
return files | |
except: | |
return [] | |
def start_training( | |
base_model, | |
lora_name, | |
train_script, | |
train_config, | |
sample_prompts, | |
): | |
# write custom script and toml | |
if not os.path.exists("models"): | |
os.makedirs("models", exist_ok=True) | |
if not os.path.exists("outputs"): | |
os.makedirs("outputs", exist_ok=True) | |
output_name = slugify(lora_name) | |
output_dir = resolve_path_without_quotes(f"outputs/{output_name}") | |
if not os.path.exists(output_dir): | |
os.makedirs(output_dir, exist_ok=True) | |
download(base_model) | |
file_type = "sh" | |
if sys.platform == "win32": | |
file_type = "bat" | |
sh_filename = f"train.{file_type}" | |
sh_filepath = resolve_path_without_quotes(f"outputs/{output_name}/{sh_filename}") | |
with open(sh_filepath, 'w', encoding="utf-8") as file: | |
file.write(train_script) | |
gr.Info(f"Generated train script at {sh_filename}") | |
dataset_path = resolve_path_without_quotes(f"outputs/{output_name}/dataset.toml") | |
with open(dataset_path, 'w', encoding="utf-8") as file: | |
file.write(train_config) | |
gr.Info(f"Generated dataset.toml") | |
sample_prompts_path = resolve_path_without_quotes(f"outputs/{output_name}/sample_prompts.txt") | |
with open(sample_prompts_path, 'w', encoding='utf-8') as file: | |
file.write(sample_prompts) | |
gr.Info(f"Generated sample_prompts.txt") | |
# Train | |
if sys.platform == "win32": | |
command = sh_filepath | |
else: | |
command = f"bash \"{sh_filepath}\"" | |
# Use Popen to run the command and capture output in real-time | |
env = os.environ.copy() | |
env['PYTHONIOENCODING'] = 'utf-8' | |
env['LOG_LEVEL'] = 'DEBUG' | |
runner = LogsViewRunner() | |
cwd = os.path.dirname(os.path.abspath(__file__)) | |
gr.Info(f"Started training") | |
yield from runner.run_command([command], cwd=cwd) | |
yield runner.log(f"Runner: {runner}") | |
# Generate Readme | |
config = toml.loads(train_config) | |
concept_sentence = config['datasets'][0]['subsets'][0]['class_tokens'] | |
print(f"concept_sentence={concept_sentence}") | |
print(f"lora_name {lora_name}, concept_sentence={concept_sentence}, output_name={output_name}") | |
sample_prompts_path = resolve_path_without_quotes(f"outputs/{output_name}/sample_prompts.txt") | |
with open(sample_prompts_path, "r", encoding="utf-8") as f: | |
lines = f.readlines() | |
sample_prompts = [line.strip() for line in lines if len(line.strip()) > 0 and line[0] != "#"] | |
md = readme(base_model, lora_name, concept_sentence, sample_prompts) | |
readme_path = resolve_path_without_quotes(f"outputs/{output_name}/README.md") | |
with open(readme_path, "w", encoding="utf-8") as f: | |
f.write(md) | |
gr.Info(f"Training Complete. Check the outputs folder for the LoRA files.", duration=None) | |
def update( | |
base_model, | |
lora_name, | |
resolution, | |
seed, | |
workers, | |
class_tokens, | |
learning_rate, | |
network_dim, | |
max_train_epochs, | |
save_every_n_epochs, | |
timestep_sampling, | |
guidance_scale, | |
vram, | |
num_repeats, | |
sample_prompts, | |
sample_every_n_steps, | |
*advanced_components, | |
): | |
output_name = slugify(lora_name) | |
dataset_folder = str(f"datasets/{output_name}") | |
sh = gen_sh( | |
base_model, | |
output_name, | |
resolution, | |
seed, | |
workers, | |
learning_rate, | |
network_dim, | |
max_train_epochs, | |
save_every_n_epochs, | |
timestep_sampling, | |
guidance_scale, | |
vram, | |
sample_prompts, | |
sample_every_n_steps, | |
*advanced_components, | |
) | |
toml = gen_toml( | |
dataset_folder, | |
resolution, | |
class_tokens, | |
num_repeats | |
) | |
return gr.update(value=sh), gr.update(value=toml), dataset_folder | |
""" | |
demo.load(fn=loaded, js=js, outputs=[hf_token, hf_login, hf_logout, hf_account]) | |
""" | |
def loaded(): | |
global current_account | |
current_account = account_hf() | |
print(f"current_account={current_account}") | |
if current_account != None: | |
return gr.update(value=current_account["token"]), gr.update(visible=False), gr.update(visible=True), gr.update(value=current_account["account"], visible=True) | |
else: | |
return gr.update(value=""), gr.update(visible=True), gr.update(visible=False), gr.update(value="", visible=False) | |
def update_sample(concept_sentence): | |
return gr.update(value=concept_sentence) | |
def refresh_publish_tab(): | |
loras = get_loras() | |
return gr.Dropdown(label="Trained LoRAs", choices=loras) | |
def init_advanced(): | |
# if basic_args | |
basic_args = { | |
'pretrained_model_name_or_path', | |
'clip_l', | |
't5xxl', | |
'ae', | |
'cache_latents_to_disk', | |
'save_model_as', | |
'sdpa', | |
'persistent_data_loader_workers', | |
'max_data_loader_n_workers', | |
'seed', | |
'gradient_checkpointing', | |
'mixed_precision', | |
'save_precision', | |
'network_module', | |
'network_dim', | |
'learning_rate', | |
'cache_text_encoder_outputs', | |
'cache_text_encoder_outputs_to_disk', | |
'fp8_base', | |
'highvram', | |
'max_train_epochs', | |
'save_every_n_epochs', | |
'dataset_config', | |
'output_dir', | |
'output_name', | |
'timestep_sampling', | |
'discrete_flow_shift', | |
'model_prediction_type', | |
'guidance_scale', | |
'loss_type', | |
'optimizer_type', | |
'optimizer_args', | |
'lr_scheduler', | |
'sample_prompts', | |
'sample_every_n_steps', | |
'max_grad_norm', | |
'split_mode', | |
'network_args' | |
} | |
# generate a UI config | |
# if not in basic_args, create a simple form | |
parser = train_network.setup_parser() | |
flux_train_utils.add_flux_train_arguments(parser) | |
args_info = {} | |
for action in parser._actions: | |
if action.dest != 'help': # Skip the default help argument | |
# if the dest is included in basic_args | |
args_info[action.dest] = { | |
"action": action.option_strings, # Option strings like '--use_8bit_adam' | |
"type": action.type, # Type of the argument | |
"help": action.help, # Help message | |
"default": action.default, # Default value, if any | |
"required": action.required # Whether the argument is required | |
} | |
temp = [] | |
for key in args_info: | |
temp.append({ 'key': key, 'action': args_info[key] }) | |
temp.sort(key=lambda x: x['key']) | |
advanced_component_ids = [] | |
advanced_components = [] | |
for item in temp: | |
key = item['key'] | |
action = item['action'] | |
if key in basic_args: | |
print("") | |
else: | |
action_type = str(action['type']) | |
component = None | |
with gr.Column(min_width=300): | |
if action_type == "None": | |
# radio | |
component = gr.Checkbox() | |
# elif action_type == "<class 'str'>": | |
# component = gr.Textbox() | |
# elif action_type == "<class 'int'>": | |
# component = gr.Number(precision=0) | |
# elif action_type == "<class 'float'>": | |
# component = gr.Number() | |
# elif "int_or_float" in action_type: | |
# component = gr.Number() | |
else: | |
component = gr.Textbox(value="") | |
if component != None: | |
component.interactive = True | |
component.elem_id = action['action'][0] | |
component.label = component.elem_id | |
component.elem_classes = ["advanced"] | |
if action['help'] != None: | |
component.info = action['help'] | |
advanced_components.append(component) | |
advanced_component_ids.append(component.elem_id) | |
return advanced_components, advanced_component_ids | |
theme = gr.themes.Monochrome( | |
text_size=gr.themes.Size(lg="18px", md="15px", sm="13px", xl="22px", xs="12px", xxl="24px", xxs="9px"), | |
font=[gr.themes.GoogleFont("Source Sans Pro"), "ui-sans-serif", "system-ui", "sans-serif"], | |
) | |
css = """ | |
@keyframes rotate { | |
0% { | |
transform: rotate(0deg); | |
} | |
100% { | |
transform: rotate(360deg); | |
} | |
} | |
#advanced_options .advanced:nth-child(even) { background: rgba(0,0,100,0.04) !important; } | |
h1{font-family: georgia; font-style: italic; font-weight: bold; font-size: 30px; letter-spacing: -1px;} | |
h3{margin-top: 0} | |
.tabitem{border: 0px} | |
.group_padding{} | |
nav{position: fixed; top: 0; left: 0; right: 0; z-index: 1000; text-align: center; padding: 10px; box-sizing: border-box; display: flex; align-items: center; backdrop-filter: blur(10px); } | |
nav button { background: none; color: firebrick; font-weight: bold; border: 2px solid firebrick; padding: 5px 10px; border-radius: 5px; font-size: 14px; } | |
nav img { height: 40px; width: 40px; border-radius: 40px; } | |
nav img.rotate { animation: rotate 2s linear infinite; } | |
.flexible { flex-grow: 1; } | |
.tast-details { margin: 10px 0 !important; } | |
.toast-wrap { bottom: var(--size-4) !important; top: auto !important; border: none !important; backdrop-filter: blur(10px); } | |
.toast-title, .toast-text, .toast-icon, .toast-close { color: black !important; font-size: 14px; } | |
.toast-body { border: none !important; } | |
#terminal { box-shadow: none !important; margin-bottom: 25px; background: rgba(0,0,0,0.03); } | |
#terminal .generating { border: none !important; } | |
#terminal label { position: absolute !important; } | |
.tabs { margin-top: 50px; } | |
.hidden { display: none !important; } | |
.codemirror-wrapper .cm-line { font-size: 12px !important; } | |
label { font-weight: bold !important; } | |
#start_training.clicked { background: silver; color: black; } | |
""" | |
js = """ | |
function() { | |
let autoscroll = document.querySelector("#autoscroll") | |
if (window.iidxx) { | |
window.clearInterval(window.iidxx); | |
} | |
window.iidxx = window.setInterval(function() { | |
let text=document.querySelector(".codemirror-wrapper .cm-line").innerText.trim() | |
let img = document.querySelector("#logo") | |
if (text.length > 0) { | |
autoscroll.classList.remove("hidden") | |
if (autoscroll.classList.contains("on")) { | |
autoscroll.textContent = "Autoscroll ON" | |
window.scrollTo(0, document.body.scrollHeight, { behavior: "smooth" }); | |
img.classList.add("rotate") | |
} else { | |
autoscroll.textContent = "Autoscroll OFF" | |
img.classList.remove("rotate") | |
} | |
} | |
}, 500); | |
console.log("autoscroll", autoscroll) | |
autoscroll.addEventListener("click", (e) => { | |
autoscroll.classList.toggle("on") | |
}) | |
function debounce(fn, delay) { | |
let timeoutId; | |
return function(...args) { | |
clearTimeout(timeoutId); | |
timeoutId = setTimeout(() => fn(...args), delay); | |
}; | |
} | |
function handleClick() { | |
console.log("refresh") | |
document.querySelector("#refresh").click(); | |
} | |
const debouncedClick = debounce(handleClick, 1000); | |
document.addEventListener("input", debouncedClick); | |
document.querySelector("#start_training").addEventListener("click", (e) => { | |
e.target.classList.add("clicked") | |
e.target.innerHTML = "Training..." | |
}) | |
} | |
""" | |
current_account = account_hf() | |
print(f"current_account={current_account}") | |
with gr.Blocks(elem_id="app", theme=theme, css=css, fill_width=True) as demo: | |
with gr.Tabs() as tabs: | |
with gr.TabItem("Gym"): | |
output_components = [] | |
with gr.Row(): | |
gr.HTML("""<nav> | |
<img id='logo' src='/file=icon.png' width='80' height='80'> | |
<div class='flexible'></div> | |
<button id='autoscroll' class='on hidden'></button> | |
</nav> | |
""") | |
with gr.Row(elem_id='container'): | |
with gr.Column(): | |
gr.Markdown( | |
"""# Step 1. LoRA Info | |
<p style="margin-top:0">Configure your LoRA train settings.</p> | |
""", elem_classes="group_padding") | |
lora_name = gr.Textbox( | |
label="The name of your LoRA", | |
info="This has to be a unique name", | |
placeholder="e.g.: Persian Miniature Painting style, Cat Toy", | |
) | |
concept_sentence = gr.Textbox( | |
elem_id="--concept_sentence", | |
label="Trigger word/sentence", | |
info="Trigger word or sentence to be used", | |
placeholder="uncommon word like p3rs0n or trtcrd, or sentence like 'in the style of CNSTLL'", | |
interactive=True, | |
) | |
model_names = list(models.keys()) | |
print(f"model_names={model_names}") | |
base_model = gr.Dropdown(label="Base model (edit the models.yaml file to add more to this list)", choices=model_names, value=model_names[0]) | |
vram = gr.Radio(["20G", "16G", "12G" ], value="20G", label="VRAM", interactive=True) | |
num_repeats = gr.Number(value=10, precision=0, label="Repeat trains per image", interactive=True) | |
max_train_epochs = gr.Number(label="Max Train Epochs", value=16, interactive=True) | |
total_steps = gr.Number(0, interactive=False, label="Expected training steps") | |
sample_prompts = gr.Textbox("", lines=5, label="Sample Image Prompts (Separate with new lines)", interactive=True) | |
sample_every_n_steps = gr.Number(0, precision=0, label="Sample Image Every N Steps", interactive=True) | |
resolution = gr.Number(value=512, precision=0, label="Resize dataset images") | |
with gr.Column(): | |
gr.Markdown( | |
"""# Step 2. Dataset | |
<p style="margin-top:0">Make sure the captions include the trigger word.</p> | |
""", elem_classes="group_padding") | |
with gr.Group(): | |
images = gr.File( | |
file_types=["image", ".txt"], | |
label="Upload your images", | |
#info="If you want, you can also manually upload caption files that match the image names (example: img0.png => img0.txt)", | |
file_count="multiple", | |
interactive=True, | |
visible=True, | |
scale=1, | |
) | |
with gr.Group(visible=False) as captioning_area: | |
do_captioning = gr.Button("Add AI captions with Florence-2") | |
output_components.append(captioning_area) | |
#output_components = [captioning_area] | |
caption_list = [] | |
for i in range(1, MAX_IMAGES + 1): | |
locals()[f"captioning_row_{i}"] = gr.Row(visible=False) | |
with locals()[f"captioning_row_{i}"]: | |
locals()[f"image_{i}"] = gr.Image( | |
type="filepath", | |
width=111, | |
height=111, | |
min_width=111, | |
interactive=False, | |
scale=2, | |
show_label=False, | |
show_share_button=False, | |
show_download_button=False, | |
) | |
locals()[f"caption_{i}"] = gr.Textbox( | |
label=f"Caption {i}", scale=15, interactive=True | |
) | |
output_components.append(locals()[f"captioning_row_{i}"]) | |
output_components.append(locals()[f"image_{i}"]) | |
output_components.append(locals()[f"caption_{i}"]) | |
caption_list.append(locals()[f"caption_{i}"]) | |
with gr.Column(): | |
gr.Markdown( | |
"""# Step 3. Train | |
<p style="margin-top:0">Press start to start training.</p> | |
""", elem_classes="group_padding") | |
refresh = gr.Button("Refresh", elem_id="refresh", visible=False) | |
start = gr.Button("Start training", visible=False, elem_id="start_training") | |
output_components.append(start) | |
train_script = gr.Textbox(label="Train script", max_lines=100, interactive=True) | |
train_config = gr.Textbox(label="Train config", max_lines=100, interactive=True) | |
with gr.Accordion("Advanced options", elem_id='advanced_options', open=False): | |
with gr.Row(): | |
with gr.Column(min_width=300): | |
seed = gr.Number(label="--seed", info="Seed", value=42, interactive=True) | |
with gr.Column(min_width=300): | |
workers = gr.Number(label="--max_data_loader_n_workers", info="Number of Workers", value=2, interactive=True) | |
with gr.Column(min_width=300): | |
learning_rate = gr.Textbox(label="--learning_rate", info="Learning Rate", value="8e-4", interactive=True) | |
with gr.Column(min_width=300): | |
save_every_n_epochs = gr.Number(label="--save_every_n_epochs", info="Save every N epochs", value=4, interactive=True) | |
with gr.Column(min_width=300): | |
guidance_scale = gr.Number(label="--guidance_scale", info="Guidance Scale", value=1.0, interactive=True) | |
with gr.Column(min_width=300): | |
timestep_sampling = gr.Textbox(label="--timestep_sampling", info="Timestep Sampling", value="shift", interactive=True) | |
with gr.Column(min_width=300): | |
network_dim = gr.Number(label="--network_dim", info="LoRA Rank", value=4, minimum=4, maximum=128, step=4, interactive=True) | |
advanced_components, advanced_component_ids = init_advanced() | |
with gr.Row(): | |
terminal = LogsView(label="Train log", elem_id="terminal") | |
with gr.Row(): | |
gallery = gr.Gallery(get_samples, inputs=[lora_name], label="Samples", every=10, columns=6) | |
with gr.TabItem("Publish") as publish_tab: | |
hf_token = gr.Textbox(label="Huggingface Token") | |
hf_login = gr.Button("Login") | |
hf_logout = gr.Button("Logout") | |
with gr.Row() as row: | |
gr.Markdown("**LoRA**") | |
gr.Markdown("**Upload**") | |
loras = get_loras() | |
with gr.Row(): | |
lora_rows = refresh_publish_tab() | |
with gr.Column(): | |
with gr.Row(): | |
repo_owner = gr.Textbox(label="Account", interactive=False) | |
repo_name = gr.Textbox(label="Repository Name") | |
repo_visibility = gr.Textbox(label="Repository Visibility ('public' or 'private')", value="public") | |
upload_button = gr.Button("Upload to HuggingFace") | |
upload_button.click( | |
fn=upload_hf, | |
inputs=[ | |
base_model, | |
lora_rows, | |
repo_owner, | |
repo_name, | |
repo_visibility, | |
hf_token, | |
] | |
) | |
hf_login.click(fn=login_hf, inputs=[hf_token], outputs=[hf_token, hf_login, hf_logout, repo_owner]) | |
hf_logout.click(fn=logout_hf, outputs=[hf_token, hf_login, hf_logout, repo_owner]) | |
publish_tab.select(refresh_publish_tab, outputs=lora_rows) | |
lora_rows.select(fn=set_repo, inputs=[lora_rows], outputs=[repo_name]) | |
dataset_folder = gr.State() | |
listeners = [ | |
base_model, | |
lora_name, | |
resolution, | |
seed, | |
workers, | |
concept_sentence, | |
learning_rate, | |
network_dim, | |
max_train_epochs, | |
save_every_n_epochs, | |
timestep_sampling, | |
guidance_scale, | |
vram, | |
num_repeats, | |
sample_prompts, | |
sample_every_n_steps, | |
*advanced_components | |
] | |
advanced_component_ids = [x.elem_id for x in advanced_components] | |
original_advanced_component_values = [comp.value for comp in advanced_components] | |
images.upload( | |
load_captioning, | |
inputs=[images, concept_sentence], | |
outputs=output_components | |
) | |
images.delete( | |
load_captioning, | |
inputs=[images, concept_sentence], | |
outputs=output_components | |
) | |
images.clear( | |
hide_captioning, | |
outputs=[captioning_area, start] | |
) | |
max_train_epochs.change( | |
fn=update_total_steps, | |
inputs=[max_train_epochs, num_repeats, images], | |
outputs=[total_steps] | |
) | |
num_repeats.change( | |
fn=update_total_steps, | |
inputs=[max_train_epochs, num_repeats, images], | |
outputs=[total_steps] | |
) | |
images.upload( | |
fn=update_total_steps, | |
inputs=[max_train_epochs, num_repeats, images], | |
outputs=[total_steps] | |
) | |
images.delete( | |
fn=update_total_steps, | |
inputs=[max_train_epochs, num_repeats, images], | |
outputs=[total_steps] | |
) | |
images.clear( | |
fn=update_total_steps, | |
inputs=[max_train_epochs, num_repeats, images], | |
outputs=[total_steps] | |
) | |
concept_sentence.change(fn=update_sample, inputs=[concept_sentence], outputs=sample_prompts) | |
start.click(fn=create_dataset, inputs=[dataset_folder, resolution, images] + caption_list, outputs=dataset_folder).then( | |
fn=start_training, | |
inputs=[ | |
base_model, | |
lora_name, | |
train_script, | |
train_config, | |
sample_prompts, | |
], | |
outputs=terminal, | |
) | |
do_captioning.click(fn=run_captioning, inputs=[images, concept_sentence] + caption_list, outputs=caption_list) | |
demo.load(fn=loaded, js=js, outputs=[hf_token, hf_login, hf_logout, repo_owner]) | |
refresh.click(update, inputs=listeners, outputs=[train_script, train_config, dataset_folder]) | |
if __name__ == "__main__": | |
cwd = os.path.dirname(os.path.abspath(__file__)) | |
demo.launch(debug=True, show_error=True, allowed_paths=[cwd]) | |