Spaces:
Paused
Paused
# Copyright (c) OpenMMLab. All rights reserved. | |
import os.path as osp | |
import platform | |
import shutil | |
import time | |
import warnings | |
import torch | |
import annotator.uniformer.mmcv as mmcv | |
from .base_runner import BaseRunner | |
from .builder import RUNNERS | |
from .checkpoint import save_checkpoint | |
from .utils import get_host_info | |
class EpochBasedRunner(BaseRunner): | |
"""Epoch-based Runner. | |
This runner train models epoch by epoch. | |
""" | |
def run_iter(self, data_batch, train_mode, **kwargs): | |
if self.batch_processor is not None: | |
outputs = self.batch_processor( | |
self.model, data_batch, train_mode=train_mode, **kwargs) | |
elif train_mode: | |
outputs = self.model.train_step(data_batch, self.optimizer, | |
**kwargs) | |
else: | |
outputs = self.model.val_step(data_batch, self.optimizer, **kwargs) | |
if not isinstance(outputs, dict): | |
raise TypeError('"batch_processor()" or "model.train_step()"' | |
'and "model.val_step()" must return a dict') | |
if 'log_vars' in outputs: | |
self.log_buffer.update(outputs['log_vars'], outputs['num_samples']) | |
self.outputs = outputs | |
def train(self, data_loader, **kwargs): | |
self.model.train() | |
self.mode = 'train' | |
self.data_loader = data_loader | |
self._max_iters = self._max_epochs * len(self.data_loader) | |
self.call_hook('before_train_epoch') | |
time.sleep(2) # Prevent possible deadlock during epoch transition | |
for i, data_batch in enumerate(self.data_loader): | |
self._inner_iter = i | |
self.call_hook('before_train_iter') | |
self.run_iter(data_batch, train_mode=True, **kwargs) | |
self.call_hook('after_train_iter') | |
self._iter += 1 | |
self.call_hook('after_train_epoch') | |
self._epoch += 1 | |
def val(self, data_loader, **kwargs): | |
self.model.eval() | |
self.mode = 'val' | |
self.data_loader = data_loader | |
self.call_hook('before_val_epoch') | |
time.sleep(2) # Prevent possible deadlock during epoch transition | |
for i, data_batch in enumerate(self.data_loader): | |
self._inner_iter = i | |
self.call_hook('before_val_iter') | |
self.run_iter(data_batch, train_mode=False) | |
self.call_hook('after_val_iter') | |
self.call_hook('after_val_epoch') | |
def run(self, data_loaders, workflow, max_epochs=None, **kwargs): | |
"""Start running. | |
Args: | |
data_loaders (list[:obj:`DataLoader`]): Dataloaders for training | |
and validation. | |
workflow (list[tuple]): A list of (phase, epochs) to specify the | |
running order and epochs. E.g, [('train', 2), ('val', 1)] means | |
running 2 epochs for training and 1 epoch for validation, | |
iteratively. | |
""" | |
assert isinstance(data_loaders, list) | |
assert mmcv.is_list_of(workflow, tuple) | |
assert len(data_loaders) == len(workflow) | |
if max_epochs is not None: | |
warnings.warn( | |
'setting max_epochs in run is deprecated, ' | |
'please set max_epochs in runner_config', DeprecationWarning) | |
self._max_epochs = max_epochs | |
assert self._max_epochs is not None, ( | |
'max_epochs must be specified during instantiation') | |
for i, flow in enumerate(workflow): | |
mode, epochs = flow | |
if mode == 'train': | |
self._max_iters = self._max_epochs * len(data_loaders[i]) | |
break | |
work_dir = self.work_dir if self.work_dir is not None else 'NONE' | |
self.logger.info('Start running, host: %s, work_dir: %s', | |
get_host_info(), work_dir) | |
self.logger.info('Hooks will be executed in the following order:\n%s', | |
self.get_hook_info()) | |
self.logger.info('workflow: %s, max: %d epochs', workflow, | |
self._max_epochs) | |
self.call_hook('before_run') | |
while self.epoch < self._max_epochs: | |
for i, flow in enumerate(workflow): | |
mode, epochs = flow | |
if isinstance(mode, str): # self.train() | |
if not hasattr(self, mode): | |
raise ValueError( | |
f'runner has no method named "{mode}" to run an ' | |
'epoch') | |
epoch_runner = getattr(self, mode) | |
else: | |
raise TypeError( | |
'mode in workflow must be a str, but got {}'.format( | |
type(mode))) | |
for _ in range(epochs): | |
if mode == 'train' and self.epoch >= self._max_epochs: | |
break | |
epoch_runner(data_loaders[i], **kwargs) | |
time.sleep(1) # wait for some hooks like loggers to finish | |
self.call_hook('after_run') | |
def save_checkpoint(self, | |
out_dir, | |
filename_tmpl='epoch_{}.pth', | |
save_optimizer=True, | |
meta=None, | |
create_symlink=True): | |
"""Save the checkpoint. | |
Args: | |
out_dir (str): The directory that checkpoints are saved. | |
filename_tmpl (str, optional): The checkpoint filename template, | |
which contains a placeholder for the epoch number. | |
Defaults to 'epoch_{}.pth'. | |
save_optimizer (bool, optional): Whether to save the optimizer to | |
the checkpoint. Defaults to True. | |
meta (dict, optional): The meta information to be saved in the | |
checkpoint. Defaults to None. | |
create_symlink (bool, optional): Whether to create a symlink | |
"latest.pth" to point to the latest checkpoint. | |
Defaults to True. | |
""" | |
if meta is None: | |
meta = {} | |
elif not isinstance(meta, dict): | |
raise TypeError( | |
f'meta should be a dict or None, but got {type(meta)}') | |
if self.meta is not None: | |
meta.update(self.meta) | |
# Note: meta.update(self.meta) should be done before | |
# meta.update(epoch=self.epoch + 1, iter=self.iter) otherwise | |
# there will be problems with resumed checkpoints. | |
# More details in https://github.com/open-mmlab/mmcv/pull/1108 | |
meta.update(epoch=self.epoch + 1, iter=self.iter) | |
filename = filename_tmpl.format(self.epoch + 1) | |
filepath = osp.join(out_dir, filename) | |
optimizer = self.optimizer if save_optimizer else None | |
save_checkpoint(self.model, filepath, optimizer=optimizer, meta=meta) | |
# in some environments, `os.symlink` is not supported, you may need to | |
# set `create_symlink` to False | |
if create_symlink: | |
dst_file = osp.join(out_dir, 'latest.pth') | |
if platform.system() != 'Windows': | |
mmcv.symlink(filename, dst_file) | |
else: | |
shutil.copy(filepath, dst_file) | |
class Runner(EpochBasedRunner): | |
"""Deprecated name of EpochBasedRunner.""" | |
def __init__(self, *args, **kwargs): | |
warnings.warn( | |
'Runner was deprecated, please use EpochBasedRunner instead') | |
super().__init__(*args, **kwargs) | |