|
|
|
import os.path as osp
|
|
import platform
|
|
import shutil
|
|
import time
|
|
import warnings
|
|
|
|
import torch
|
|
from torch.optim import Optimizer
|
|
|
|
import annotator.uniformer.mmcv as mmcv
|
|
from .base_runner import BaseRunner
|
|
from .builder import RUNNERS
|
|
from .checkpoint import save_checkpoint
|
|
from .hooks import IterTimerHook
|
|
from .utils import get_host_info
|
|
|
|
|
|
class IterLoader:
|
|
|
|
def __init__(self, dataloader):
|
|
self._dataloader = dataloader
|
|
self.iter_loader = iter(self._dataloader)
|
|
self._epoch = 0
|
|
|
|
@property
|
|
def epoch(self):
|
|
return self._epoch
|
|
|
|
def __next__(self):
|
|
try:
|
|
data = next(self.iter_loader)
|
|
except StopIteration:
|
|
self._epoch += 1
|
|
if hasattr(self._dataloader.sampler, 'set_epoch'):
|
|
self._dataloader.sampler.set_epoch(self._epoch)
|
|
time.sleep(2)
|
|
self.iter_loader = iter(self._dataloader)
|
|
data = next(self.iter_loader)
|
|
|
|
return data
|
|
|
|
def __len__(self):
|
|
return len(self._dataloader)
|
|
|
|
|
|
@RUNNERS.register_module()
|
|
class IterBasedRunner(BaseRunner):
|
|
"""Iteration-based Runner.
|
|
|
|
This runner train models iteration by iteration.
|
|
"""
|
|
|
|
def train(self, data_loader, **kwargs):
|
|
self.model.train()
|
|
self.mode = 'train'
|
|
self.data_loader = data_loader
|
|
self._epoch = data_loader.epoch
|
|
data_batch = next(data_loader)
|
|
self.call_hook('before_train_iter')
|
|
outputs = self.model.train_step(data_batch, self.optimizer, **kwargs)
|
|
if not isinstance(outputs, dict):
|
|
raise TypeError('model.train_step() must return a dict')
|
|
if 'log_vars' in outputs:
|
|
self.log_buffer.update(outputs['log_vars'], outputs['num_samples'])
|
|
self.outputs = outputs
|
|
self.call_hook('after_train_iter')
|
|
self._inner_iter += 1
|
|
self._iter += 1
|
|
|
|
@torch.no_grad()
|
|
def val(self, data_loader, **kwargs):
|
|
self.model.eval()
|
|
self.mode = 'val'
|
|
self.data_loader = data_loader
|
|
data_batch = next(data_loader)
|
|
self.call_hook('before_val_iter')
|
|
outputs = self.model.val_step(data_batch, **kwargs)
|
|
if not isinstance(outputs, dict):
|
|
raise TypeError('model.val_step() must return a dict')
|
|
if 'log_vars' in outputs:
|
|
self.log_buffer.update(outputs['log_vars'], outputs['num_samples'])
|
|
self.outputs = outputs
|
|
self.call_hook('after_val_iter')
|
|
self._inner_iter += 1
|
|
|
|
def run(self, data_loaders, workflow, max_iters=None, **kwargs):
|
|
"""Start running.
|
|
|
|
Args:
|
|
data_loaders (list[:obj:`DataLoader`]): Dataloaders for training
|
|
and validation.
|
|
workflow (list[tuple]): A list of (phase, iters) to specify the
|
|
running order and iterations. E.g, [('train', 10000),
|
|
('val', 1000)] means running 10000 iterations for training and
|
|
1000 iterations for validation, iteratively.
|
|
"""
|
|
assert isinstance(data_loaders, list)
|
|
assert mmcv.is_list_of(workflow, tuple)
|
|
assert len(data_loaders) == len(workflow)
|
|
if max_iters is not None:
|
|
warnings.warn(
|
|
'setting max_iters in run is deprecated, '
|
|
'please set max_iters in runner_config', DeprecationWarning)
|
|
self._max_iters = max_iters
|
|
assert self._max_iters is not None, (
|
|
'max_iters must be specified during instantiation')
|
|
|
|
work_dir = self.work_dir if self.work_dir is not None else 'NONE'
|
|
self.logger.info('Start running, host: %s, work_dir: %s',
|
|
get_host_info(), work_dir)
|
|
self.logger.info('Hooks will be executed in the following order:\n%s',
|
|
self.get_hook_info())
|
|
self.logger.info('workflow: %s, max: %d iters', workflow,
|
|
self._max_iters)
|
|
self.call_hook('before_run')
|
|
|
|
iter_loaders = [IterLoader(x) for x in data_loaders]
|
|
|
|
self.call_hook('before_epoch')
|
|
|
|
while self.iter < self._max_iters:
|
|
for i, flow in enumerate(workflow):
|
|
self._inner_iter = 0
|
|
mode, iters = flow
|
|
if not isinstance(mode, str) or not hasattr(self, mode):
|
|
raise ValueError(
|
|
'runner has no method named "{}" to run a workflow'.
|
|
format(mode))
|
|
iter_runner = getattr(self, mode)
|
|
for _ in range(iters):
|
|
if mode == 'train' and self.iter >= self._max_iters:
|
|
break
|
|
iter_runner(iter_loaders[i], **kwargs)
|
|
|
|
time.sleep(1)
|
|
self.call_hook('after_epoch')
|
|
self.call_hook('after_run')
|
|
|
|
def resume(self,
|
|
checkpoint,
|
|
resume_optimizer=True,
|
|
map_location='default'):
|
|
"""Resume model from checkpoint.
|
|
|
|
Args:
|
|
checkpoint (str): Checkpoint to resume from.
|
|
resume_optimizer (bool, optional): Whether resume the optimizer(s)
|
|
if the checkpoint file includes optimizer(s). Default to True.
|
|
map_location (str, optional): Same as :func:`torch.load`.
|
|
Default to 'default'.
|
|
"""
|
|
if map_location == 'default':
|
|
device_id = torch.cuda.current_device()
|
|
checkpoint = self.load_checkpoint(
|
|
checkpoint,
|
|
map_location=lambda storage, loc: storage.cuda(device_id))
|
|
else:
|
|
checkpoint = self.load_checkpoint(
|
|
checkpoint, map_location=map_location)
|
|
|
|
self._epoch = checkpoint['meta']['epoch']
|
|
self._iter = checkpoint['meta']['iter']
|
|
self._inner_iter = checkpoint['meta']['iter']
|
|
if 'optimizer' in checkpoint and resume_optimizer:
|
|
if isinstance(self.optimizer, Optimizer):
|
|
self.optimizer.load_state_dict(checkpoint['optimizer'])
|
|
elif isinstance(self.optimizer, dict):
|
|
for k in self.optimizer.keys():
|
|
self.optimizer[k].load_state_dict(
|
|
checkpoint['optimizer'][k])
|
|
else:
|
|
raise TypeError(
|
|
'Optimizer should be dict or torch.optim.Optimizer '
|
|
f'but got {type(self.optimizer)}')
|
|
|
|
self.logger.info(f'resumed from epoch: {self.epoch}, iter {self.iter}')
|
|
|
|
def save_checkpoint(self,
|
|
out_dir,
|
|
filename_tmpl='iter_{}.pth',
|
|
meta=None,
|
|
save_optimizer=True,
|
|
create_symlink=True):
|
|
"""Save checkpoint to file.
|
|
|
|
Args:
|
|
out_dir (str): Directory to save checkpoint files.
|
|
filename_tmpl (str, optional): Checkpoint file template.
|
|
Defaults to 'iter_{}.pth'.
|
|
meta (dict, optional): Metadata to be saved in checkpoint.
|
|
Defaults to None.
|
|
save_optimizer (bool, optional): Whether save optimizer.
|
|
Defaults to True.
|
|
create_symlink (bool, optional): Whether create symlink to the
|
|
latest checkpoint file. Defaults to True.
|
|
"""
|
|
if meta is None:
|
|
meta = {}
|
|
elif not isinstance(meta, dict):
|
|
raise TypeError(
|
|
f'meta should be a dict or None, but got {type(meta)}')
|
|
if self.meta is not None:
|
|
meta.update(self.meta)
|
|
|
|
|
|
|
|
|
|
meta.update(epoch=self.epoch + 1, iter=self.iter)
|
|
|
|
filename = filename_tmpl.format(self.iter + 1)
|
|
filepath = osp.join(out_dir, filename)
|
|
optimizer = self.optimizer if save_optimizer else None
|
|
save_checkpoint(self.model, filepath, optimizer=optimizer, meta=meta)
|
|
|
|
|
|
if create_symlink:
|
|
dst_file = osp.join(out_dir, 'latest.pth')
|
|
if platform.system() != 'Windows':
|
|
mmcv.symlink(filename, dst_file)
|
|
else:
|
|
shutil.copy(filepath, dst_file)
|
|
|
|
def register_training_hooks(self,
|
|
lr_config,
|
|
optimizer_config=None,
|
|
checkpoint_config=None,
|
|
log_config=None,
|
|
momentum_config=None,
|
|
custom_hooks_config=None):
|
|
"""Register default hooks for iter-based training.
|
|
|
|
Checkpoint hook, optimizer stepper hook and logger hooks will be set to
|
|
`by_epoch=False` by default.
|
|
|
|
Default hooks include:
|
|
|
|
+----------------------+-------------------------+
|
|
| Hooks | Priority |
|
|
+======================+=========================+
|
|
| LrUpdaterHook | VERY_HIGH (10) |
|
|
+----------------------+-------------------------+
|
|
| MomentumUpdaterHook | HIGH (30) |
|
|
+----------------------+-------------------------+
|
|
| OptimizerStepperHook | ABOVE_NORMAL (40) |
|
|
+----------------------+-------------------------+
|
|
| CheckpointSaverHook | NORMAL (50) |
|
|
+----------------------+-------------------------+
|
|
| IterTimerHook | LOW (70) |
|
|
+----------------------+-------------------------+
|
|
| LoggerHook(s) | VERY_LOW (90) |
|
|
+----------------------+-------------------------+
|
|
| CustomHook(s) | defaults to NORMAL (50) |
|
|
+----------------------+-------------------------+
|
|
|
|
If custom hooks have same priority with default hooks, custom hooks
|
|
will be triggered after default hooks.
|
|
"""
|
|
if checkpoint_config is not None:
|
|
checkpoint_config.setdefault('by_epoch', False)
|
|
if lr_config is not None:
|
|
lr_config.setdefault('by_epoch', False)
|
|
if log_config is not None:
|
|
for info in log_config['hooks']:
|
|
info.setdefault('by_epoch', False)
|
|
super(IterBasedRunner, self).register_training_hooks(
|
|
lr_config=lr_config,
|
|
momentum_config=momentum_config,
|
|
optimizer_config=optimizer_config,
|
|
checkpoint_config=checkpoint_config,
|
|
log_config=log_config,
|
|
timer_config=IterTimerHook(),
|
|
custom_hooks_config=custom_hooks_config)
|
|
|