Spaces:
Running
Running
""" | |
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved | |
This script is a simplified version of the training script in detectron2/tools. | |
""" | |
import os | |
import itertools | |
import weakref | |
from typing import Any, Dict, List, Set | |
import logging | |
from collections import OrderedDict | |
import torch | |
from fvcore.nn.precise_bn import get_bn_modules | |
import detectron2.utils.comm as comm | |
from detectron2.utils.logger import setup_logger | |
from detectron2.checkpoint import DetectionCheckpointer | |
from detectron2.config import get_cfg | |
from detectron2.data import build_detection_train_loader | |
from regionspot import build_custom_train_loader | |
from detectron2.engine import DefaultTrainer, default_argument_parser, default_setup, launch, create_ddp_model, \ | |
AMPTrainer, SimpleTrainer, hooks | |
from detectron2.evaluation import COCOEvaluator, LVISEvaluator, verify_results | |
from detectron2.solver.build import maybe_add_gradient_clipping | |
from detectron2.modeling import build_model | |
from regionspot.data import objects365 | |
from regionspot.data import openimages | |
from regionspot.data import v3det | |
from regionspot import RegionSpotDatasetMapper, add_regionspot_config, RegionSpotWithTTA | |
from regionspot.util.model_ema import add_model_ema_configs, may_build_model_ema, may_get_ema_checkpointer, EMAHook, \ | |
apply_model_ema_and_restore, EMADetectionCheckpointer | |
class Trainer(DefaultTrainer): | |
""" Extension of the Trainer class adapted to RegionSpot. """ | |
def __init__(self, cfg): | |
""" | |
Args: | |
cfg (CfgNode): | |
""" | |
super(DefaultTrainer, self).__init__() # call grandfather's `__init__` while avoid father's `__init()` | |
logger = logging.getLogger("detectron2") | |
if not logger.isEnabledFor(logging.INFO): # setup_logger is not called for d2 | |
setup_logger() | |
cfg = DefaultTrainer.auto_scale_workers(cfg, comm.get_world_size()) | |
# Assume these objects must be constructed in this order. | |
model = self.build_model(cfg) | |
optimizer = self.build_optimizer(cfg, model) | |
data_loader = self.build_train_loader(cfg) | |
model = create_ddp_model(model, broadcast_buffers=False) | |
self._trainer = (AMPTrainer if cfg.SOLVER.AMP.ENABLED else SimpleTrainer)( | |
model, data_loader, optimizer | |
) | |
self.scheduler = self.build_lr_scheduler(cfg, optimizer) | |
########## EMA ############ | |
kwargs = { | |
'trainer': weakref.proxy(self), | |
} | |
kwargs.update(may_get_ema_checkpointer(cfg, model)) | |
self.checkpointer = DetectionCheckpointer( | |
# Assume you want to save checkpoints together with logs/statistics | |
model, | |
cfg.OUTPUT_DIR, | |
**kwargs, | |
# trainer=weakref.proxy(self), | |
) | |
self.start_iter = 0 | |
self.max_iter = cfg.SOLVER.MAX_ITER | |
self.cfg = cfg | |
self.register_hooks(self.build_hooks()) | |
def build_model(cls, cfg): | |
""" | |
Returns: | |
torch.nn.Module: | |
It now calls :func:`detectron2.modeling.build_model`. | |
Overwrite it if you'd like a different model. | |
""" | |
model = build_model(cfg) | |
logger = logging.getLogger(__name__) | |
logger.info("Model:\n{}".format(model)) | |
# setup EMA | |
may_build_model_ema(cfg, model) | |
return model | |
def build_evaluator(cls, cfg, dataset_name, output_folder=None): | |
""" | |
Create evaluator(s) for a given dataset. | |
This uses the special metadata "evaluator_type" associated with each builtin dataset. | |
For your own dataset, you can simply create an evaluator manually in your | |
script and do not have to worry about the hacky if-else logic here. | |
""" | |
if output_folder is None: | |
output_folder = os.path.join(cfg.OUTPUT_DIR, "inference") | |
if 'lvis' in dataset_name: | |
return LVISEvaluator(dataset_name, cfg, True, output_folder) | |
else: | |
return COCOEvaluator(dataset_name, cfg, True, output_folder) | |
def build_train_loader(cls, cfg): | |
mapper = RegionSpotDatasetMapper(cfg, is_train=True) | |
if cfg.DATALOADER.SAMPLER_TRAIN in ['TrainingSampler', 'RepeatFactorTrainingSampler']: | |
data_loader = build_detection_train_loader(cfg, mapper=mapper) | |
else: | |
data_loader = build_custom_train_loader(cfg, mapper=mapper) | |
return data_loader | |
def build_optimizer(cls, cfg, model): | |
params: List[Dict[str, Any]] = [] | |
memo: Set[torch.nn.parameter.Parameter] = set() | |
for key, value in model.named_parameters(recurse=True): | |
if not value.requires_grad: | |
continue | |
# Avoid duplicating parameters | |
if value in memo: | |
continue | |
memo.add(value) | |
lr = cfg.SOLVER.BASE_LR | |
weight_decay = cfg.SOLVER.WEIGHT_DECAY | |
if "backbone" in key: | |
lr = lr * cfg.SOLVER.BACKBONE_MULTIPLIER | |
params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}] | |
def maybe_add_full_model_gradient_clipping(optim): # optim: the optimizer class | |
# detectron2 doesn't have full model gradient clipping now | |
clip_norm_val = cfg.SOLVER.CLIP_GRADIENTS.CLIP_VALUE | |
enable = ( | |
cfg.SOLVER.CLIP_GRADIENTS.ENABLED | |
and cfg.SOLVER.CLIP_GRADIENTS.CLIP_TYPE == "full_model" | |
and clip_norm_val > 0.0 | |
) | |
class FullModelGradientClippingOptimizer(optim): | |
def step(self, closure=None): | |
all_params = itertools.chain(*[x["params"] for x in self.param_groups]) | |
torch.nn.utils.clip_grad_norm_(all_params, clip_norm_val) | |
super().step(closure=closure) | |
return FullModelGradientClippingOptimizer if enable else optim | |
optimizer_type = cfg.SOLVER.OPTIMIZER | |
if optimizer_type == "SGD": | |
optimizer = maybe_add_full_model_gradient_clipping(torch.optim.SGD)( | |
params, cfg.SOLVER.BASE_LR, momentum=cfg.SOLVER.MOMENTUM | |
) | |
elif optimizer_type == "ADAMW": | |
optimizer = maybe_add_full_model_gradient_clipping(torch.optim.AdamW)( | |
params, cfg.SOLVER.BASE_LR | |
) | |
else: | |
raise NotImplementedError(f"no optimizer type {optimizer_type}") | |
if not cfg.SOLVER.CLIP_GRADIENTS.CLIP_TYPE == "full_model": | |
optimizer = maybe_add_gradient_clipping(cfg, optimizer) | |
return optimizer | |
def ema_test(cls, cfg, model, evaluators=None): | |
# model with ema weights | |
logger = logging.getLogger("detectron2.trainer") | |
if cfg.MODEL_EMA.ENABLED: | |
logger.info("Run evaluation with EMA.") | |
with apply_model_ema_and_restore(model): | |
results = cls.test(cfg, model, evaluators=evaluators) | |
else: | |
results = cls.test(cfg, model, evaluators=evaluators) | |
return results | |
def test_with_TTA(cls, cfg, model): | |
logger = logging.getLogger("detectron2.trainer") | |
logger.info("Running inference with test-time augmentation ...") | |
model = RegionSpotWithTTA(cfg, model) | |
evaluators = [ | |
cls.build_evaluator( | |
cfg, name, output_folder=os.path.join(cfg.OUTPUT_DIR, "inference_TTA") | |
) | |
for name in cfg.DATASETS.TEST | |
] | |
if cfg.MODEL_EMA.ENABLED: | |
cls.ema_test(cfg, model, evaluators) | |
else: | |
res = cls.test(cfg, model, evaluators) | |
res = OrderedDict({k + "_TTA": v for k, v in res.items()}) | |
return res | |
def build_hooks(self): | |
""" | |
Build a list of default hooks, including timing, evaluation, | |
checkpointing, lr scheduling, precise BN, writing events. | |
Returns: | |
list[HookBase]: | |
""" | |
cfg = self.cfg.clone() | |
cfg.defrost() | |
cfg.DATALOADER.NUM_WORKERS = 0 # save some memory and time for PreciseBN | |
ret = [ | |
hooks.IterationTimer(), | |
EMAHook(self.cfg, self.model) if cfg.MODEL_EMA.ENABLED else None, # EMA hook | |
hooks.LRScheduler(), | |
hooks.PreciseBN( | |
# Run at the same freq as (but before) evaluation. | |
cfg.TEST.EVAL_PERIOD, | |
self.model, | |
# Build a new data loader to not affect training | |
self.build_train_loader(cfg), | |
cfg.TEST.PRECISE_BN.NUM_ITER, | |
) | |
if cfg.TEST.PRECISE_BN.ENABLED and get_bn_modules(self.model) | |
else None, | |
] | |
# Do PreciseBN before checkpointer, because it updates the model and need to | |
# be saved by checkpointer. | |
# This is not always the best: if checkpointing has a different frequency, | |
# some checkpoints may have more precise statistics than others. | |
if comm.is_main_process(): | |
ret.append(hooks.PeriodicCheckpointer(self.checkpointer, cfg.SOLVER.CHECKPOINT_PERIOD)) | |
def test_and_save_results(): | |
self._last_eval_results = self.test(self.cfg, self.model) | |
return self._last_eval_results | |
# Do evaluation after checkpointer, because then if it fails, | |
# we can use the saved checkpoint to debug. | |
ret.append(hooks.EvalHook(cfg.TEST.EVAL_PERIOD, test_and_save_results)) | |
if comm.is_main_process(): | |
# Here the default print/log frequency of each writer is used. | |
# run writers in the end, so that evaluation metrics are written | |
ret.append(hooks.PeriodicWriter(self.build_writers(), period=20)) | |
return ret | |
def setup(args): | |
""" | |
Create configs and perform basic setups. | |
""" | |
cfg = get_cfg() | |
add_regionspot_config(cfg) | |
add_model_ema_configs(cfg) | |
cfg.merge_from_file(args.config_file) | |
cfg.merge_from_list(args.opts) | |
cfg.freeze() | |
default_setup(cfg, args) | |
return cfg | |
def main(args): | |
cfg = setup(args) | |
if args.eval_only: | |
model = Trainer.build_model(cfg) | |
kwargs = may_get_ema_checkpointer(cfg, model) | |
if cfg.MODEL_EMA.ENABLED: | |
EMADetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR, **kwargs).resume_or_load(cfg.MODEL.WEIGHTS, | |
resume=args.resume) | |
else: | |
DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR, **kwargs).resume_or_load(cfg.MODEL.WEIGHTS, | |
resume=args.resume) | |
res = Trainer.ema_test(cfg, model) | |
if cfg.TEST.AUG.ENABLED: | |
res.update(Trainer.test_with_TTA(cfg, model)) | |
if comm.is_main_process(): | |
verify_results(cfg, res) | |
return res | |
trainer = Trainer(cfg) | |
trainer.resume_or_load(resume=args.resume) | |
return trainer.train() | |
if __name__ == "__main__": | |
args = default_argument_parser().parse_args() | |
print("Command Line Args:", args) | |
launch( | |
main, | |
args.num_gpus, | |
num_machines=args.num_machines, | |
machine_rank=args.machine_rank, | |
dist_url=args.dist_url, | |
args=(args,), | |
) | |