import copy from pdb import set_trace as st import functools import os import numpy as np import blobfile as bf import torch as th import torch.distributed as dist from torch.nn.parallel.distributed import DistributedDataParallel as DDP from torch.optim import AdamW from . import dist_util, logger from .fp16_util import MixedPrecisionTrainer from .nn import update_ema from .resample import LossAwareSampler, UniformSampler from pathlib import Path # For ImageNet experiments, this was a good default value. # We found that the lg_loss_scale quickly climbed to # 20-21 within the first ~1K steps of training. INITIAL_LOG_LOSS_SCALE = 20.0 # use_amp = True # use_amp = False # if use_amp: # logger.log('ddpm use AMP to accelerate training') class TrainLoop: def __init__( self, *, model, diffusion, data, batch_size, microbatch, lr, ema_rate, log_interval, save_interval, resume_checkpoint, use_fp16=False, fp16_scale_growth=1e-3, schedule_sampler=None, weight_decay=0.0, lr_anneal_steps=0, use_amp=False, ): self.model = model self.diffusion = diffusion self.data = data self.batch_size = batch_size self.microbatch = microbatch if microbatch > 0 else batch_size self.lr = lr self.ema_rate = ([ema_rate] if isinstance(ema_rate, float) else [float(x) for x in ema_rate.split(",")]) self.log_interval = log_interval self.save_interval = save_interval self.resume_checkpoint = resume_checkpoint self.use_fp16 = use_fp16 self.fp16_scale_growth = fp16_scale_growth self.schedule_sampler = schedule_sampler or UniformSampler(diffusion) self.weight_decay = weight_decay self.lr_anneal_steps = lr_anneal_steps self.step = 0 self.resume_step = 0 self.global_batch = self.batch_size * dist.get_world_size() self.sync_cuda = th.cuda.is_available() self._load_and_sync_parameters() self.mp_trainer = MixedPrecisionTrainer( model=self.model, use_fp16=self.use_fp16, fp16_scale_growth=fp16_scale_growth, use_amp=use_amp, ) self.opt = AdamW(self.mp_trainer.master_params, lr=self.lr, weight_decay=self.weight_decay) if self.resume_step: self._load_optimizer_state() # Model was resumed, either due to a restart or a checkpoint # being specified at the command line. self.ema_params = [ self._load_ema_parameters(rate) for rate in self.ema_rate ] else: self.ema_params = [ copy.deepcopy(self.mp_trainer.master_params) for _ in range(len(self.ema_rate)) ] # print('creating DDP') if th.cuda.is_available(): self.use_ddp = True self.ddp_model = DDP( self.model, device_ids=[dist_util.dev()], output_device=dist_util.dev(), broadcast_buffers=False, bucket_cap_mb=128, find_unused_parameters=False, ) else: if dist.get_world_size() > 1: logger.warn("Distributed training requires CUDA. " "Gradients will not be synchronized properly!") self.use_ddp = False self.ddp_model = self.model # print('creating DDP done') def _load_and_sync_parameters(self): resume_checkpoint, resume_step = find_resume_checkpoint( ) or self.resume_checkpoint if resume_checkpoint: if not Path(resume_checkpoint).exists(): logger.log( f"failed to load model from checkpoint: {resume_checkpoint}, not exist" ) return # self.resume_step = parse_resume_step_from_filename(resume_checkpoint) self.resume_step = resume_step # TODO, EMA part if dist.get_rank() == 0: logger.log( f"loading model from checkpoint: {resume_checkpoint}...") # if model is None: # model = self.model self.model.load_state_dict( dist_util.load_state_dict( resume_checkpoint, map_location=dist_util.dev(), )) dist_util.sync_params(self.model.parameters()) def _load_ema_parameters(self, rate, model=None, mp_trainer=None, model_name='ddpm'): if mp_trainer is None: mp_trainer = self.mp_trainer if model is None: model = self.model ema_params = copy.deepcopy(mp_trainer.master_params) main_checkpoint, _ = find_resume_checkpoint( self.resume_checkpoint, model_name) or self.resume_checkpoint ema_checkpoint = find_ema_checkpoint(main_checkpoint, self.resume_step, rate, model_name) if ema_checkpoint: if dist_util.get_rank() == 0: if not Path(ema_checkpoint).exists(): logger.log( f"failed to load EMA from checkpoint: {ema_checkpoint}, not exist" ) return logger.log(f"loading EMA from checkpoint: {ema_checkpoint}...") map_location = { 'cuda:%d' % 0: 'cuda:%d' % dist_util.get_rank() } # configure map_location properly state_dict = dist_util.load_state_dict( ema_checkpoint, map_location=map_location) model_ema_state_dict = model.state_dict() for k, v in state_dict.items(): if k in model_ema_state_dict.keys() and v.size( ) == model_ema_state_dict[k].size(): model_ema_state_dict[k] = v else: logger.log('ignore key: ', k, ": ", v.size()) ema_params = mp_trainer.state_dict_to_master_params( model_ema_state_dict) del state_dict # print('ema mark 3, ', model_name, flush=True) if dist_util.get_world_size() > 1: dist_util.sync_params(ema_params) # print('ema mark 4, ', model_name, flush=True) # del ema_params return ema_params def _load_ema_parameters_freezeAE( self, rate, model, # mp_trainer=None, model_name='rec'): # if mp_trainer is None: # mp_trainer = self.mp_trainer # if model is None: # model = self.model_rec # ema_params = copy.deepcopy(mp_trainer.master_params) main_checkpoint, _ = find_resume_checkpoint( self.resume_checkpoint, model_name) or self.resume_checkpoint ema_checkpoint = find_ema_checkpoint(main_checkpoint, self.resume_step, rate, model_name) if ema_checkpoint: if dist_util.get_rank() == 0: if not Path(ema_checkpoint).exists(): logger.log( f"failed to load EMA from checkpoint: {ema_checkpoint}, not exist" ) return logger.log(f"loading EMA from checkpoint: {ema_checkpoint}...") map_location = { 'cuda:%d' % 0: 'cuda:%d' % dist_util.get_rank() } # configure map_location properly state_dict = dist_util.load_state_dict( ema_checkpoint, map_location=map_location) model_ema_state_dict = model.state_dict() for k, v in state_dict.items(): if k in model_ema_state_dict.keys() and v.size( ) == model_ema_state_dict[k].size(): model_ema_state_dict[k] = v else: logger.log('ignore key: ', k, ": ", v.size()) ema_params = mp_trainer.state_dict_to_master_params( model_ema_state_dict) del state_dict # print('ema mark 3, ', model_name, flush=True) if dist_util.get_world_size() > 1: dist_util.sync_params(ema_params) # print('ema mark 4, ', model_name, flush=True) # del ema_params return ema_params # def _load_ema_parameters(self, rate): # ema_params = copy.deepcopy(self.mp_trainer.master_params) # main_checkpoint, _ = find_resume_checkpoint() or self.resume_checkpoint # ema_checkpoint = find_ema_checkpoint(main_checkpoint, self.resume_step, rate) # if ema_checkpoint: # if dist.get_rank() == 0: # logger.log(f"loading EMA from checkpoint: {ema_checkpoint}...") # state_dict = dist_util.load_state_dict( # ema_checkpoint, map_location=dist_util.dev() # ) # ema_params = self.mp_trainer.state_dict_to_master_params(state_dict) # dist_util.sync_params(ema_params) # return ema_params def _load_optimizer_state(self): main_checkpoint, _ = find_resume_checkpoint() or self.resume_checkpoint opt_checkpoint = bf.join(bf.dirname(main_checkpoint), f"opt{self.resume_step:06}.pt") if bf.exists(opt_checkpoint): logger.log( f"loading optimizer state from checkpoint: {opt_checkpoint}") state_dict = dist_util.load_state_dict( opt_checkpoint, map_location=dist_util.dev()) self.opt.load_state_dict(state_dict) def run_loop(self): while (not self.lr_anneal_steps or self.step + self.resume_step < self.lr_anneal_steps): batch, cond = next(self.data) self.run_step(batch, cond) if self.step % self.log_interval == 0: logger.dumpkvs() if self.step % self.save_interval == 0: self.save() # Run for a finite amount of time in integration tests. if os.environ.get("DIFFUSION_TRAINING_TEST", "") and self.step > 0: return self.step += 1 # Save the last checkpoint if it wasn't already saved. if (self.step - 1) % self.save_interval != 0: self.save() def run_step(self, batch, cond): self.forward_backward(batch, cond) took_step = self.mp_trainer.optimize(self.opt) if took_step: self._update_ema() self._anneal_lr() self.log_step() def forward_backward(self, batch, cond): self.mp_trainer.zero_grad() for i in range(0, batch.shape[0], self.microbatch): # st() with th.autocast(device_type=dist_util.dev(), dtype=th.float16, enabled=self.mp_trainer.use_amp): micro = batch[i:i + self.microbatch].to(dist_util.dev()) micro_cond = { k: v[i:i + self.microbatch].to(dist_util.dev()) for k, v in cond.items() } last_batch = (i + self.microbatch) >= batch.shape[0] t, weights = self.schedule_sampler.sample( micro.shape[0], dist_util.dev()) compute_losses = functools.partial( self.diffusion.training_losses, self.ddp_model, micro, t, model_kwargs=micro_cond, ) if last_batch or not self.use_ddp: losses = compute_losses() else: with self.ddp_model.no_sync(): losses = compute_losses() if isinstance(self.schedule_sampler, LossAwareSampler): self.schedule_sampler.update_with_local_losses( t, losses["loss"].detach()) loss = (losses["loss"] * weights).mean() log_loss_dict(self.diffusion, t, {k: v * weights for k, v in losses.items()}) self.mp_trainer.backward(loss) def _update_ema(self): for rate, params in zip(self.ema_rate, self.ema_params): update_ema(params, self.mp_trainer.master_params, rate=rate) def _anneal_lr(self): if not self.lr_anneal_steps: return frac_done = (self.step + self.resume_step) / self.lr_anneal_steps lr = self.lr * (1 - frac_done) for param_group in self.opt.param_groups: param_group["lr"] = lr def log_step(self): logger.logkv("step", self.step + self.resume_step) logger.logkv("samples", (self.step + self.resume_step + 1) * self.global_batch) def save(self): def save_checkpoint(rate, params): state_dict = self.mp_trainer.master_params_to_state_dict(params) if dist.get_rank() == 0: logger.log(f"saving model {rate}...") if not rate: filename = f"model{(self.step+self.resume_step):07d}.pt" else: filename = f"ema_{rate}_{(self.step+self.resume_step):07d}.pt" with bf.BlobFile(bf.join(get_blob_logdir(), filename), "wb") as f: th.save(state_dict, f) save_checkpoint(0, self.mp_trainer.master_params) for rate, params in zip(self.ema_rate, self.ema_params): save_checkpoint(rate, params) if dist.get_rank() == 0: with bf.BlobFile( bf.join(get_blob_logdir(), f"opt{(self.step+self.resume_step):07d}.pt"), "wb", ) as f: th.save(self.opt.state_dict(), f) dist.barrier() def parse_resume_step_from_filename(filename): """ Parse filenames of the form path/to/modelNNNNNN.pt, where NNNNNN is the checkpoint's number of steps. """ split1 = Path(filename).stem[-6:] # split = filename.split("model") # if len(split) < 2: # return 0 # split1 = split[-1].split(".")[0] try: return int(split1) except ValueError: print('fail to load model step', split1) return 0 def get_blob_logdir(): # You can change this to be a separate path to save checkpoints to # a blobstore or some external drive. return logger.get_dir() def find_resume_checkpoint(resume_checkpoint='', model_name='ddpm'): # On your infrastructure, you may want to override this to automatically # discover the latest checkpoint on your blob storage, etc. if resume_checkpoint != '': step = parse_resume_step_from_filename(resume_checkpoint) split = resume_checkpoint.split("model") resume_ckpt_path = str( Path(split[0]) / f'model_{model_name}{step:07d}.pt') else: resume_ckpt_path = '' step = 0 return resume_ckpt_path, step def find_ema_checkpoint(main_checkpoint, step, rate, model_name=''): if main_checkpoint is None: return None if model_name == '': filename = f"ema_{rate}_{(step):07d}.pt" else: filename = f"ema_{model_name}_{rate}_{(step):07d}.pt" path = bf.join(bf.dirname(main_checkpoint), filename) # print(path) # st() if bf.exists(path): print('load ema model', path) return path else: print('fail to load ema model', path) return None def log_loss_dict(diffusion, ts, losses): for key, values in losses.items(): logger.logkv_mean(key, values.mean().item()) # Log the quantiles (four quartiles, in particular). for sub_t, sub_loss in zip(ts.cpu().numpy(), values.detach().cpu().numpy()): quartile = int(4 * sub_t / diffusion.num_timesteps) logger.logkv_mean(f"{key}_q{quartile}", sub_loss) def log_rec3d_loss_dict(loss_dict): for key, values in loss_dict.items(): logger.logkv_mean(key, values.mean().item()) def calc_average_loss(all_loss_dicts): all_scores = {} # todo, defaultdict mean_all_scores = {} for loss_dict in all_loss_dicts: for k, v in loss_dict.items(): v = v.item() if k not in all_scores: # all_scores[f'{k}_val'] = [v] all_scores[k] = [v] else: all_scores[k].append(v) for k, v in all_scores.items(): mean = np.mean(v) std = np.std(v) if k in ['loss_lpis', 'loss_ssim']: mean = 1 - mean result_str = '{} average loss is {:.4f} +- {:.4f}'.format(k, mean, std) mean_all_scores[k] = mean print(result_str) val_scores_for_logging = { f'{k}_val': v for k, v in mean_all_scores.items() } return val_scores_for_logging