multimodalart's picture
Upload 52 files
18d0601 verified
import os
import shutil
import tarfile
from cog import BaseModel, Input, Path
from predict import SDXL_MODEL_CACHE, SDXL_URL, download_weights
from preprocess import preprocess
from trainer_pti import main
"""
Wrapper around actual trainer.
"""
OUTPUT_DIR = "training_out"
class TrainingOutput(BaseModel):
weights: Path
from typing import Tuple
def train(
input_images: Path = Input(
description="A .zip or .tar file containing the image files that will be used for fine-tuning"
),
seed: int = Input(
description="Random seed for reproducible training. Leave empty to use a random seed",
default=None,
),
resolution: int = Input(
description="Square pixel resolution which your images will be resized to for training",
default=768,
),
train_batch_size: int = Input(
description="Batch size (per device) for training",
default=4,
),
num_train_epochs: int = Input(
description="Number of epochs to loop through your training dataset",
default=4000,
),
max_train_steps: int = Input(
description="Number of individual training steps. Takes precedence over num_train_epochs",
default=1000,
),
# gradient_accumulation_steps: int = Input(
# description="Number of training steps to accumulate before a backward pass. Effective batch size = gradient_accumulation_steps * batch_size",
# default=1,
# ), # todo.
is_lora: bool = Input(
description="Whether to use LoRA training. If set to False, will use Full fine tuning",
default=True,
),
unet_learning_rate: float = Input(
description="Learning rate for the U-Net. We recommend this value to be somewhere between `1e-6` to `1e-5`.",
default=1e-6,
),
ti_lr: float = Input(
description="Scaling of learning rate for training textual inversion embeddings. Don't alter unless you know what you're doing.",
default=3e-4,
),
lora_lr: float = Input(
description="Scaling of learning rate for training LoRA embeddings. Don't alter unless you know what you're doing.",
default=1e-4,
),
lora_rank: int = Input(
description="Rank of LoRA embeddings. Don't alter unless you know what you're doing.",
default=32,
),
lr_scheduler: str = Input(
description="Learning rate scheduler to use for training",
default="constant",
choices=[
"constant",
"linear",
],
),
lr_warmup_steps: int = Input(
description="Number of warmup steps for lr schedulers with warmups.",
default=100,
),
token_string: str = Input(
description="A unique string that will be trained to refer to the concept in the input images. Can be anything, but TOK works well",
default="TOK",
),
# token_map: str = Input(
# description="String of token and their impact size specificing tokens used in the dataset. This will be in format of `token1:size1,token2:size2,...`.",
# default="TOK:2",
# ),
caption_prefix: str = Input(
description="Text which will be used as prefix during automatic captioning. Must contain the `token_string`. For example, if caption text is 'a photo of TOK', automatic captioning will expand to 'a photo of TOK under a bridge', 'a photo of TOK holding a cup', etc.",
default="a photo of TOK, ",
),
mask_target_prompts: str = Input(
description="Prompt that describes part of the image that you will find important. For example, if you are fine-tuning your pet, `photo of a dog` will be a good prompt. Prompt-based masking is used to focus the fine-tuning process on the important/salient parts of the image",
default=None,
),
crop_based_on_salience: bool = Input(
description="If you want to crop the image to `target_size` based on the important parts of the image, set this to True. If you want to crop the image based on face detection, set this to False",
default=True,
),
use_face_detection_instead: bool = Input(
description="If you want to use face detection instead of CLIPSeg for masking. For face applications, we recommend using this option.",
default=False,
),
clipseg_temperature: float = Input(
description="How blurry you want the CLIPSeg mask to be. We recommend this value be something between `0.5` to `1.0`. If you want to have more sharp mask (but thus more errorful), you can decrease this value.",
default=1.0,
),
verbose: bool = Input(description="verbose output", default=True),
checkpointing_steps: int = Input(
description="Number of steps between saving checkpoints. Set to very very high number to disable checkpointing, because you don't need one.",
default=999999,
),
input_images_filetype: str = Input(
description="Filetype of the input images. Can be either `zip` or `tar`. By default its `infer`, and it will be inferred from the ext of input file.",
default="infer",
choices=["zip", "tar", "infer"],
),
) -> TrainingOutput:
# Hard-code token_map for now. Make it configurable once we support multiple concepts or user-uploaded caption csv.
token_map = token_string + ":2"
# Process 'token_to_train' and 'input_data_tar_or_zip'
inserting_list_tokens = token_map.split(",")
token_dict = {}
running_tok_cnt = 0
all_token_lists = []
for token in inserting_list_tokens:
n_tok = int(token.split(":")[1])
token_dict[token.split(":")[0]] = "".join(
[f"<s{i + running_tok_cnt}>" for i in range(n_tok)]
)
all_token_lists.extend([f"<s{i + running_tok_cnt}>" for i in range(n_tok)])
running_tok_cnt += n_tok
input_dir = preprocess(
input_images_filetype=input_images_filetype,
input_zip_path=input_images,
caption_text=caption_prefix,
mask_target_prompts=mask_target_prompts,
target_size=resolution,
crop_based_on_salience=crop_based_on_salience,
use_face_detection_instead=use_face_detection_instead,
temp=clipseg_temperature,
substitution_tokens=list(token_dict.keys()),
)
if not os.path.exists(SDXL_MODEL_CACHE):
download_weights(SDXL_URL, SDXL_MODEL_CACHE)
if os.path.exists(OUTPUT_DIR):
shutil.rmtree(OUTPUT_DIR)
os.makedirs(OUTPUT_DIR)
main(
pretrained_model_name_or_path=SDXL_MODEL_CACHE,
instance_data_dir=os.path.join(input_dir, "captions.csv"),
output_dir=OUTPUT_DIR,
seed=seed,
resolution=resolution,
train_batch_size=train_batch_size,
num_train_epochs=num_train_epochs,
max_train_steps=max_train_steps,
gradient_accumulation_steps=1,
unet_learning_rate=unet_learning_rate,
ti_lr=ti_lr,
lora_lr=lora_lr,
lr_scheduler=lr_scheduler,
lr_warmup_steps=lr_warmup_steps,
token_dict=token_dict,
inserting_list_tokens=all_token_lists,
verbose=verbose,
checkpointing_steps=checkpointing_steps,
scale_lr=False,
max_grad_norm=1.0,
allow_tf32=True,
mixed_precision="bf16",
device="cuda:0",
lora_rank=lora_rank,
is_lora=is_lora,
)
directory = Path(OUTPUT_DIR)
out_path = "trained_model.tar"
with tarfile.open(out_path, "w") as tar:
for file_path in directory.rglob("*"):
print(file_path)
arcname = file_path.relative_to(directory)
tar.add(file_path, arcname=arcname)
return TrainingOutput(weights=Path(out_path))