File size: 7,728 Bytes
5a0ae56 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 |
import os
import shutil
import tarfile
from cog import BaseModel, Input, Path
from predict import SDXL_MODEL_CACHE, SDXL_URL, download_weights
from preprocess import preprocess
from trainer_pti import main
"""
Wrapper around actual trainer.
"""
OUTPUT_DIR = "training_out"
class TrainingOutput(BaseModel):
weights: Path
from typing import Tuple
def train(
input_images: Path = Input(
description="A .zip or .tar file containing the image files that will be used for fine-tuning"
),
seed: int = Input(
description="Random seed for reproducible training. Leave empty to use a random seed",
default=None,
),
resolution: int = Input(
description="Square pixel resolution which your images will be resized to for training",
default=768,
),
train_batch_size: int = Input(
description="Batch size (per device) for training",
default=4,
),
num_train_epochs: int = Input(
description="Number of epochs to loop through your training dataset",
default=4000,
),
max_train_steps: int = Input(
description="Number of individual training steps. Takes precedence over num_train_epochs",
default=1000,
),
# gradient_accumulation_steps: int = Input(
# description="Number of training steps to accumulate before a backward pass. Effective batch size = gradient_accumulation_steps * batch_size",
# default=1,
# ), # todo.
is_lora: bool = Input(
description="Whether to use LoRA training. If set to False, will use Full fine tuning",
default=True,
),
unet_learning_rate: float = Input(
description="Learning rate for the U-Net. We recommend this value to be somewhere between `1e-6` to `1e-5`.",
default=1e-6,
),
ti_lr: float = Input(
description="Scaling of learning rate for training textual inversion embeddings. Don't alter unless you know what you're doing.",
default=3e-4,
),
lora_lr: float = Input(
description="Scaling of learning rate for training LoRA embeddings. Don't alter unless you know what you're doing.",
default=1e-4,
),
lora_rank: int = Input(
description="Rank of LoRA embeddings. Don't alter unless you know what you're doing.",
default=32,
),
lr_scheduler: str = Input(
description="Learning rate scheduler to use for training",
default="constant",
choices=[
"constant",
"linear",
],
),
lr_warmup_steps: int = Input(
description="Number of warmup steps for lr schedulers with warmups.",
default=100,
),
token_string: str = Input(
description="A unique string that will be trained to refer to the concept in the input images. Can be anything, but TOK works well",
default="TOK",
),
# token_map: str = Input(
# description="String of token and their impact size specificing tokens used in the dataset. This will be in format of `token1:size1,token2:size2,...`.",
# default="TOK:2",
# ),
caption_prefix: str = Input(
description="Text which will be used as prefix during automatic captioning. Must contain the `token_string`. For example, if caption text is 'a photo of TOK', automatic captioning will expand to 'a photo of TOK under a bridge', 'a photo of TOK holding a cup', etc.",
default="a photo of TOK, ",
),
mask_target_prompts: str = Input(
description="Prompt that describes part of the image that you will find important. For example, if you are fine-tuning your pet, `photo of a dog` will be a good prompt. Prompt-based masking is used to focus the fine-tuning process on the important/salient parts of the image",
default=None,
),
crop_based_on_salience: bool = Input(
description="If you want to crop the image to `target_size` based on the important parts of the image, set this to True. If you want to crop the image based on face detection, set this to False",
default=True,
),
use_face_detection_instead: bool = Input(
description="If you want to use face detection instead of CLIPSeg for masking. For face applications, we recommend using this option.",
default=False,
),
clipseg_temperature: float = Input(
description="How blurry you want the CLIPSeg mask to be. We recommend this value be something between `0.5` to `1.0`. If you want to have more sharp mask (but thus more errorful), you can decrease this value.",
default=1.0,
),
verbose: bool = Input(description="verbose output", default=True),
checkpointing_steps: int = Input(
description="Number of steps between saving checkpoints. Set to very very high number to disable checkpointing, because you don't need one.",
default=999999,
),
input_images_filetype: str = Input(
description="Filetype of the input images. Can be either `zip` or `tar`. By default its `infer`, and it will be inferred from the ext of input file.",
default="infer",
choices=["zip", "tar", "infer"],
),
) -> TrainingOutput:
# Hard-code token_map for now. Make it configurable once we support multiple concepts or user-uploaded caption csv.
token_map = token_string + ":2"
# Process 'token_to_train' and 'input_data_tar_or_zip'
inserting_list_tokens = token_map.split(",")
token_dict = {}
running_tok_cnt = 0
all_token_lists = []
for token in inserting_list_tokens:
n_tok = int(token.split(":")[1])
token_dict[token.split(":")[0]] = "".join(
[f"<s{i + running_tok_cnt}>" for i in range(n_tok)]
)
all_token_lists.extend([f"<s{i + running_tok_cnt}>" for i in range(n_tok)])
running_tok_cnt += n_tok
input_dir = preprocess(
input_images_filetype=input_images_filetype,
input_zip_path=input_images,
caption_text=caption_prefix,
mask_target_prompts=mask_target_prompts,
target_size=resolution,
crop_based_on_salience=crop_based_on_salience,
use_face_detection_instead=use_face_detection_instead,
temp=clipseg_temperature,
substitution_tokens=list(token_dict.keys()),
)
if not os.path.exists(SDXL_MODEL_CACHE):
download_weights(SDXL_URL, SDXL_MODEL_CACHE)
if os.path.exists(OUTPUT_DIR):
shutil.rmtree(OUTPUT_DIR)
os.makedirs(OUTPUT_DIR)
main(
pretrained_model_name_or_path=SDXL_MODEL_CACHE,
instance_data_dir=os.path.join(input_dir, "captions.csv"),
output_dir=OUTPUT_DIR,
seed=seed,
resolution=resolution,
train_batch_size=train_batch_size,
num_train_epochs=num_train_epochs,
max_train_steps=max_train_steps,
gradient_accumulation_steps=1,
unet_learning_rate=unet_learning_rate,
ti_lr=ti_lr,
lora_lr=lora_lr,
lr_scheduler=lr_scheduler,
lr_warmup_steps=lr_warmup_steps,
token_dict=token_dict,
inserting_list_tokens=all_token_lists,
verbose=verbose,
checkpointing_steps=checkpointing_steps,
scale_lr=False,
max_grad_norm=1.0,
allow_tf32=True,
mixed_precision="bf16",
device="cuda:0",
lora_rank=lora_rank,
is_lora=is_lora,
)
directory = Path(OUTPUT_DIR)
out_path = "trained_model.tar"
with tarfile.open(out_path, "w") as tar:
for file_path in directory.rglob("*"):
print(file_path)
arcname = file_path.relative_to(directory)
tar.add(file_path, arcname=arcname)
return TrainingOutput(weights=Path(out_path))
|