Spaces:
Runtime error
Runtime error
File size: 5,226 Bytes
94be4c7 ab11bdd 81ccbca fc73e59 94be4c7 81ccbca ab11bdd 0002379 81ccbca fc73e59 b58675c fc73e59 ab11bdd fc73e59 81ccbca fc73e59 81ccbca fc73e59 81ccbca fc73e59 0002379 fc73e59 94be4c7 c8aa68b 94be4c7 fc73e59 81ccbca fc73e59 94be4c7 fc73e59 94be4c7 fc73e59 ab11bdd 81ccbca ab11bdd fc73e59 ab11bdd fc73e59 94be4c7 81ccbca 7c89716 fd9afda 81ccbca 0002379 81ccbca 94be4c7 81ccbca |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 |
from random import random
from accelerate.utils import set_seed
from torch.cuda.amp import autocast
from StableDiffuser import StableDiffuser
from finetuning import FineTunedModel
import torch
from tqdm import tqdm
from memory_efficiency import MemoryEfficiencyWrapper
def train(repo_id_or_path, img_size, prompt, modules, freeze_modules, iterations, negative_guidance, lr, save_path,
use_adamw8bit=True, use_xformers=True, use_amp=True, use_gradient_checkpointing=False, seed=-1):
nsteps = 50
diffuser = StableDiffuser(scheduler='DDIM', repo_id_or_path=repo_id_or_path).to('cuda')
memory_efficiency_wrapper = MemoryEfficiencyWrapper(diffuser=diffuser, use_amp=use_amp, use_xformers=use_xformers,
use_gradient_checkpointing=use_gradient_checkpointing )
with memory_efficiency_wrapper:
diffuser.train()
finetuner = FineTunedModel(diffuser, modules, frozen_modules=freeze_modules)
if use_adamw8bit:
print("using AdamW 8Bit optimizer")
import bitsandbytes as bnb
optimizer = bnb.optim.AdamW8bit(finetuner.parameters(),
lr=lr,
betas=(0.9, 0.999),
weight_decay=0.010,
eps=1e-8
)
else:
print("using Adam optimizer")
optimizer = torch.optim.Adam(finetuner.parameters(), lr=lr)
criteria = torch.nn.MSELoss()
pbar = tqdm(range(iterations))
with torch.no_grad():
neutral_text_embeddings = diffuser.get_text_embeddings([''],n_imgs=1)
positive_text_embeddings = diffuser.get_text_embeddings([prompt],n_imgs=1)
del diffuser.vae
del diffuser.text_encoder
del diffuser.tokenizer
torch.cuda.empty_cache()
print(f"using img_size of {img_size}")
if seed == -1:
seed = random.randint(0, 2 ** 30)
set_seed(int(seed))
for i in pbar:
with torch.no_grad():
diffuser.set_scheduler_timesteps(nsteps)
optimizer.zero_grad()
iteration = torch.randint(1, nsteps - 1, (1,)).item()
latents = diffuser.get_initial_latents(1, width=img_size, height=img_size, n_prompts=1)
with finetuner:
latents_steps, _ = diffuser.diffusion(
latents,
positive_text_embeddings,
start_iteration=0,
end_iteration=iteration,
guidance_scale=3,
show_progress=False,
use_amp=use_amp
)
diffuser.set_scheduler_timesteps(1000)
iteration = int(iteration / nsteps * 1000)
with autocast(enabled=use_amp):
positive_latents = diffuser.predict_noise(iteration, latents_steps[0], positive_text_embeddings, guidance_scale=1)
neutral_latents = diffuser.predict_noise(iteration, latents_steps[0], neutral_text_embeddings, guidance_scale=1)
with finetuner:
with autocast(enabled=use_amp):
negative_latents = diffuser.predict_noise(iteration, latents_steps[0], positive_text_embeddings, guidance_scale=1)
positive_latents.requires_grad = False
neutral_latents.requires_grad = False
# loss = criteria(e_n, e_0) works the best try 5000 epochs
loss = criteria(negative_latents, neutral_latents - (negative_guidance*(positive_latents - neutral_latents)))
memory_efficiency_wrapper.step(optimizer, loss)
optimizer.zero_grad()
torch.save(finetuner.state_dict(), save_path)
del diffuser, loss, optimizer, finetuner, negative_latents, neutral_latents, positive_latents, latents_steps, latents
torch.cuda.empty_cache()
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--repo_id_or_path", required=True)
parser.add_argument("--img_size", type=int, required=False, default=512)
parser.add_argument('--prompt', required=True)
parser.add_argument('--modules', required=True)
parser.add_argument('--freeze_modules', nargs='+', required=True)
parser.add_argument('--save_path', required=True)
parser.add_argument('--iterations', type=int, required=True)
parser.add_argument('--lr', type=float, required=True)
parser.add_argument('--negative_guidance', type=float, required=True)
parser.add_argument('--seed', type=int, required=False, default=-1,
help='Training seed for reproducible results, or -1 to pick a random seed')
parser.add_argument('--use_adamw8bit', action='store_true')
parser.add_argument('--use_xformers', action='store_true')
parser.add_argument('--use_amp', action='store_true')
parser.add_argument('--use_gradient_checkpointing', action='store_true')
train(**vars(parser.parse_args())) |