Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
from typing import Any, Dict, Optional, Union | |
import numpy as np | |
import torch | |
from mmengine.registry import HOOKS | |
from mmengine.utils import get_git_hash | |
from mmengine.version import __version__ | |
from .hook import Hook | |
DATA_BATCH = Optional[Union[dict, tuple, list]] | |
def _is_scalar(value: Any) -> bool: | |
"""Determine the value is a scalar type value. | |
Args: | |
value (Any): value of log. | |
Returns: | |
bool: whether the value is a scalar type value. | |
""" | |
if isinstance(value, np.ndarray): | |
return value.size == 1 | |
elif isinstance(value, (int, float, np.number)): | |
return True | |
elif isinstance(value, torch.Tensor): | |
return value.numel() == 1 | |
return False | |
class RuntimeInfoHook(Hook): | |
"""A hook that updates runtime information into message hub. | |
E.g. ``epoch``, ``iter``, ``max_epochs``, and ``max_iters`` for the | |
training state. Components that cannot access the runner can get runtime | |
information through the message hub. | |
""" | |
priority = 'VERY_HIGH' | |
def before_run(self, runner) -> None: | |
"""Update metainfo. | |
Args: | |
runner (Runner): The runner of the training process. | |
""" | |
metainfo = dict( | |
cfg=runner.cfg.pretty_text, | |
seed=runner.seed, | |
experiment_name=runner.experiment_name, | |
mmengine_version=__version__ + get_git_hash()) | |
runner.message_hub.update_info_dict(metainfo) | |
self.last_loop_stage = None | |
def before_train(self, runner) -> None: | |
"""Update resumed training state. | |
Args: | |
runner (Runner): The runner of the training process. | |
""" | |
runner.message_hub.update_info('loop_stage', 'train') | |
runner.message_hub.update_info('epoch', runner.epoch) | |
runner.message_hub.update_info('iter', runner.iter) | |
runner.message_hub.update_info('max_epochs', runner.max_epochs) | |
runner.message_hub.update_info('max_iters', runner.max_iters) | |
if hasattr(runner.train_dataloader.dataset, 'metainfo'): | |
runner.message_hub.update_info( | |
'dataset_meta', runner.train_dataloader.dataset.metainfo) | |
def after_train(self, runner) -> None: | |
runner.message_hub.pop_info('loop_stage') | |
def before_train_epoch(self, runner) -> None: | |
"""Update current epoch information before every epoch. | |
Args: | |
runner (Runner): The runner of the training process. | |
""" | |
runner.message_hub.update_info('epoch', runner.epoch) | |
def before_train_iter(self, | |
runner, | |
batch_idx: int, | |
data_batch: DATA_BATCH = None) -> None: | |
"""Update current iter and learning rate information before every | |
iteration. | |
Args: | |
runner (Runner): The runner of the training process. | |
batch_idx (int): The index of the current batch in the train loop. | |
data_batch (Sequence[dict], optional): Data from dataloader. | |
Defaults to None. | |
""" | |
runner.message_hub.update_info('iter', runner.iter) | |
lr_dict = runner.optim_wrapper.get_lr() | |
assert isinstance(lr_dict, dict), ( | |
'`runner.optim_wrapper.get_lr()` should return a dict ' | |
'of learning rate when training with OptimWrapper(single ' | |
'optimizer) or OptimWrapperDict(multiple optimizer), ' | |
f'but got {type(lr_dict)} please check your optimizer ' | |
'constructor return an `OptimWrapper` or `OptimWrapperDict` ' | |
'instance') | |
for name, lr in lr_dict.items(): | |
runner.message_hub.update_scalar(f'train/{name}', lr[0]) | |
def after_train_iter(self, | |
runner, | |
batch_idx: int, | |
data_batch: DATA_BATCH = None, | |
outputs: Optional[dict] = None) -> None: | |
"""Update ``log_vars`` in model outputs every iteration. | |
Args: | |
runner (Runner): The runner of the training process. | |
batch_idx (int): The index of the current batch in the train loop. | |
data_batch (Sequence[dict], optional): Data from dataloader. | |
Defaults to None. | |
outputs (dict, optional): Outputs from model. Defaults to None. | |
""" | |
if outputs is not None: | |
for key, value in outputs.items(): | |
if key.startswith('vis_'): | |
continue | |
runner.message_hub.update_scalar(f'train/{key}', value) | |
def before_val(self, runner) -> None: | |
self.last_loop_stage = runner.message_hub.get_info('loop_stage') | |
runner.message_hub.update_info('loop_stage', 'val') | |
def after_val_epoch(self, | |
runner, | |
metrics: Optional[Dict[str, float]] = None) -> None: | |
"""All subclasses should override this method, if they need any | |
operations after each validation epoch. | |
Args: | |
runner (Runner): The runner of the validation process. | |
metrics (Dict[str, float], optional): Evaluation results of all | |
metrics on validation dataset. The keys are the names of the | |
metrics, and the values are corresponding results. | |
""" | |
if metrics is not None: | |
for key, value in metrics.items(): | |
if _is_scalar(value): | |
runner.message_hub.update_scalar(f'val/{key}', value) | |
else: | |
runner.message_hub.update_info(f'val/{key}', value) | |
def after_val(self, runner) -> None: | |
# ValLoop may be called within the TrainLoop, so we need to reset | |
# the loop_stage | |
# workflow: before_train -> before_val -> after_val -> after_train | |
if self.last_loop_stage == 'train': | |
runner.message_hub.update_info('loop_stage', self.last_loop_stage) | |
self.last_loop_stage = None | |
else: | |
runner.message_hub.pop_info('loop_stage') | |
def before_test(self, runner) -> None: | |
runner.message_hub.update_info('loop_stage', 'test') | |
def after_test(self, runner) -> None: | |
runner.message_hub.pop_info('loop_stage') | |
def after_test_epoch(self, | |
runner, | |
metrics: Optional[Dict[str, float]] = None) -> None: | |
"""All subclasses should override this method, if they need any | |
operations after each test epoch. | |
Args: | |
runner (Runner): The runner of the testing process. | |
metrics (Dict[str, float], optional): Evaluation results of all | |
metrics on test dataset. The keys are the names of the | |
metrics, and the values are corresponding results. | |
""" | |
if metrics is not None: | |
for key, value in metrics.items(): | |
if _is_scalar(value): | |
runner.message_hub.update_scalar(f'test/{key}', value) | |
else: | |
runner.message_hub.update_info(f'test/{key}', value) | |