Spaces:
Runtime error
Runtime error
#!/usr/bin/env python3 | |
# Copyright (c) Facebook, Inc. and its affiliates. | |
# Lightning Trainer should be considered beta at this point | |
# We have confirmed that training and validation run correctly and produce correct results | |
# Depending on how you launch the trainer, there are issues with processes terminating correctly | |
# This module is still dependent on D2 logging, but could be transferred to use Lightning logging | |
import logging | |
import os | |
import time | |
import weakref | |
from collections import OrderedDict | |
from typing import Any, Dict, List | |
import detectron2.utils.comm as comm | |
from detectron2.checkpoint import DetectionCheckpointer | |
from detectron2.config import get_cfg | |
from detectron2.data import build_detection_test_loader, build_detection_train_loader | |
from detectron2.engine import ( | |
DefaultTrainer, | |
SimpleTrainer, | |
default_argument_parser, | |
default_setup, | |
default_writers, | |
hooks, | |
) | |
from detectron2.evaluation import print_csv_format | |
from detectron2.evaluation.testing import flatten_results_dict | |
from detectron2.modeling import build_model | |
from detectron2.solver import build_lr_scheduler, build_optimizer | |
from detectron2.utils.events import EventStorage | |
from detectron2.utils.logger import setup_logger | |
import pytorch_lightning as pl # type: ignore | |
from pytorch_lightning import LightningDataModule, LightningModule | |
from train_net import build_evaluator | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger("detectron2") | |
class TrainingModule(LightningModule): | |
def __init__(self, cfg): | |
super().__init__() | |
if not logger.isEnabledFor(logging.INFO): # setup_logger is not called for d2 | |
setup_logger() | |
self.cfg = DefaultTrainer.auto_scale_workers(cfg, comm.get_world_size()) | |
self.storage: EventStorage = None | |
self.model = build_model(self.cfg) | |
self.start_iter = 0 | |
self.max_iter = cfg.SOLVER.MAX_ITER | |
def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: | |
checkpoint["iteration"] = self.storage.iter | |
def on_load_checkpoint(self, checkpointed_state: Dict[str, Any]) -> None: | |
self.start_iter = checkpointed_state["iteration"] | |
self.storage.iter = self.start_iter | |
def setup(self, stage: str): | |
if self.cfg.MODEL.WEIGHTS: | |
self.checkpointer = DetectionCheckpointer( | |
# Assume you want to save checkpoints together with logs/statistics | |
self.model, | |
self.cfg.OUTPUT_DIR, | |
) | |
logger.info(f"Load model weights from checkpoint: {self.cfg.MODEL.WEIGHTS}.") | |
# Only load weights, use lightning checkpointing if you want to resume | |
self.checkpointer.load(self.cfg.MODEL.WEIGHTS) | |
self.iteration_timer = hooks.IterationTimer() | |
self.iteration_timer.before_train() | |
self.data_start = time.perf_counter() | |
self.writers = None | |
def training_step(self, batch, batch_idx): | |
data_time = time.perf_counter() - self.data_start | |
# Need to manually enter/exit since trainer may launch processes | |
# This ideally belongs in setup, but setup seems to run before processes are spawned | |
if self.storage is None: | |
self.storage = EventStorage(0) | |
self.storage.__enter__() | |
self.iteration_timer.trainer = weakref.proxy(self) | |
self.iteration_timer.before_step() | |
self.writers = ( | |
default_writers(self.cfg.OUTPUT_DIR, self.max_iter) | |
if comm.is_main_process() | |
else {} | |
) | |
loss_dict = self.model(batch) | |
SimpleTrainer.write_metrics(loss_dict, data_time) | |
opt = self.optimizers() | |
self.storage.put_scalar( | |
"lr", opt.param_groups[self._best_param_group_id]["lr"], smoothing_hint=False | |
) | |
self.iteration_timer.after_step() | |
self.storage.step() | |
# A little odd to put before step here, but it's the best way to get a proper timing | |
self.iteration_timer.before_step() | |
if self.storage.iter % 20 == 0: | |
for writer in self.writers: | |
writer.write() | |
return sum(loss_dict.values()) | |
def training_step_end(self, training_step_outpus): | |
self.data_start = time.perf_counter() | |
return training_step_outpus | |
def training_epoch_end(self, training_step_outputs): | |
self.iteration_timer.after_train() | |
if comm.is_main_process(): | |
self.checkpointer.save("model_final") | |
for writer in self.writers: | |
writer.write() | |
writer.close() | |
self.storage.__exit__(None, None, None) | |
def _process_dataset_evaluation_results(self) -> OrderedDict: | |
results = OrderedDict() | |
for idx, dataset_name in enumerate(self.cfg.DATASETS.TEST): | |
results[dataset_name] = self._evaluators[idx].evaluate() | |
if comm.is_main_process(): | |
print_csv_format(results[dataset_name]) | |
if len(results) == 1: | |
results = list(results.values())[0] | |
return results | |
def _reset_dataset_evaluators(self): | |
self._evaluators = [] | |
for dataset_name in self.cfg.DATASETS.TEST: | |
evaluator = build_evaluator(self.cfg, dataset_name) | |
evaluator.reset() | |
self._evaluators.append(evaluator) | |
def on_validation_epoch_start(self, _outputs): | |
self._reset_dataset_evaluators() | |
def validation_epoch_end(self, _outputs): | |
results = self._process_dataset_evaluation_results(_outputs) | |
flattened_results = flatten_results_dict(results) | |
for k, v in flattened_results.items(): | |
try: | |
v = float(v) | |
except Exception as e: | |
raise ValueError( | |
"[EvalHook] eval_function should return a nested dict of float. " | |
"Got '{}: {}' instead.".format(k, v) | |
) from e | |
self.storage.put_scalars(**flattened_results, smoothing_hint=False) | |
def validation_step(self, batch, batch_idx: int, dataloader_idx: int = 0) -> None: | |
if not isinstance(batch, List): | |
batch = [batch] | |
outputs = self.model(batch) | |
self._evaluators[dataloader_idx].process(batch, outputs) | |
def configure_optimizers(self): | |
optimizer = build_optimizer(self.cfg, self.model) | |
self._best_param_group_id = hooks.LRScheduler.get_best_param_group_id(optimizer) | |
scheduler = build_lr_scheduler(self.cfg, optimizer) | |
return [optimizer], [{"scheduler": scheduler, "interval": "step"}] | |
class DataModule(LightningDataModule): | |
def __init__(self, cfg): | |
super().__init__() | |
self.cfg = DefaultTrainer.auto_scale_workers(cfg, comm.get_world_size()) | |
def train_dataloader(self): | |
return build_detection_train_loader(self.cfg) | |
def val_dataloader(self): | |
dataloaders = [] | |
for dataset_name in self.cfg.DATASETS.TEST: | |
dataloaders.append(build_detection_test_loader(self.cfg, dataset_name)) | |
return dataloaders | |
def main(args): | |
cfg = setup(args) | |
train(cfg, args) | |
def train(cfg, args): | |
trainer_params = { | |
# training loop is bounded by max steps, use a large max_epochs to make | |
# sure max_steps is met first | |
"max_epochs": 10 ** 8, | |
"max_steps": cfg.SOLVER.MAX_ITER, | |
"val_check_interval": cfg.TEST.EVAL_PERIOD if cfg.TEST.EVAL_PERIOD > 0 else 10 ** 8, | |
"num_nodes": args.num_machines, | |
"gpus": args.num_gpus, | |
"num_sanity_val_steps": 0, | |
} | |
if cfg.SOLVER.AMP.ENABLED: | |
trainer_params["precision"] = 16 | |
last_checkpoint = os.path.join(cfg.OUTPUT_DIR, "last.ckpt") | |
if args.resume: | |
# resume training from checkpoint | |
trainer_params["resume_from_checkpoint"] = last_checkpoint | |
logger.info(f"Resuming training from checkpoint: {last_checkpoint}.") | |
trainer = pl.Trainer(**trainer_params) | |
logger.info(f"start to train with {args.num_machines} nodes and {args.num_gpus} GPUs") | |
module = TrainingModule(cfg) | |
data_module = DataModule(cfg) | |
if args.eval_only: | |
logger.info("Running inference") | |
trainer.validate(module, data_module) | |
else: | |
logger.info("Running training") | |
trainer.fit(module, data_module) | |
def setup(args): | |
""" | |
Create configs and perform basic setups. | |
""" | |
cfg = get_cfg() | |
cfg.merge_from_file(args.config_file) | |
cfg.merge_from_list(args.opts) | |
cfg.freeze() | |
default_setup(cfg, args) | |
return cfg | |
if __name__ == "__main__": | |
parser = default_argument_parser() | |
args = parser.parse_args() | |
logger.info("Command Line Args:", args) | |
main(args) | |