Commit
•
1235e6e
1
Parent(s):
173552f
Update training 2
Browse files- app.py +12 -4
- train_dreambooth.py +68 -11
app.py
CHANGED
@@ -30,7 +30,7 @@ maximum_concepts = 3
|
|
30 |
|
31 |
#Pre download the files
|
32 |
model_v1 = snapshot_download(repo_id="multimodalart/sd-fine-tunable")
|
33 |
-
|
34 |
model_v2_512 = snapshot_download(repo_id="stabilityai/stable-diffusion-2-base")
|
35 |
safety_checker = snapshot_download(repo_id="multimodalart/sd-sc")
|
36 |
|
@@ -171,6 +171,10 @@ def train(*inputs):
|
|
171 |
Training_Steps=1400
|
172 |
|
173 |
stptxt = int((Training_Steps*Train_text_encoder_for)/100)
|
|
|
|
|
|
|
|
|
174 |
if (type_of_thing == "object" or type_of_thing == "style" or (type_of_thing == "person" and not experimental_face_improvement)):
|
175 |
args_general = argparse.Namespace(
|
176 |
image_captions_filename = True,
|
@@ -183,7 +187,7 @@ def train(*inputs):
|
|
183 |
output_dir="output_model",
|
184 |
instance_prompt="",
|
185 |
seed=42,
|
186 |
-
resolution=
|
187 |
mixed_precision="fp16",
|
188 |
train_batch_size=1,
|
189 |
gradient_accumulation_steps=1,
|
@@ -192,6 +196,8 @@ def train(*inputs):
|
|
192 |
lr_scheduler="polynomial",
|
193 |
lr_warmup_steps = 0,
|
194 |
max_train_steps=Training_Steps,
|
|
|
|
|
195 |
)
|
196 |
print("Starting single training...")
|
197 |
lock_file = open("intraining.lock", "w")
|
@@ -211,7 +217,7 @@ def train(*inputs):
|
|
211 |
prior_loss_weight=1.0,
|
212 |
instance_prompt="",
|
213 |
seed=42,
|
214 |
-
resolution=
|
215 |
mixed_precision="fp16",
|
216 |
train_batch_size=1,
|
217 |
gradient_accumulation_steps=1,
|
@@ -220,7 +226,9 @@ def train(*inputs):
|
|
220 |
lr_scheduler="polynomial",
|
221 |
lr_warmup_steps = 0,
|
222 |
max_train_steps=Training_Steps,
|
223 |
-
num_class_images=200,
|
|
|
|
|
224 |
)
|
225 |
print("Starting multi-training...")
|
226 |
lock_file = open("intraining.lock", "w")
|
|
|
30 |
|
31 |
#Pre download the files
|
32 |
model_v1 = snapshot_download(repo_id="multimodalart/sd-fine-tunable")
|
33 |
+
model_v2 = snapshot_download(repo_id="stabilityai/stable-diffusion-2")
|
34 |
model_v2_512 = snapshot_download(repo_id="stabilityai/stable-diffusion-2-base")
|
35 |
safety_checker = snapshot_download(repo_id="multimodalart/sd-sc")
|
36 |
|
|
|
171 |
Training_Steps=1400
|
172 |
|
173 |
stptxt = int((Training_Steps*Train_text_encoder_for)/100)
|
174 |
+
#gradient_checkpointing = False if which_model == "v1-5" else True
|
175 |
+
gradient_checkpointing=False
|
176 |
+
resolution = 512 if which_model != "v2-768" else 768
|
177 |
+
cache_latents = True if which_model != "v1-5" else False
|
178 |
if (type_of_thing == "object" or type_of_thing == "style" or (type_of_thing == "person" and not experimental_face_improvement)):
|
179 |
args_general = argparse.Namespace(
|
180 |
image_captions_filename = True,
|
|
|
187 |
output_dir="output_model",
|
188 |
instance_prompt="",
|
189 |
seed=42,
|
190 |
+
resolution=resolution,
|
191 |
mixed_precision="fp16",
|
192 |
train_batch_size=1,
|
193 |
gradient_accumulation_steps=1,
|
|
|
196 |
lr_scheduler="polynomial",
|
197 |
lr_warmup_steps = 0,
|
198 |
max_train_steps=Training_Steps,
|
199 |
+
gradient_checkpointing=gradient_checkpointing,
|
200 |
+
cache_latents=cache_latents,
|
201 |
)
|
202 |
print("Starting single training...")
|
203 |
lock_file = open("intraining.lock", "w")
|
|
|
217 |
prior_loss_weight=1.0,
|
218 |
instance_prompt="",
|
219 |
seed=42,
|
220 |
+
resolution=resolution,
|
221 |
mixed_precision="fp16",
|
222 |
train_batch_size=1,
|
223 |
gradient_accumulation_steps=1,
|
|
|
226 |
lr_scheduler="polynomial",
|
227 |
lr_warmup_steps = 0,
|
228 |
max_train_steps=Training_Steps,
|
229 |
+
num_class_images=200,
|
230 |
+
gradient_checkpointing=gradient_checkpointing,
|
231 |
+
cache_latents=cache_latents,
|
232 |
)
|
233 |
print("Starting multi-training...")
|
234 |
lock_file = open("intraining.lock", "w")
|
train_dreambooth.py
CHANGED
@@ -235,6 +235,13 @@ def parse_args():
|
|
235 |
help="Train only the unet",
|
236 |
)
|
237 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
238 |
parser.add_argument(
|
239 |
"--Session_dir",
|
240 |
type=str,
|
@@ -382,6 +389,16 @@ class PromptDataset(Dataset):
|
|
382 |
example["index"] = index
|
383 |
return example
|
384 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
385 |
|
386 |
def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
|
387 |
if token is None:
|
@@ -631,6 +648,28 @@ def run_training(args_imported):
|
|
631 |
if not args.train_text_encoder:
|
632 |
text_encoder.to(accelerator.device, dtype=weight_dtype)
|
633 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
634 |
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
635 |
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
636 |
if overrode_max_train_steps:
|
@@ -669,8 +708,12 @@ def run_training(args_imported):
|
|
669 |
for step, batch in enumerate(train_dataloader):
|
670 |
with accelerator.accumulate(unet):
|
671 |
# Convert images to latent space
|
672 |
-
|
673 |
-
|
|
|
|
|
|
|
|
|
674 |
|
675 |
# Sample noise that we'll add to the latents
|
676 |
noise = torch.randn_like(latents)
|
@@ -684,26 +727,40 @@ def run_training(args_imported):
|
|
684 |
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
685 |
|
686 |
# Get the text embedding for conditioning
|
687 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
688 |
|
689 |
# Predict the noise residual
|
690 |
-
|
691 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
692 |
if args.with_prior_preservation:
|
693 |
-
# Chunk the noise and
|
694 |
-
|
695 |
-
|
696 |
|
697 |
# Compute instance loss
|
698 |
-
loss = F.mse_loss(
|
699 |
|
700 |
# Compute prior loss
|
701 |
-
prior_loss = F.mse_loss(
|
702 |
|
703 |
# Add the prior loss to the instance loss.
|
704 |
loss = loss + args.prior_loss_weight * prior_loss
|
705 |
else:
|
706 |
-
loss = F.mse_loss(
|
707 |
|
708 |
accelerator.backward(loss)
|
709 |
if accelerator.sync_gradients:
|
|
|
235 |
help="Train only the unet",
|
236 |
)
|
237 |
|
238 |
+
parser.add_argument(
|
239 |
+
"--cache_latents",
|
240 |
+
action="store_true",
|
241 |
+
default=False,
|
242 |
+
help="Train only the unet",
|
243 |
+
)
|
244 |
+
|
245 |
parser.add_argument(
|
246 |
"--Session_dir",
|
247 |
type=str,
|
|
|
389 |
example["index"] = index
|
390 |
return example
|
391 |
|
392 |
+
class LatentsDataset(Dataset):
|
393 |
+
def __init__(self, latents_cache, text_encoder_cache):
|
394 |
+
self.latents_cache = latents_cache
|
395 |
+
self.text_encoder_cache = text_encoder_cache
|
396 |
+
|
397 |
+
def __len__(self):
|
398 |
+
return len(self.latents_cache)
|
399 |
+
|
400 |
+
def __getitem__(self, index):
|
401 |
+
return self.latents_cache[index], self.text_encoder_cache[index]
|
402 |
|
403 |
def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
|
404 |
if token is None:
|
|
|
648 |
if not args.train_text_encoder:
|
649 |
text_encoder.to(accelerator.device, dtype=weight_dtype)
|
650 |
|
651 |
+
|
652 |
+
if args.cache_latents:
|
653 |
+
latents_cache = []
|
654 |
+
text_encoder_cache = []
|
655 |
+
for batch in tqdm(train_dataloader, desc="Caching latents"):
|
656 |
+
with torch.no_grad():
|
657 |
+
batch["pixel_values"] = batch["pixel_values"].to(accelerator.device, non_blocking=True, dtype=weight_dtype)
|
658 |
+
batch["input_ids"] = batch["input_ids"].to(accelerator.device, non_blocking=True)
|
659 |
+
latents_cache.append(vae.encode(batch["pixel_values"]).latent_dist)
|
660 |
+
if args.train_text_encoder:
|
661 |
+
text_encoder_cache.append(batch["input_ids"])
|
662 |
+
else:
|
663 |
+
text_encoder_cache.append(text_encoder(batch["input_ids"])[0])
|
664 |
+
train_dataset = LatentsDataset(latents_cache, text_encoder_cache)
|
665 |
+
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=1, collate_fn=lambda x: x, shuffle=True)
|
666 |
+
|
667 |
+
del vae
|
668 |
+
if not args.train_text_encoder:
|
669 |
+
del text_encoder
|
670 |
+
if torch.cuda.is_available():
|
671 |
+
torch.cuda.empty_cache()
|
672 |
+
|
673 |
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
674 |
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
675 |
if overrode_max_train_steps:
|
|
|
708 |
for step, batch in enumerate(train_dataloader):
|
709 |
with accelerator.accumulate(unet):
|
710 |
# Convert images to latent space
|
711 |
+
with torch.no_grad():
|
712 |
+
if args.cache_latents:
|
713 |
+
latents = batch[0][0]
|
714 |
+
else:
|
715 |
+
latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
|
716 |
+
latents = latents * 0.18215
|
717 |
|
718 |
# Sample noise that we'll add to the latents
|
719 |
noise = torch.randn_like(latents)
|
|
|
727 |
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
728 |
|
729 |
# Get the text embedding for conditioning
|
730 |
+
if(args.cache_latents):
|
731 |
+
if args.train_text_encoder:
|
732 |
+
encoder_hidden_states = text_encoder(batch[0][1])[0]
|
733 |
+
else:
|
734 |
+
encoder_hidden_states = batch[0][1]
|
735 |
+
else:
|
736 |
+
encoder_hidden_states = text_encoder(batch["input_ids"])[0]
|
737 |
|
738 |
# Predict the noise residual
|
739 |
+
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
740 |
+
|
741 |
+
# Get the target for loss depending on the prediction type
|
742 |
+
if noise_scheduler.config.prediction_type == "epsilon":
|
743 |
+
target = noise
|
744 |
+
elif noise_scheduler.config.prediction_type == "v_prediction":
|
745 |
+
target = noise_scheduler.get_velocity(latents, noise, timesteps)
|
746 |
+
else:
|
747 |
+
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
|
748 |
+
|
749 |
if args.with_prior_preservation:
|
750 |
+
# Chunk the noise and model_pred into two parts and compute the loss on each part separately.
|
751 |
+
model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
|
752 |
+
target, target_prior = torch.chunk(target, 2, dim=0)
|
753 |
|
754 |
# Compute instance loss
|
755 |
+
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean()
|
756 |
|
757 |
# Compute prior loss
|
758 |
+
prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
|
759 |
|
760 |
# Add the prior loss to the instance loss.
|
761 |
loss = loss + args.prior_loss_weight * prior_loss
|
762 |
else:
|
763 |
+
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean()
|
764 |
|
765 |
accelerator.backward(loss)
|
766 |
if accelerator.sync_gradients:
|