# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. import datetime import logging import sys import os import math import time import torch import torch.distributed as dist from maskrcnn_benchmark.utils.comm import get_world_size, all_gather, is_main_process, broadcast_data, get_rank from maskrcnn_benchmark.utils.metric_logger import MetricLogger from maskrcnn_benchmark.utils.ema import ModelEma from maskrcnn_benchmark.utils.amp import autocast, GradScaler from maskrcnn_benchmark.data.datasets.evaluation import evaluate from .inference import inference import pdb def reduce_loss_dict(loss_dict): """ Reduce the loss dictionary from all processes so that process with rank 0 has the averaged results. Returns a dict with the same fields as loss_dict, after reduction. """ world_size = get_world_size() if world_size < 2: return loss_dict with torch.no_grad(): loss_names = [] all_losses = [] for k in sorted(loss_dict.keys()): loss_names.append(k) all_losses.append(loss_dict[k]) all_losses = torch.stack(all_losses, dim=0) dist.reduce(all_losses, dst=0) if dist.get_rank() == 0: # only main process gets accumulated, so only divide by # world_size in this case all_losses /= world_size reduced_losses = {k: v for k, v in zip(loss_names, all_losses)} return reduced_losses def do_train( cfg, model, data_loader, optimizer, scheduler, checkpointer, device, checkpoint_period, arguments, val_data_loader=None, meters=None, zero_shot=False ): logger = logging.getLogger("maskrcnn_benchmark.trainer") logger.info("Start training") # meters = MetricLogger(delimiter=" ") max_iter = len(data_loader) start_iter = arguments["iteration"] model.train() model_ema = None if cfg.SOLVER.MODEL_EMA > 0: model_ema = ModelEma(model, decay=cfg.SOLVER.MODEL_EMA) start_training_time = time.time() end = time.time() if cfg.SOLVER.USE_AMP: scaler = GradScaler() global_rank = get_rank() if cfg.SOLVER.CHECKPOINT_PER_EPOCH != -1 and cfg.SOLVER.MAX_EPOCH >= 1: checkpoint_period = len(data_loader) * cfg.SOLVER.CHECKPOINT_PER_EPOCH // cfg.SOLVER.MAX_EPOCH if global_rank <= 0 and cfg.SOLVER.MAX_EPOCH >= 1: print("Iter per epoch ", len(data_loader) // cfg.SOLVER.MAX_EPOCH ) if cfg.SOLVER.AUTO_TERMINATE_PATIENCE != -1: patience_counter = 0 previous_best = 0.0 # Adapt the weight decay if cfg.SOLVER.WEIGHT_DECAY_SCHEDULE and hasattr(scheduler, 'milestones'): milestone_target = 0 for i, milstone in enumerate(list(scheduler.milestones)): if scheduler.last_epoch >= milstone * cfg.SOLVER.WEIGHT_DECAY_SCHEDULE_RATIO: milestone_target = i+1 for iteration, (images, targets, idxs, positive_map, positive_map_eval, greenlight_map) in enumerate(data_loader, start_iter): nnegative = sum(len(target) < 1 for target in targets) nsample = len(targets) if nsample == nnegative or nnegative > nsample * cfg.SOLVER.MAX_NEG_PER_BATCH: logger.info('[WARNING] Sampled {} negative in {} in a batch, greater the allowed ratio {}, skip'. format(nnegative, nsample, cfg.SOLVER.MAX_NEG_PER_BATCH)) continue data_time = time.time() - end iteration = iteration + 1 arguments["iteration"] = iteration images = images.to(device) captions = None try: targets = [target.to(device) for target in targets] captions = [t.get_field("caption") for t in targets if "caption" in t.fields()] except: pass # Freeze language backbone if cfg.MODEL.LANGUAGE_BACKBONE.FREEZE: if hasattr(model, "module"): model.module.language_backbone.eval() else: model.language_backbone.eval() if cfg.SOLVER.USE_AMP: with autocast(): if len(captions) > 0: loss_dict = model(images, targets, captions, positive_map, greenlight_map = greenlight_map) else: loss_dict = model(images, targets) losses = sum(loss for loss in loss_dict.values()) # save checkpoints for further debug if nan happens # loss_value = losses.item() # if not math.isfinite(loss_value): # logging.error(f'=> loss is {loss_value}, stopping training') # logging.error("Losses are : {}".format(loss_dict)) # time_str = time.strftime('%Y-%m-%d-%H-%M') # fname = os.path.join(checkpointer.save_dir, f'{time_str}_states.pth') # logging.info(f'=> save error state to {fname}') # dict_to_save = { # 'x': images, # 'y': targets, # 'loss': losses, # 'states': model.module.state_dict() if hasattr(model, 'module') else model.state_dict() # } # if len(captions) > 0: # dict_to_save['captions'] = captions # dict_to_save['positive_map'] = positive_map # torch.save( # dict_to_save, # fname # ) if torch.isnan(losses) or torch.isinf(losses): logging.error("NaN encountered, ignoring") losses[losses != losses] = 0 optimizer.zero_grad() scaler.scale(losses).backward() scaler.step(optimizer) scaler.update() scheduler.step() else: if len(captions) > 0: loss_dict = model(images, targets, captions, positive_map) else: loss_dict = model(images, targets) losses = sum(loss for loss in loss_dict.values()) # loss_value = losses.item() # if not math.isfinite(loss_value): # logging.error(f'=> loss is {loss_value}, stopping training') # time_str = time.strftime('%Y-%m-%d-%H-%M') # fname = os.path.join(checkpointer.save_dir, f'{time_str}_states.pth') # logging.info(f'=> save error state to {fname}') # dict_to_save = { # 'x': images, # 'y': targets, # 'loss': losses, # 'states': model.module.state_dict() if hasattr(model, 'module') else model.state_dict() # } # if len(captions) > 0: # dict_to_save['captions'] = captions # dict_to_save['positive_map'] = positive_map # torch.save( # dict_to_save, # fname # ) if torch.isnan(losses) or torch.isinf(losses): losses[losses != losses] = 0 optimizer.zero_grad() losses.backward() optimizer.step() scheduler.step() # Adapt the weight decay: only support multiStepLR if cfg.SOLVER.WEIGHT_DECAY_SCHEDULE and hasattr(scheduler, 'milestones'): if milestone_target < len(scheduler.milestones): next_milestone = list(scheduler.milestones)[milestone_target] else: next_milestone = float('inf') if scheduler.last_epoch >= next_milestone * cfg.SOLVER.WEIGHT_DECAY_SCHEDULE_RATIO: gamma = scheduler.gamma logger.info("Drop the weight decay by {}!".format(gamma)) for param in optimizer.param_groups: if 'weight_decay' in param: param['weight_decay'] *= gamma # move the target forward milestone_target += 1 # reduce losses over all GPUs for logging purposes loss_dict_reduced = reduce_loss_dict(loss_dict) losses_reduced = sum(loss for loss in loss_dict_reduced.values()) meters.update(loss=losses_reduced, **loss_dict_reduced) if model_ema is not None: model_ema.update(model) arguments["model_ema"] = model_ema.state_dict() batch_time = time.time() - end end = time.time() meters.update(time=batch_time, data=data_time) eta_seconds = meters.time.global_avg * (max_iter - iteration) eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) if iteration % 20 == 0 or iteration == max_iter: # if iteration % 1 == 0 or iteration == max_iter: #logger.info( if global_rank <= 0: print( meters.delimiter.join( [ "eta: {eta}", "iter: {iter}", "{meters}", "lr: {lr:.6f}", "wd: {wd:.6f}", "max mem: {memory:.0f}", ] ).format( eta=eta_string, iter=iteration, meters=str(meters), lr=optimizer.param_groups[0]["lr"], wd=optimizer.param_groups[0]["weight_decay"], memory=torch.cuda.max_memory_allocated() / 1024.0 / 1024.0, ) ) if val_data_loader and (iteration % checkpoint_period == 0 or iteration == max_iter): if is_main_process(): print("Evaluating") eval_result = 0.0 model.eval() if cfg.SOLVER.TEST_WITH_INFERENCE: with torch.no_grad(): try: _model = model.module except: _model = model _result = inference( model = _model, data_loader = val_data_loader, dataset_name="val", device=device, expected_results=cfg.TEST.EXPECTED_RESULTS, expected_results_sigma_tol=cfg.TEST.EXPECTED_RESULTS_SIGMA_TOL, output_folder=None, cfg=cfg, verbose=False ) if is_main_process(): eval_result = _result[0].results['bbox']['AP'] else: results_dict = {} cpu_device = torch.device("cpu") for i, batch in enumerate(val_data_loader): images, targets, image_ids, positive_map, *_ = batch with torch.no_grad(): images = images.to(device) if positive_map is None: output = model(images) else: captions = [t.get_field("caption") for t in targets if "caption" in t.fields()] output = model(images, captions, positive_map) output = [o.to(cpu_device) for o in output] results_dict.update( {img_id: result for img_id, result in zip(image_ids, output)} ) all_predictions = all_gather(results_dict) if is_main_process(): predictions = {} for p in all_predictions: predictions.update(p) predictions = [predictions[i] for i in list(sorted(predictions.keys()))] eval_result, _ = evaluate(val_data_loader.dataset, predictions, output_folder=None, box_only=cfg.DATASETS.CLASS_AGNOSTIC) if cfg.DATASETS.CLASS_AGNOSTIC: eval_result = eval_result.results['box_proposal']['AR@100'] else: eval_result = eval_result.results['bbox']['AP'] model.train() if model_ema is not None and cfg.SOLVER.USE_EMA_FOR_MONITOR: model_ema.ema.eval() results_dict = {} cpu_device = torch.device("cpu") for i, batch in enumerate(val_data_loader): images, targets, image_ids, positive_map, positive_map_eval = batch with torch.no_grad(): images = images.to(device) if positive_map is None: output = model_ema.ema(images) else: captions = [t.get_field("caption") for t in targets if "caption" in t.fields()] output = model_ema.ema(images, captions, positive_map) output = [o.to(cpu_device) for o in output] results_dict.update( {img_id: result for img_id, result in zip(image_ids, output)} ) all_predictions = all_gather(results_dict) if is_main_process(): predictions = {} for p in all_predictions: predictions.update(p) predictions = [predictions[i] for i in list(sorted(predictions.keys()))] eval_result, _ = evaluate(val_data_loader.dataset, predictions, output_folder=None, box_only=cfg.DATASETS.CLASS_AGNOSTIC) if cfg.DATASETS.CLASS_AGNOSTIC: eval_result = eval_result.results['box_proposal']['AR@100'] else: eval_result = eval_result.results['bbox']['AP'] arguments.update(eval_result=eval_result) if cfg.SOLVER.USE_AUTOSTEP: eval_result = all_gather(eval_result)[0] #broadcast_data([eval_result])[0] # print("Rank {} eval result gathered".format(cfg.local_rank), eval_result) scheduler.step(eval_result) if cfg.SOLVER.AUTO_TERMINATE_PATIENCE != -1: if eval_result < previous_best: patience_counter += 1 else: patience_counter = 0 previous_best = eval_result checkpointer.save("model_best", **arguments) print("Previous Best", previous_best, "Patience Counter", patience_counter, "Eval Result", eval_result) if patience_counter >= cfg.SOLVER.AUTO_TERMINATE_PATIENCE: if is_main_process(): print("\n\n\n\nAuto Termination at {}, current best {}\n\n\n".format(iteration, previous_best)) break if iteration % checkpoint_period == 0: checkpointer.save("model_{:07d}".format(iteration), **arguments) if iteration == max_iter: checkpointer.save("model_final", **arguments) break total_training_time = time.time() - start_training_time total_time_str = str(datetime.timedelta(seconds=total_training_time)) logger.info( "Total training time: {} ({:.4f} s / it)".format( total_time_str, total_training_time / (max_iter) ) )