rawalkhirodkar's picture
Add initial commit
28c256d
raw
history blame
7.35 kB
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from typing import Any, Dict, Optional, Union
import numpy as np
import torch
from mmengine.registry import HOOKS
from mmengine.utils import get_git_hash
from mmengine.version import __version__
from .hook import Hook
DATA_BATCH = Optional[Union[dict, tuple, list]]
def _is_scalar(value: Any) -> bool:
"""Determine the value is a scalar type value.
Args:
value (Any): value of log.
Returns:
bool: whether the value is a scalar type value.
"""
if isinstance(value, np.ndarray):
return value.size == 1
elif isinstance(value, (int, float, np.number)):
return True
elif isinstance(value, torch.Tensor):
return value.numel() == 1
return False
@HOOKS.register_module()
class RuntimeInfoHook(Hook):
"""A hook that updates runtime information into message hub.
E.g. ``epoch``, ``iter``, ``max_epochs``, and ``max_iters`` for the
training state. Components that cannot access the runner can get runtime
information through the message hub.
"""
priority = 'VERY_HIGH'
def before_run(self, runner) -> None:
"""Update metainfo.
Args:
runner (Runner): The runner of the training process.
"""
metainfo = dict(
cfg=runner.cfg.pretty_text,
seed=runner.seed,
experiment_name=runner.experiment_name,
mmengine_version=__version__ + get_git_hash())
runner.message_hub.update_info_dict(metainfo)
self.last_loop_stage = None
def before_train(self, runner) -> None:
"""Update resumed training state.
Args:
runner (Runner): The runner of the training process.
"""
runner.message_hub.update_info('loop_stage', 'train')
runner.message_hub.update_info('epoch', runner.epoch)
runner.message_hub.update_info('iter', runner.iter)
runner.message_hub.update_info('max_epochs', runner.max_epochs)
runner.message_hub.update_info('max_iters', runner.max_iters)
if hasattr(runner.train_dataloader.dataset, 'metainfo'):
runner.message_hub.update_info(
'dataset_meta', runner.train_dataloader.dataset.metainfo)
def after_train(self, runner) -> None:
runner.message_hub.pop_info('loop_stage')
def before_train_epoch(self, runner) -> None:
"""Update current epoch information before every epoch.
Args:
runner (Runner): The runner of the training process.
"""
runner.message_hub.update_info('epoch', runner.epoch)
def before_train_iter(self,
runner,
batch_idx: int,
data_batch: DATA_BATCH = None) -> None:
"""Update current iter and learning rate information before every
iteration.
Args:
runner (Runner): The runner of the training process.
batch_idx (int): The index of the current batch in the train loop.
data_batch (Sequence[dict], optional): Data from dataloader.
Defaults to None.
"""
runner.message_hub.update_info('iter', runner.iter)
lr_dict = runner.optim_wrapper.get_lr()
assert isinstance(lr_dict, dict), (
'`runner.optim_wrapper.get_lr()` should return a dict '
'of learning rate when training with OptimWrapper(single '
'optimizer) or OptimWrapperDict(multiple optimizer), '
f'but got {type(lr_dict)} please check your optimizer '
'constructor return an `OptimWrapper` or `OptimWrapperDict` '
'instance')
for name, lr in lr_dict.items():
runner.message_hub.update_scalar(f'train/{name}', lr[0])
def after_train_iter(self,
runner,
batch_idx: int,
data_batch: DATA_BATCH = None,
outputs: Optional[dict] = None) -> None:
"""Update ``log_vars`` in model outputs every iteration.
Args:
runner (Runner): The runner of the training process.
batch_idx (int): The index of the current batch in the train loop.
data_batch (Sequence[dict], optional): Data from dataloader.
Defaults to None.
outputs (dict, optional): Outputs from model. Defaults to None.
"""
if outputs is not None:
for key, value in outputs.items():
if key.startswith('vis_'):
continue
runner.message_hub.update_scalar(f'train/{key}', value)
def before_val(self, runner) -> None:
self.last_loop_stage = runner.message_hub.get_info('loop_stage')
runner.message_hub.update_info('loop_stage', 'val')
def after_val_epoch(self,
runner,
metrics: Optional[Dict[str, float]] = None) -> None:
"""All subclasses should override this method, if they need any
operations after each validation epoch.
Args:
runner (Runner): The runner of the validation process.
metrics (Dict[str, float], optional): Evaluation results of all
metrics on validation dataset. The keys are the names of the
metrics, and the values are corresponding results.
"""
if metrics is not None:
for key, value in metrics.items():
if _is_scalar(value):
runner.message_hub.update_scalar(f'val/{key}', value)
else:
runner.message_hub.update_info(f'val/{key}', value)
def after_val(self, runner) -> None:
# ValLoop may be called within the TrainLoop, so we need to reset
# the loop_stage
# workflow: before_train -> before_val -> after_val -> after_train
if self.last_loop_stage == 'train':
runner.message_hub.update_info('loop_stage', self.last_loop_stage)
self.last_loop_stage = None
else:
runner.message_hub.pop_info('loop_stage')
def before_test(self, runner) -> None:
runner.message_hub.update_info('loop_stage', 'test')
def after_test(self, runner) -> None:
runner.message_hub.pop_info('loop_stage')
def after_test_epoch(self,
runner,
metrics: Optional[Dict[str, float]] = None) -> None:
"""All subclasses should override this method, if they need any
operations after each test epoch.
Args:
runner (Runner): The runner of the testing process.
metrics (Dict[str, float], optional): Evaluation results of all
metrics on test dataset. The keys are the names of the
metrics, and the values are corresponding results.
"""
if metrics is not None:
for key, value in metrics.items():
if _is_scalar(value):
runner.message_hub.update_scalar(f'test/{key}', value)
else:
runner.message_hub.update_info(f'test/{key}', value)