import random from accelerate.utils import set_seed from torch.cuda.amp import autocast from StableDiffuser import StableDiffuser from finetuning import FineTunedModel import torch from tqdm import tqdm from isolate_rng import isolate_rng from memory_efficiency import MemoryEfficiencyWrapper def train(repo_id_or_path, img_size, prompt, modules, freeze_modules, iterations, negative_guidance, lr, save_path, use_adamw8bit=True, use_xformers=True, use_amp=True, use_gradient_checkpointing=False, seed=-1, save_every=-1): nsteps = 50 diffuser = StableDiffuser(scheduler='DDIM', repo_id_or_path=repo_id_or_path).to('cuda') memory_efficiency_wrapper = MemoryEfficiencyWrapper(diffuser=diffuser, use_amp=use_amp, use_xformers=use_xformers, use_gradient_checkpointing=use_gradient_checkpointing ) with memory_efficiency_wrapper: diffuser.train() finetuner = FineTunedModel(diffuser, modules, frozen_modules=freeze_modules) if use_adamw8bit: print("using AdamW 8Bit optimizer") import bitsandbytes as bnb optimizer = bnb.optim.AdamW8bit(finetuner.parameters(), lr=lr, betas=(0.9, 0.999), weight_decay=0.010, eps=1e-8 ) else: print("using Adam optimizer") optimizer = torch.optim.Adam(finetuner.parameters(), lr=lr) criteria = torch.nn.MSELoss() pbar = tqdm(range(iterations)) with torch.no_grad(): neutral_text_embeddings = diffuser.get_text_embeddings([''],n_imgs=1) positive_text_embeddings = diffuser.get_text_embeddings([prompt],n_imgs=1) del diffuser.vae del diffuser.text_encoder del diffuser.tokenizer torch.cuda.empty_cache() print(f"using img_size of {img_size}") if seed == -1: seed = random.randint(0, 2 ** 30) set_seed(int(seed)) prev_losses = [] start_loss = None max_prev_loss_count = 10 for i in pbar: with torch.no_grad(): diffuser.set_scheduler_timesteps(nsteps) optimizer.zero_grad() iteration = torch.randint(1, nsteps - 1, (1,)).item() latents = diffuser.get_initial_latents(1, width=img_size, height=img_size, n_prompts=1) with finetuner: latents_steps, _ = diffuser.diffusion( latents, positive_text_embeddings, start_iteration=0, end_iteration=iteration, guidance_scale=3, show_progress=False, use_amp=use_amp ) diffuser.set_scheduler_timesteps(1000) iteration = int(iteration / nsteps * 1000) with autocast(enabled=use_amp): positive_latents = diffuser.predict_noise(iteration, latents_steps[0], positive_text_embeddings, guidance_scale=1) neutral_latents = diffuser.predict_noise(iteration, latents_steps[0], neutral_text_embeddings, guidance_scale=1) with finetuner: with autocast(enabled=use_amp): negative_latents = diffuser.predict_noise(iteration, latents_steps[0], positive_text_embeddings, guidance_scale=1) positive_latents.requires_grad = False neutral_latents.requires_grad = False # loss = criteria(e_n, e_0) works the best try 5000 epochs loss = criteria(negative_latents, neutral_latents - (negative_guidance*(positive_latents - neutral_latents))) memory_efficiency_wrapper.step(optimizer, loss) optimizer.zero_grad() # print moving average loss prev_losses.append(loss.detach().clone()) if len(prev_losses) > max_prev_loss_count: prev_losses.pop(0) if start_loss is None: start_loss = prev_losses[-1] if len(prev_losses) >= max_prev_loss_count: moving_average_loss = sum(prev_losses) / len(prev_losses) print( f"step {i}: loss={loss.item()} (avg={moving_average_loss.item()}, start ∆={(moving_average_loss - start_loss).item()}") else: print(f"step {i}: loss={loss.item()}") if save_every > 0 and ((i % save_every) == (save_every-1)): torch.save(finetuner.state_dict(), save_path + f"__step_{i}.pt") torch.save(finetuner.state_dict(), save_path) del diffuser, loss, optimizer, finetuner, negative_latents, neutral_latents, positive_latents, latents_steps, latents torch.cuda.empty_cache() if __name__ == '__main__': import argparse parser = argparse.ArgumentParser() parser.add_argument("--repo_id_or_path", required=True) parser.add_argument("--img_size", type=int, required=False, default=512) parser.add_argument('--prompt', required=True) parser.add_argument('--modules', required=True) parser.add_argument('--freeze_modules', nargs='+', required=True) parser.add_argument('--save_path', required=True) parser.add_argument('--iterations', type=int, required=True) parser.add_argument('--lr', type=float, required=True) parser.add_argument('--negative_guidance', type=float, required=True) parser.add_argument('--seed', type=int, required=False, default=-1, help='Training seed for reproducible results, or -1 to pick a random seed') parser.add_argument('--use_adamw8bit', action='store_true') parser.add_argument('--use_xformers', action='store_true') parser.add_argument('--use_amp', action='store_true') parser.add_argument('--use_gradient_checkpointing', action='store_true') train(**vars(parser.parse_args()))