ReNO / main.py
fffiloni's picture
we can load model an handle oom errors on model switch
38ce166 verified
raw
history blame
17.4 kB
import json
import logging
import os
import blobfile as bf
import torch
import gc
from datasets import load_dataset
from pytorch_lightning import seed_everything
from tqdm import tqdm
from arguments import parse_args
from models import get_model, get_multi_apply_fn
from rewards import get_reward_losses
from training import LatentNoiseTrainer, get_optimizer
def setup(args, loaded_model_setup=None):
seed_everything(args.seed)
bf.makedirs(f"{args.save_dir}/logs/{args.task}")
# Set up logging and name settings
logger = logging.getLogger()
logger.handlers.clear() # Clear existing handlers
settings = (
f"{args.model}{'_' + args.prompt if args.task == 't2i-compbench' else ''}"
f"{'_no-optim' if args.no_optim else ''}_{args.seed if args.task != 'geneval' else ''}"
f"_lr{args.lr}_gc{args.grad_clip}_iter{args.n_iters}"
f"_reg{args.reg_weight if args.enable_reg else '0'}"
f"{'_pickscore' + str(args.pickscore_weighting) if args.enable_pickscore else ''}"
f"{'_clip' + str(args.clip_weighting) if args.enable_clip else ''}"
f"{'_hps' + str(args.hps_weighting) if args.enable_hps else ''}"
f"{'_imagereward' + str(args.imagereward_weighting) if args.enable_imagereward else ''}"
f"{'_aesthetic' + str(args.aesthetic_weighting) if args.enable_aesthetic else ''}"
)
file_stream = open(f"{args.save_dir}/logs/{args.task}/{settings}.txt", "w")
handler = logging.StreamHandler(file_stream)
formatter = logging.Formatter("%(asctime)s - %(message)s")
handler.setFormatter(formatter)
logger.addHandler(handler)
logger.setLevel("INFO")
consoleHandler = logging.StreamHandler()
consoleHandler.setFormatter(formatter)
logger.addHandler(consoleHandler)
logging.info(args)
if args.device_id is not None:
logging.info(f"Using CUDA device {args.device_id}")
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = args.device_id
device = torch.device("cuda")
if args.dtype == "float32":
dtype = torch.float32
elif args.dtype == "float16":
dtype = torch.float16
# If args.model is the same as the one in loaded_model_setup, reuse the trainer and pipe
if loaded_model_setup and args.model == loaded_model_setup[0].model:
# Reuse the trainer and pipe from the loaded model setup
print(f"Reusing model {args.model} from loaded setup.")
trainer = loaded_model_setup[1] # Trainer is at position 1 in loaded_model_setup
# Update trainer with the new arguments
trainer.n_iters = args.n_iters
trainer.n_inference_steps = args.n_inference_steps
trainer.seed = args.seed
trainer.save_all_images = args.save_all_images
trainer.no_optim = args.no_optim
trainer.regularize = args.enable_reg
trainer.regularization_weight = args.reg_weight
trainer.grad_clip = args.grad_clip
trainer.log_metrics = args.task == "single" or not args.no_optim
trainer.imageselect = args.imageselect
# Get latents (this step is still required)
if args.model == "flux":
shape = (1, 16 * 64, 64)
elif args.model != "pixart":
height = trainer.model.unet.config.sample_size * trainer.model.vae_scale_factor
width = trainer.model.unet.config.sample_size * trainer.model.vae_scale_factor
shape = (
1,
trainer.model.unet.in_channels,
height // trainer.model.vae_scale_factor,
width // trainer.model.vae_scale_factor,
)
else:
height = trainer.model.transformer.config.sample_size * trainer.model.vae_scale_factor
width = trainer.model.transformer.config.sample_size * trainer.model.vae_scale_factor
shape = (
1,
trainer.model.transformer.config.in_channels,
height // trainer.model.vae_scale_factor,
width // trainer.model.vae_scale_factor,
)
multi_apply_fn = loaded_model_setup[6]
enable_grad = not args.no_optim
return args, trainer, device, dtype, shape, enable_grad, multi_apply_fn, settings
# Proceed with full model loading if args.model is different
print(f"Loading new model: {args.model}")
# Get reward losses
reward_losses = get_reward_losses(args, dtype, device, args.cache_dir)
# Get model and noise trainer
pipe = get_model(
args.model, dtype, device, args.cache_dir, args.memsave, args.cpu_offloading
)
# Attempt to move the model to GPU or keep it on CPU if offloading is enabled
try:
if not args.cpu_offloading:
pipe.to(device)
except RuntimeError as e:
if 'out of memory' in str(e):
print("CUDA OOM error. Attempting to handle OOM situation.")
# Attempt to clear memory and retry moving to GPU
torch.cuda.empty_cache() # Free up cached memory
gc.collect()
try:
# Retry loading after clearing cache
if not args.cpu_offloading:
pipe.to(device)
except RuntimeError as e:
print("Still facing OOM issues. Keeping model on CPU.")
args.cpu_offloading = True # Force CPU offloading
else:
raise e # Re-raise the exception if it's not OOM
torch.cuda.empty_cache() # Free up cached memory
gc.collect()
trainer = LatentNoiseTrainer(
reward_losses=reward_losses,
model=pipe,
n_iters=args.n_iters,
n_inference_steps=args.n_inference_steps,
seed=args.seed,
save_all_images=args.save_all_images,
device=device if not args.cpu_offloading else 'cpu', # Use CPU if offloading is enabled
no_optim=args.no_optim,
regularize=args.enable_reg,
regularization_weight=args.reg_weight,
grad_clip=args.grad_clip,
log_metrics=args.task == "single" or not args.no_optim,
imageselect=args.imageselect,
)
# Create latents
if args.model == "flux":
shape = (1, 16 * 64, 64)
elif args.model != "pixart":
height = pipe.unet.config.sample_size * pipe.vae_scale_factor
width = pipe.unet.config.sample_size * pipe.vae_scale_factor
shape = (
1,
pipe.unet.in_channels,
height // pipe.vae_scale_factor,
width // pipe.vae_scale_factor,
)
else:
height = pipe.transformer.config.sample_size * pipe.vae_scale_factor
width = pipe.transformer.config.sample_size * pipe.vae_scale_factor
shape = (
1,
pipe.transformer.config.in_channels,
height // pipe.vae_scale_factor,
width // pipe.vae_scale_factor,
)
enable_grad = not args.no_optim
# Final memory cleanup
torch.cuda.empty_cache() # Free up cached memory
gc.collect()
if args.enable_multi_apply:
multi_apply_fn = get_multi_apply_fn(
model_type=args.multi_step_model,
seed=args.seed,
pipe=pipe,
cache_dir=args.cache_dir,
device=device if not args.cpu_offloading else 'cpu',
dtype=dtype,
)
else:
multi_apply_fn = None
torch.cuda.empty_cache() # Free up cached memory
gc.collect()
return args, trainer, device, dtype, shape, enable_grad, multi_apply_fn, settings
def execute_task(args, trainer, device, dtype, shape, enable_grad, multi_apply_fn, settings, progress_callback=None):
if args.task == "single":
init_latents = torch.randn(shape, device=device, dtype=dtype)
latents = torch.nn.Parameter(init_latents, requires_grad=enable_grad)
optimizer = get_optimizer(args.optim, latents, args.lr, args.nesterov)
save_dir = f"{args.save_dir}/{args.task}/{settings}/{args.prompt[:150]}"
os.makedirs(f"{save_dir}", exist_ok=True)
init_image, best_image, total_init_rewards, total_best_rewards = trainer.train(
latents, args.prompt, optimizer, save_dir, multi_apply_fn, progress_callback=progress_callback
)
best_image.save(f"{save_dir}/best_image.png")
#init_image.save(f"{save_dir}/init_image.png")
elif args.task == "example-prompts":
fo = open("assets/example_prompts.txt", "r")
prompts = fo.readlines()
fo.close()
for i, prompt in tqdm(enumerate(prompts)):
# Get new latents and optimizer
init_latents = torch.randn(shape, device=device, dtype=dtype)
latents = torch.nn.Parameter(init_latents, requires_grad=enable_grad)
optimizer = get_optimizer(args.optim, latents, args.lr, args.nesterov)
prompt = prompt.strip()
name = f"{i:03d}_{prompt[:150]}.png"
save_dir = f"{args.save_dir}/{args.task}/{settings}/{name}"
os.makedirs(save_dir, exist_ok=True)
init_image, best_image, init_rewards, best_rewards = trainer.train(
latents, prompt, optimizer, save_dir, multi_apply_fn
)
if i == 0:
total_best_rewards = {k: 0.0 for k in best_rewards.keys()}
total_init_rewards = {k: 0.0 for k in best_rewards.keys()}
for k in best_rewards.keys():
total_best_rewards[k] += best_rewards[k]
total_init_rewards[k] += init_rewards[k]
best_image.save(f"{save_dir}/best_image.png")
init_image.save(f"{save_dir}/init_image.png")
logging.info(f"Initial rewards: {init_rewards}")
logging.info(f"Best rewards: {best_rewards}")
for k in total_best_rewards.keys():
total_best_rewards[k] /= len(prompts)
total_init_rewards[k] /= len(prompts)
# save results to directory
with open(f"{args.save_dir}/example-prompts/{settings}/results.txt", "w") as f:
f.write(
f"Mean initial all rewards: {total_init_rewards}\n"
f"Mean best all rewards: {total_best_rewards}\n"
)
elif args.task == "t2i-compbench":
prompt_list_file = f"../T2I-CompBench/examples/dataset/{args.prompt}.txt"
fo = open(prompt_list_file, "r")
prompts = fo.readlines()
fo.close()
os.makedirs(f"{args.save_dir}/{args.task}/{settings}/samples", exist_ok=True)
for i, prompt in tqdm(enumerate(prompts)):
# Get new latents and optimizer
init_latents = torch.randn(shape, device=device, dtype=dtype)
latents = torch.nn.Parameter(init_latents, requires_grad=enable_grad)
optimizer = get_optimizer(args.optim, latents, args.lr, args.nesterov)
prompt = prompt.strip()
init_image, best_image, init_rewards, best_rewards = trainer.train(
latents, prompt, optimizer, None, multi_apply_fn
)
if i == 0:
total_best_rewards = {k: 0.0 for k in best_rewards.keys()}
total_init_rewards = {k: 0.0 for k in best_rewards.keys()}
for k in best_rewards.keys():
total_best_rewards[k] += best_rewards[k]
total_init_rewards[k] += init_rewards[k]
name = f"{prompt}_{i:06d}.png"
best_image.save(f"{args.save_dir}/{args.task}/{settings}/samples/{name}")
logging.info(f"Initial rewards: {init_rewards}")
logging.info(f"Best rewards: {best_rewards}")
for k in total_best_rewards.keys():
total_best_rewards[k] /= len(prompts)
total_init_rewards[k] /= len(prompts)
elif args.task == "parti-prompts":
parti_dataset = load_dataset("nateraw/parti-prompts", split="train")
total_reward_diff = 0.0
total_best_reward = 0.0
total_init_reward = 0.0
total_improved_samples = 0
for index, sample in enumerate(parti_dataset):
os.makedirs(
f"{args.save_dir}/{args.task}/{settings}/{index}", exist_ok=True
)
prompt = sample["Prompt"]
init_image, best_image, init_rewards, best_rewards = trainer.train(
latents, prompt, optimizer, multi_apply_fn
)
best_image.save(
f"{args.save_dir}/{args.task}/{settings}/{index}/best_image.png"
)
open(
f"{args.save_dir}/{args.task}/{settings}/{index}/prompt.txt", "w"
).write(
f"{prompt} \n Initial Rewards: {init_rewards} \n Best Rewards: {best_rewards}"
)
logging.info(f"Initial rewards: {init_rewards}")
logging.info(f"Best rewards: {best_rewards}")
initial_reward = init_rewards[args.benchmark_reward]
best_reward = best_rewards[args.benchmark_reward]
total_reward_diff += best_reward - initial_reward
total_best_reward += best_reward
total_init_reward += initial_reward
if best_reward < initial_reward:
total_improved_samples += 1
if i == 0:
total_best_rewards = {k: 0.0 for k in best_rewards.keys()}
total_init_rewards = {k: 0.0 for k in best_rewards.keys()}
for k in best_rewards.keys():
total_best_rewards[k] += best_rewards[k]
total_init_rewards[k] += init_rewards[k]
# Get new latents and optimizer
init_latents = torch.randn(shape, device=device, dtype=dtype)
latents = torch.nn.Parameter(init_latents, requires_grad=enable_grad)
optimizer = get_optimizer(args.optim, latents, args.lr, args.nesterov)
improvement_percentage = total_improved_samples / parti_dataset.num_rows
mean_best_reward = total_best_reward / parti_dataset.num_rows
mean_init_reward = total_init_reward / parti_dataset.num_rows
mean_reward_diff = total_reward_diff / parti_dataset.num_rows
logging.info(
f"Improvement percentage: {improvement_percentage:.4f}, "
f"mean initial reward: {mean_init_reward:.4f}, "
f"mean best reward: {mean_best_reward:.4f}, "
f"mean reward diff: {mean_reward_diff:.4f}"
)
for k in total_best_rewards.keys():
total_best_rewards[k] /= len(parti_dataset)
total_init_rewards[k] /= len(parti_dataset)
# save results
os.makedirs(f"{args.save_dir}/parti-prompts/{settings}", exist_ok=True)
with open(f"{args.save_dir}/parti-prompts/{settings}/results.txt", "w") as f:
f.write(
f"Mean improvement: {improvement_percentage:.4f}, "
f"mean initial reward: {mean_init_reward:.4f}, "
f"mean best reward: {mean_best_reward:.4f}, "
f"mean reward diff: {mean_reward_diff:.4f}\n"
f"Mean initial all rewards: {total_init_rewards}\n"
f"Mean best all rewards: {total_best_rewards}"
)
elif args.task == "geneval":
prompt_list_file = "../geneval/prompts/evaluation_metadata.jsonl"
with open(prompt_list_file) as fp:
metadatas = [json.loads(line) for line in fp]
outdir = f"{args.save_dir}/{args.task}/{settings}"
for index, metadata in enumerate(metadatas):
# Get new latents and optimizer
init_latents = torch.randn(shape, device=device, dtype=dtype)
latents = torch.nn.Parameter(init_latents, requires_grad=True)
optimizer = get_optimizer(args.optim, latents, args.lr, args.nesterov)
prompt = metadata["prompt"]
init_image, best_image, init_rewards, best_rewards = trainer.train(
latents, prompt, optimizer, None, multi_apply_fn
)
logging.info(f"Initial rewards: {init_rewards}")
logging.info(f"Best rewards: {best_rewards}")
outpath = f"{outdir}/{index:0>5}"
os.makedirs(f"{outpath}/samples", exist_ok=True)
with open(f"{outpath}/metadata.jsonl", "w") as fp:
json.dump(metadata, fp)
best_image.save(f"{outpath}/samples/{args.seed:05}.png")
if i == 0:
total_best_rewards = {k: 0.0 for k in best_rewards.keys()}
total_init_rewards = {k: 0.0 for k in best_rewards.keys()}
for k in best_rewards.keys():
total_best_rewards[k] += best_rewards[k]
total_init_rewards[k] += init_rewards[k]
for k in total_best_rewards.keys():
total_best_rewards[k] /= len(parti_dataset)
total_init_rewards[k] /= len(parti_dataset)
else:
raise ValueError(f"Unknown task {args.task}")
# log total rewards
logging.info(f"Mean initial rewards: {total_init_rewards}")
logging.info(f"Mean best rewards: {total_best_rewards}")
def main():
args = parse_args()
args, trainer, device, dtype, shape, enable_grad, multi_apply_fn, settings = setup(args, loaded_model_setup=None)
execute_task(args, trainer, device, dtype, shape, enable_grad, multi_apply_fn, settings)
if __name__ == "__main__":
main()