Spaces:
Runtime error
Runtime error
from random import random | |
from accelerate.utils import set_seed | |
from torch.cuda.amp import autocast | |
from StableDiffuser import StableDiffuser | |
from finetuning import FineTunedModel | |
import torch | |
from tqdm import tqdm | |
from memory_efficiency import MemoryEfficiencyWrapper | |
def train(repo_id_or_path, img_size, prompt, modules, freeze_modules, iterations, negative_guidance, lr, save_path, | |
use_adamw8bit=True, use_xformers=True, use_amp=True, use_gradient_checkpointing=False, seed=-1): | |
nsteps = 50 | |
diffuser = StableDiffuser(scheduler='DDIM', repo_id_or_path=repo_id_or_path).to('cuda') | |
memory_efficiency_wrapper = MemoryEfficiencyWrapper(diffuser=diffuser, use_amp=use_amp, use_xformers=use_xformers, | |
use_gradient_checkpointing=use_gradient_checkpointing ) | |
with memory_efficiency_wrapper: | |
diffuser.train() | |
finetuner = FineTunedModel(diffuser, modules, frozen_modules=freeze_modules) | |
if use_adamw8bit: | |
print("using AdamW 8Bit optimizer") | |
import bitsandbytes as bnb | |
optimizer = bnb.optim.AdamW8bit(finetuner.parameters(), | |
lr=lr, | |
betas=(0.9, 0.999), | |
weight_decay=0.010, | |
eps=1e-8 | |
) | |
else: | |
print("using Adam optimizer") | |
optimizer = torch.optim.Adam(finetuner.parameters(), lr=lr) | |
criteria = torch.nn.MSELoss() | |
pbar = tqdm(range(iterations)) | |
with torch.no_grad(): | |
neutral_text_embeddings = diffuser.get_text_embeddings([''],n_imgs=1) | |
positive_text_embeddings = diffuser.get_text_embeddings([prompt],n_imgs=1) | |
del diffuser.vae | |
del diffuser.text_encoder | |
del diffuser.tokenizer | |
torch.cuda.empty_cache() | |
print(f"using img_size of {img_size}") | |
if seed == -1: | |
seed = random.randint(0, 2 ** 30) | |
set_seed(int(seed)) | |
for i in pbar: | |
with torch.no_grad(): | |
diffuser.set_scheduler_timesteps(nsteps) | |
optimizer.zero_grad() | |
iteration = torch.randint(1, nsteps - 1, (1,)).item() | |
latents = diffuser.get_initial_latents(1, width=img_size, height=img_size, n_prompts=1) | |
with finetuner: | |
latents_steps, _ = diffuser.diffusion( | |
latents, | |
positive_text_embeddings, | |
start_iteration=0, | |
end_iteration=iteration, | |
guidance_scale=3, | |
show_progress=False, | |
use_amp=use_amp | |
) | |
diffuser.set_scheduler_timesteps(1000) | |
iteration = int(iteration / nsteps * 1000) | |
with autocast(enabled=use_amp): | |
positive_latents = diffuser.predict_noise(iteration, latents_steps[0], positive_text_embeddings, guidance_scale=1) | |
neutral_latents = diffuser.predict_noise(iteration, latents_steps[0], neutral_text_embeddings, guidance_scale=1) | |
with finetuner: | |
with autocast(enabled=use_amp): | |
negative_latents = diffuser.predict_noise(iteration, latents_steps[0], positive_text_embeddings, guidance_scale=1) | |
positive_latents.requires_grad = False | |
neutral_latents.requires_grad = False | |
# loss = criteria(e_n, e_0) works the best try 5000 epochs | |
loss = criteria(negative_latents, neutral_latents - (negative_guidance*(positive_latents - neutral_latents))) | |
memory_efficiency_wrapper.step(optimizer, loss) | |
optimizer.zero_grad() | |
torch.save(finetuner.state_dict(), save_path) | |
del diffuser, loss, optimizer, finetuner, negative_latents, neutral_latents, positive_latents, latents_steps, latents | |
torch.cuda.empty_cache() | |
if __name__ == '__main__': | |
import argparse | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--repo_id_or_path", required=True) | |
parser.add_argument("--img_size", type=int, required=False, default=512) | |
parser.add_argument('--prompt', required=True) | |
parser.add_argument('--modules', required=True) | |
parser.add_argument('--freeze_modules', nargs='+', required=True) | |
parser.add_argument('--save_path', required=True) | |
parser.add_argument('--iterations', type=int, required=True) | |
parser.add_argument('--lr', type=float, required=True) | |
parser.add_argument('--negative_guidance', type=float, required=True) | |
parser.add_argument('--seed', type=int, required=False, default=-1, | |
help='Training seed for reproducible results, or -1 to pick a random seed') | |
parser.add_argument('--use_adamw8bit', action='store_true') | |
parser.add_argument('--use_xformers', action='store_true') | |
parser.add_argument('--use_amp', action='store_true') | |
parser.add_argument('--use_gradient_checkpointing', action='store_true') | |
train(**vars(parser.parse_args())) |