from copy import deepcopy import colossalai import torch import torch.distributed as dist import wandb from colossalai.booster import Booster from colossalai.booster.plugin import LowLevelZeroPlugin from colossalai.cluster import DistCoordinator from colossalai.nn.optimizer import HybridAdam from colossalai.utils import get_current_device from tqdm import tqdm from opensora.acceleration.checkpoint import set_grad_checkpoint from opensora.acceleration.parallel_states import ( get_data_parallel_group, set_data_parallel_group, set_sequence_parallel_group, ) from opensora.acceleration.plugin import ZeroSeqParallelPlugin from opensora.datasets import DatasetFromCSV, get_transforms_image, get_transforms_video, prepare_dataloader from opensora.registry import MODELS, SCHEDULERS, build_module from opensora.utils.ckpt_utils import create_logger, load, model_sharding, record_model_param_shape, save from opensora.utils.config_utils import ( create_experiment_workspace, create_tensorboard_writer, parse_configs, save_training_config, ) from opensora.utils.misc import all_reduce_mean, format_numel_str, get_model_numel, requires_grad, to_torch_dtype from opensora.utils.train_utils import update_ema def main(): # ====================================================== # 1. args & cfg # ====================================================== cfg = parse_configs(training=True) print(cfg) exp_name, exp_dir = create_experiment_workspace(cfg) save_training_config(cfg._cfg_dict, exp_dir) # ====================================================== # 2. runtime variables & colossalai launch # ====================================================== assert torch.cuda.is_available(), "Training currently requires at least one GPU." assert cfg.dtype in ["fp16", "bf16"], f"Unknown mixed precision {cfg.dtype}" # 2.1. colossalai init distributed training colossalai.launch_from_torch({}) coordinator = DistCoordinator() device = get_current_device() dtype = to_torch_dtype(cfg.dtype) # 2.2. init logger, tensorboard & wandb if not coordinator.is_master(): logger = create_logger(None) else: logger = create_logger(exp_dir) logger.info(f"Experiment directory created at {exp_dir}") writer = create_tensorboard_writer(exp_dir) if cfg.wandb: wandb.init(project="minisora", name=exp_name, config=cfg._cfg_dict) # 2.3. initialize ColossalAI booster if cfg.plugin == "zero2": plugin = LowLevelZeroPlugin( stage=2, precision=cfg.dtype, initial_scale=2**16, max_norm=cfg.grad_clip, ) set_data_parallel_group(dist.group.WORLD) elif cfg.plugin == "zero2-seq": plugin = ZeroSeqParallelPlugin( sp_size=cfg.sp_size, stage=2, precision=cfg.dtype, initial_scale=2**16, max_norm=cfg.grad_clip, ) set_sequence_parallel_group(plugin.sp_group) set_data_parallel_group(plugin.dp_group) else: raise ValueError(f"Unknown plugin {cfg.plugin}") booster = Booster(plugin=plugin) # ====================================================== # 3. build dataset and dataloader # ====================================================== dataset = DatasetFromCSV( cfg.data_path, # TODO: change transforms transform=( get_transforms_video(cfg.image_size[0]) if not cfg.use_image_transform else get_transforms_image(cfg.image_size[0]) ), num_frames=cfg.num_frames, frame_interval=cfg.frame_interval, root=cfg.root, ) # TODO: use plugin's prepare dataloader # a batch contains: # { # "video": torch.Tensor, # [B, C, T, H, W], # "text": List[str], # } dataloader = prepare_dataloader( dataset, batch_size=cfg.batch_size, num_workers=cfg.num_workers, shuffle=True, drop_last=True, pin_memory=True, process_group=get_data_parallel_group(), ) logger.info(f"Dataset contains {len(dataset):,} videos ({cfg.data_path})") total_batch_size = cfg.batch_size * dist.get_world_size() // cfg.sp_size logger.info(f"Total batch size: {total_batch_size}") # ====================================================== # 4. build model # ====================================================== # 4.1. build model input_size = (cfg.num_frames, *cfg.image_size) vae = build_module(cfg.vae, MODELS) latent_size = vae.get_latent_size(input_size) text_encoder = build_module(cfg.text_encoder, MODELS, device=device) # T5 must be fp32 model = build_module( cfg.model, MODELS, input_size=latent_size, in_channels=vae.out_channels, caption_channels=text_encoder.output_dim, model_max_length=text_encoder.model_max_length, dtype=dtype, ) model_numel, model_numel_trainable = get_model_numel(model) logger.info( f"Trainable model params: {format_numel_str(model_numel_trainable)}, Total model params: {format_numel_str(model_numel)}" ) # 4.2. create ema ema = deepcopy(model).to(torch.float32).to(device) requires_grad(ema, False) ema_shape_dict = record_model_param_shape(ema) # 4.3. move to device vae = vae.to(device, dtype) model = model.to(device, dtype) # 4.4. build scheduler scheduler = build_module(cfg.scheduler, SCHEDULERS) # 4.5. setup optimizer optimizer = HybridAdam( filter(lambda p: p.requires_grad, model.parameters()), lr=cfg.lr, weight_decay=0, adamw_mode=True ) lr_scheduler = None # 4.6. prepare for training if cfg.grad_checkpoint: set_grad_checkpoint(model) model.train() update_ema(ema, model, decay=0, sharded=False) ema.eval() # ======================================================= # 5. boost model for distributed training with colossalai # ======================================================= torch.set_default_dtype(dtype) model, optimizer, _, dataloader, lr_scheduler = booster.boost( model=model, optimizer=optimizer, lr_scheduler=lr_scheduler, dataloader=dataloader ) torch.set_default_dtype(torch.float) num_steps_per_epoch = len(dataloader) logger.info("Boost model for distributed training") # ======================================================= # 6. training loop # ======================================================= start_epoch = start_step = log_step = sampler_start_idx = 0 running_loss = 0.0 # 6.1. resume training if cfg.load is not None: logger.info("Loading checkpoint") start_epoch, start_step, sampler_start_idx = load(booster, model, ema, optimizer, lr_scheduler, cfg.load) logger.info(f"Loaded checkpoint {cfg.load} at epoch {start_epoch} step {start_step}") logger.info(f"Training for {cfg.epochs} epochs with {num_steps_per_epoch} steps per epoch") dataloader.sampler.set_start_index(sampler_start_idx) model_sharding(ema) # 6.2. training loop for epoch in range(start_epoch, cfg.epochs): dataloader.sampler.set_epoch(epoch) dataloader_iter = iter(dataloader) logger.info(f"Beginning epoch {epoch}...") with tqdm( range(start_step, num_steps_per_epoch), desc=f"Epoch {epoch}", disable=not coordinator.is_master(), total=num_steps_per_epoch, initial=start_step, ) as pbar: for step in pbar: batch = next(dataloader_iter) x = batch["video"].to(device, dtype) # [B, C, T, H, W] y = batch["text"] with torch.no_grad(): # Prepare visual inputs x = vae.encode(x) # [B, C, T, H/P, W/P] # Prepare text inputs model_args = text_encoder.encode(y) # Diffusion t = torch.randint(0, scheduler.num_timesteps, (x.shape[0],), device=device) loss_dict = scheduler.training_losses(model, x, t, model_args) # Backward & update loss = loss_dict["loss"].mean() booster.backward(loss=loss, optimizer=optimizer) optimizer.step() optimizer.zero_grad() # Update EMA update_ema(ema, model.module, optimizer=optimizer) # Log loss values: all_reduce_mean(loss) running_loss += loss.item() global_step = epoch * num_steps_per_epoch + step log_step += 1 # Log to tensorboard if coordinator.is_master() and (global_step + 1) % cfg.log_every == 0: avg_loss = running_loss / log_step pbar.set_postfix({"loss": avg_loss, "step": step, "global_step": global_step}) running_loss = 0 log_step = 0 writer.add_scalar("loss", loss.item(), global_step) if cfg.wandb: wandb.log( { "iter": global_step, "num_samples": global_step * total_batch_size, "epoch": epoch, "loss": loss.item(), "avg_loss": avg_loss, }, step=global_step, ) # Save checkpoint if cfg.ckpt_every > 0 and (global_step + 1) % cfg.ckpt_every == 0: save( booster, model, ema, optimizer, lr_scheduler, epoch, step + 1, global_step + 1, cfg.batch_size, coordinator, exp_dir, ema_shape_dict, ) logger.info( f"Saved checkpoint at epoch {epoch} step {step + 1} global_step {global_step + 1} to {exp_dir}" ) # the continue epochs are not resumed, so we need to reset the sampler start index and start step dataloader.sampler.set_start_index(0) start_step = 0 if __name__ == "__main__": main()